1/*
2 * AllDiff.cpp
3 * @brief General "all-different" constraint
4 * @date Feb 6, 2012
5 * @author Frank Dellaert
6 */
7
8#include <gtsam/base/Testable.h>
9#include <gtsam_unstable/discrete/AllDiff.h>
10#include <gtsam_unstable/discrete/Domain.h>
11
12#include <optional>
13
14namespace gtsam {
15
16/* ************************************************************************* */
17AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) {
18 for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(x: dkey);
19}
20
21/* ************************************************************************* */
22void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
23 std::cout << s << "AllDiff on ";
24 for (Key dkey : keys_) std::cout << formatter(dkey) << " ";
25 std::cout << std::endl;
26}
27
28/* ************************************************************************* */
29double AllDiff::evaluate(const Assignment<Key>& values) const {
30 std::set<size_t> taken; // record values taken by keys
31 for (Key dkey : keys_) {
32 size_t value = values.at(k: dkey); // get the value for that key
33 if (taken.count(x: value)) return 0.0; // check if value alreday taken
34 taken.insert(x: value); // if not, record it as taken and keep checking
35 }
36 return 1.0;
37}
38
39/* ************************************************************************* */
40DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
41 // We will do this by converting the allDif into many BinaryAllDiff
42 // constraints
43 DecisionTreeFactor converted;
44 size_t nrKeys = keys_.size();
45 for (size_t i1 = 0; i1 < nrKeys; i1++)
46 for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) {
47 BinaryAllDiff binary12(discreteKey(i: i1), discreteKey(i: i2));
48 converted = converted * binary12.toDecisionTreeFactor();
49 }
50 return converted;
51}
52
53/* ************************************************************************* */
54DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
55 // TODO: can we do this more efficiently?
56 return toDecisionTreeFactor() * f;
57}
58
59/* ************************************************************************* */
60bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const {
61 Domain& Dj = domains->at(k: j);
62
63 // Though strictly not part of allDiff, we check for
64 // a value in domains->at(j) that does not occur in any other connected domain.
65 // If found, we make this a singleton...
66 // TODO: make a new constraint where this really is true
67 std::optional<Domain> maybeChanged = Dj.checkAllDiff(keys: keys_, domains: *domains);
68 if (maybeChanged) {
69 Dj = *maybeChanged;
70 return true;
71 }
72
73 // Check all other domains for singletons and erase corresponding values.
74 // This is the same as arc-consistency on the equivalent binary constraints
75 bool changed = false;
76 for (Key k : keys_)
77 if (k != j) {
78 const Domain& Dk = domains->at(k: k);
79 if (Dk.isSingleton()) { // check if singleton
80 size_t value = Dk.firstValue();
81 if (Dj.contains(value)) {
82 Dj.erase(value); // erase value if true
83 changed = true;
84 }
85 }
86 }
87 return changed;
88}
89
90/* ************************************************************************* */
91Constraint::shared_ptr AllDiff::partiallyApply(const DiscreteValues& values) const {
92 DiscreteKeys newKeys;
93 // loop over keys and add them only if they do not appear in values
94 for (Key k : keys_)
95 if (values.find(x: k) == values.end()) {
96 newKeys.push_back(x: DiscreteKey(k, cardinalities_.at(k: k)));
97 }
98 return std::make_shared<AllDiff>(args&: newKeys);
99}
100
101/* ************************************************************************* */
102Constraint::shared_ptr AllDiff::partiallyApply(
103 const Domains& domains) const {
104 DiscreteValues known;
105 for (Key k : keys_) {
106 const Domain& Dk = domains.at(k: k);
107 if (Dk.isSingleton()) known[k] = Dk.firstValue();
108 }
109 return partiallyApply(values: known);
110}
111
112/* ************************************************************************* */
113} // namespace gtsam
114