1/*
2 * testCSP.cpp
3 * @brief develop code for CSP solver
4 * @date Feb 5, 2012
5 * @author Frank Dellaert
6 */
7
8#include <gtsam_unstable/discrete/CSP.h>
9#include <gtsam_unstable/discrete/Domain.h>
10
11#include <CppUnitLite/TestHarness.h>
12
13#include <fstream>
14#include <iostream>
15
16using namespace std;
17using namespace gtsam;
18
19/* ************************************************************************* */
20TEST(CSP, SingleValue) {
21 // Create keys for Idaho, Arizona, and Utah, allowing two colors for each:
22 size_t nrColors = 3;
23 DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
24
25 // Check that a single value is equal to a decision stump with only one "1":
26 SingleValue singleValue(AZ, 2);
27 DecisionTreeFactor f1(AZ, "0 0 1");
28 EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor()));
29
30 // Create domains
31 Domains domains;
32 domains.emplace(args: 0, args: Domain(ID));
33 domains.emplace(args: 1, args: Domain(AZ));
34 domains.emplace(args: 2, args: Domain(UT));
35
36 // Ensure arc-consistency: just wipes out values in AZ domain:
37 EXPECT(singleValue.ensureArcConsistency(1, &domains));
38 LONGS_EQUAL(3, domains.at(0).nrValues());
39 LONGS_EQUAL(1, domains.at(1).nrValues());
40 LONGS_EQUAL(3, domains.at(2).nrValues());
41}
42
43/* ************************************************************************* */
44TEST(CSP, BinaryAllDif) {
45 // Create keys for Idaho, Arizona, and Utah, allowing 2 colors for each:
46 size_t nrColors = 2;
47 DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
48
49 // Check construction and conversion
50 BinaryAllDiff c1(ID, UT);
51 DecisionTreeFactor f1(ID & UT, "0 1 1 0");
52 EXPECT(assert_equal(f1, c1.toDecisionTreeFactor()));
53
54 // Check construction and conversion
55 BinaryAllDiff c2(UT, AZ);
56 DecisionTreeFactor f2(UT & AZ, "0 1 1 0");
57 EXPECT(assert_equal(f2, c2.toDecisionTreeFactor()));
58
59 // Check multiplication of factors with constraint:
60 DecisionTreeFactor f3 = f1 * f2;
61 EXPECT(assert_equal(f3, c1 * f2));
62 EXPECT(assert_equal(f3, c2 * f1));
63}
64
65/* ************************************************************************* */
66TEST(CSP, AllDiff) {
67 // Create keys for Idaho, Arizona, and Utah, allowing two colors for each:
68 size_t nrColors = 3;
69 DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
70
71 // Check construction and conversion
72 vector<DiscreteKey> dkeys{ID, UT, AZ};
73 AllDiff alldiff(dkeys);
74 DecisionTreeFactor actual = alldiff.toDecisionTreeFactor();
75 // GTSAM_PRINT(actual);
76 actual.dot(name: "actual");
77 DecisionTreeFactor f2(
78 ID & AZ & UT,
79 "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0");
80 EXPECT(assert_equal(f2, actual));
81
82 // Create domains.
83 Domains domains;
84 domains.emplace(args: 0, args: Domain(ID));
85 domains.emplace(args: 1, args: Domain(AZ));
86 domains.emplace(args: 2, args: Domain(UT));
87
88 // First constrict AZ domain:
89 SingleValue singleValue(AZ, 2);
90 EXPECT(singleValue.ensureArcConsistency(1, &domains));
91
92 // Arc-consistency
93 EXPECT(alldiff.ensureArcConsistency(0, &domains));
94 EXPECT(!alldiff.ensureArcConsistency(1, &domains));
95 EXPECT(alldiff.ensureArcConsistency(2, &domains));
96 LONGS_EQUAL(2, domains.at(0).nrValues());
97 LONGS_EQUAL(1, domains.at(1).nrValues());
98 LONGS_EQUAL(2, domains.at(2).nrValues());
99}
100
101/* ************************************************************************* */
102TEST(CSP, allInOne) {
103 // Create keys for Idaho, Arizona, and Utah, allowing 3 colors for each:
104 size_t nrColors = 2;
105 DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
106
107 // Create the CSP
108 CSP csp;
109 csp.addAllDiff(key1: ID, key2: UT);
110 csp.addAllDiff(key1: UT, key2: AZ);
111
112 // Check an invalid combination, with ID==UT==AZ all same color
113 DiscreteValues invalid;
114 invalid[ID.first] = 0;
115 invalid[UT.first] = 0;
116 invalid[AZ.first] = 0;
117 EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
118
119 // Check a valid combination
120 DiscreteValues valid;
121 valid[ID.first] = 0;
122 valid[UT.first] = 1;
123 valid[AZ.first] = 0;
124 EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
125
126 // Just for fun, create the product and check it
127 DecisionTreeFactor product = csp.product()->toDecisionTreeFactor();
128 // product.dot("product");
129 DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
130 EXPECT(assert_equal(expectedProduct, product));
131
132 // Solve
133 auto mpe = csp.optimize();
134 DiscreteValues expected {{ID.first, 1}, {UT.first, 0}, {AZ.first, 1}};
135 EXPECT(assert_equal(expected, mpe));
136 EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
137}
138
139/* ************************************************************************* */
140TEST(CSP, WesternUS) {
141 // Create keys for all states in Western US, with 4 color possibilities.
142 size_t nrColors = 4;
143 DiscreteKey WA(0, nrColors), OR(3, nrColors), CA(1, nrColors),
144 NV(2, nrColors), ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors),
145 MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors);
146
147 // Create the CSP
148 CSP csp;
149 csp.addAllDiff(key1: WA, key2: ID);
150 csp.addAllDiff(key1: WA, key2: OR);
151 csp.addAllDiff(key1: OR, key2: ID);
152 csp.addAllDiff(key1: OR, key2: CA);
153 csp.addAllDiff(key1: OR, key2: NV);
154 csp.addAllDiff(key1: CA, key2: NV);
155 csp.addAllDiff(key1: CA, key2: AZ);
156 csp.addAllDiff(key1: ID, key2: MT);
157 csp.addAllDiff(key1: ID, key2: WY);
158 csp.addAllDiff(key1: ID, key2: UT);
159 csp.addAllDiff(key1: ID, key2: NV);
160 csp.addAllDiff(key1: NV, key2: UT);
161 csp.addAllDiff(key1: NV, key2: AZ);
162 csp.addAllDiff(key1: UT, key2: WY);
163 csp.addAllDiff(key1: UT, key2: CO);
164 csp.addAllDiff(key1: UT, key2: NM);
165 csp.addAllDiff(key1: UT, key2: AZ);
166 csp.addAllDiff(key1: AZ, key2: CO);
167 csp.addAllDiff(key1: AZ, key2: NM);
168 csp.addAllDiff(key1: MT, key2: WY);
169 csp.addAllDiff(key1: WY, key2: CO);
170 csp.addAllDiff(key1: CO, key2: NM);
171
172 DiscreteValues mpe{{0, 2}, {1, 3}, {2, 2}, {3, 1}, {4, 1}, {5, 3},
173 {6, 3}, {7, 2}, {8, 0}, {9, 1}, {10, 0}};
174
175 // Create ordering according to example in ND-CSP.lyx
176 const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
177
178 // Solve using that ordering:
179 auto actualMPE = csp.optimize(ordering);
180
181 EXPECT(assert_equal(mpe, actualMPE));
182 EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
183
184 // Write out the dual graph for hmetis
185#ifdef DUAL
186 VariableIndexOrdered index(csp);
187 index.print("index");
188 ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/US-West-dual.txt");
189 index.outputMetisFormat(os);
190#endif
191}
192
193/* ************************************************************************* */
194TEST(CSP, ArcConsistency) {
195 // Create keys for Idaho, Arizona, and Utah, allowing three colors for each:
196 size_t nrColors = 3;
197 DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
198
199 // Create the CSP using just one all-diff constraint, plus constrain Arizona.
200 CSP csp;
201 vector<DiscreteKey> dkeys{ID, UT, AZ};
202 csp.addAllDiff(dkeys);
203 csp.addSingleValue(dkey: AZ, value: 2);
204 // GTSAM_PRINT(csp);
205
206 // Check an invalid combination, with ID==UT==AZ all same color
207 DiscreteValues invalid;
208 invalid[ID.first] = 0;
209 invalid[UT.first] = 1;
210 invalid[AZ.first] = 0;
211 EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
212
213 // Check a valid combination
214 DiscreteValues valid;
215 valid[ID.first] = 0;
216 valid[UT.first] = 1;
217 valid[AZ.first] = 2;
218 EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
219
220 // Solve
221 auto mpe = csp.optimize();
222 DiscreteValues expected {{ID.first, 1}, {UT.first, 0}, {AZ.first, 2}};
223 EXPECT(assert_equal(expected, mpe));
224 EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
225
226 // ensure arc-consistency, i.e., narrow domains...
227 Domains domains;
228 domains.emplace(args: 0, args: Domain(ID));
229 domains.emplace(args: 1, args: Domain(AZ));
230 domains.emplace(args: 2, args: Domain(UT));
231
232 SingleValue singleValue(AZ, 2);
233 AllDiff alldiff(dkeys);
234 EXPECT(singleValue.ensureArcConsistency(1, &domains));
235 EXPECT(alldiff.ensureArcConsistency(0, &domains));
236 EXPECT(!alldiff.ensureArcConsistency(1, &domains));
237 EXPECT(alldiff.ensureArcConsistency(2, &domains));
238 LONGS_EQUAL(2, domains.at(0).nrValues());
239 LONGS_EQUAL(1, domains.at(1).nrValues());
240 LONGS_EQUAL(2, domains.at(2).nrValues());
241
242 // Parial application, version 1
243 DiscreteValues known;
244 known[AZ.first] = 2;
245 DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known);
246 DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0");
247 EXPECT(assert_equal(f3, reduced1->toDecisionTreeFactor()));
248 DiscreteFactor::shared_ptr reduced2 = singleValue.partiallyApply(values: known);
249 DecisionTreeFactor f4(AZ, "0 0 1");
250 EXPECT(assert_equal(f4, reduced2->toDecisionTreeFactor()));
251
252 // Parial application, version 2
253 DiscreteFactor::shared_ptr reduced3 = alldiff.partiallyApply(domains);
254 EXPECT(assert_equal(f3, reduced3->toDecisionTreeFactor()));
255 DiscreteFactor::shared_ptr reduced4 = singleValue.partiallyApply(domains);
256 EXPECT(assert_equal(f4, reduced4->toDecisionTreeFactor()));
257
258 // full arc-consistency test
259 csp.runArcConsistency(cardinality: nrColors);
260 // GTSAM_PRINT(csp);
261}
262
263/* ************************************************************************* */
264int main() {
265 TestResult tr;
266 return TestRegistry::runAllTests(result&: tr);
267}
268/* ************************************************************************* */
269