| 1 | /* |
| 2 | * testCSP.cpp |
| 3 | * @brief develop code for CSP solver |
| 4 | * @date Feb 5, 2012 |
| 5 | * @author Frank Dellaert |
| 6 | */ |
| 7 | |
| 8 | #include <gtsam_unstable/discrete/CSP.h> |
| 9 | #include <gtsam_unstable/discrete/Domain.h> |
| 10 | |
| 11 | #include <CppUnitLite/TestHarness.h> |
| 12 | |
| 13 | #include <fstream> |
| 14 | #include <iostream> |
| 15 | |
| 16 | using namespace std; |
| 17 | using namespace gtsam; |
| 18 | |
| 19 | /* ************************************************************************* */ |
| 20 | TEST(CSP, SingleValue) { |
| 21 | // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: |
| 22 | size_t nrColors = 3; |
| 23 | DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); |
| 24 | |
| 25 | // Check that a single value is equal to a decision stump with only one "1": |
| 26 | SingleValue singleValue(AZ, 2); |
| 27 | DecisionTreeFactor f1(AZ, "0 0 1" ); |
| 28 | EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor())); |
| 29 | |
| 30 | // Create domains |
| 31 | Domains domains; |
| 32 | domains.emplace(args: 0, args: Domain(ID)); |
| 33 | domains.emplace(args: 1, args: Domain(AZ)); |
| 34 | domains.emplace(args: 2, args: Domain(UT)); |
| 35 | |
| 36 | // Ensure arc-consistency: just wipes out values in AZ domain: |
| 37 | EXPECT(singleValue.ensureArcConsistency(1, &domains)); |
| 38 | LONGS_EQUAL(3, domains.at(0).nrValues()); |
| 39 | LONGS_EQUAL(1, domains.at(1).nrValues()); |
| 40 | LONGS_EQUAL(3, domains.at(2).nrValues()); |
| 41 | } |
| 42 | |
| 43 | /* ************************************************************************* */ |
| 44 | TEST(CSP, BinaryAllDif) { |
| 45 | // Create keys for Idaho, Arizona, and Utah, allowing 2 colors for each: |
| 46 | size_t nrColors = 2; |
| 47 | DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); |
| 48 | |
| 49 | // Check construction and conversion |
| 50 | BinaryAllDiff c1(ID, UT); |
| 51 | DecisionTreeFactor f1(ID & UT, "0 1 1 0" ); |
| 52 | EXPECT(assert_equal(f1, c1.toDecisionTreeFactor())); |
| 53 | |
| 54 | // Check construction and conversion |
| 55 | BinaryAllDiff c2(UT, AZ); |
| 56 | DecisionTreeFactor f2(UT & AZ, "0 1 1 0" ); |
| 57 | EXPECT(assert_equal(f2, c2.toDecisionTreeFactor())); |
| 58 | |
| 59 | // Check multiplication of factors with constraint: |
| 60 | DecisionTreeFactor f3 = f1 * f2; |
| 61 | EXPECT(assert_equal(f3, c1 * f2)); |
| 62 | EXPECT(assert_equal(f3, c2 * f1)); |
| 63 | } |
| 64 | |
| 65 | /* ************************************************************************* */ |
| 66 | TEST(CSP, AllDiff) { |
| 67 | // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: |
| 68 | size_t nrColors = 3; |
| 69 | DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); |
| 70 | |
| 71 | // Check construction and conversion |
| 72 | vector<DiscreteKey> dkeys{ID, UT, AZ}; |
| 73 | AllDiff alldiff(dkeys); |
| 74 | DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); |
| 75 | // GTSAM_PRINT(actual); |
| 76 | actual.dot(name: "actual" ); |
| 77 | DecisionTreeFactor f2( |
| 78 | ID & AZ & UT, |
| 79 | "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0" ); |
| 80 | EXPECT(assert_equal(f2, actual)); |
| 81 | |
| 82 | // Create domains. |
| 83 | Domains domains; |
| 84 | domains.emplace(args: 0, args: Domain(ID)); |
| 85 | domains.emplace(args: 1, args: Domain(AZ)); |
| 86 | domains.emplace(args: 2, args: Domain(UT)); |
| 87 | |
| 88 | // First constrict AZ domain: |
| 89 | SingleValue singleValue(AZ, 2); |
| 90 | EXPECT(singleValue.ensureArcConsistency(1, &domains)); |
| 91 | |
| 92 | // Arc-consistency |
| 93 | EXPECT(alldiff.ensureArcConsistency(0, &domains)); |
| 94 | EXPECT(!alldiff.ensureArcConsistency(1, &domains)); |
| 95 | EXPECT(alldiff.ensureArcConsistency(2, &domains)); |
| 96 | LONGS_EQUAL(2, domains.at(0).nrValues()); |
| 97 | LONGS_EQUAL(1, domains.at(1).nrValues()); |
| 98 | LONGS_EQUAL(2, domains.at(2).nrValues()); |
| 99 | } |
| 100 | |
| 101 | /* ************************************************************************* */ |
| 102 | TEST(CSP, allInOne) { |
| 103 | // Create keys for Idaho, Arizona, and Utah, allowing 3 colors for each: |
| 104 | size_t nrColors = 2; |
| 105 | DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); |
| 106 | |
| 107 | // Create the CSP |
| 108 | CSP csp; |
| 109 | csp.addAllDiff(key1: ID, key2: UT); |
| 110 | csp.addAllDiff(key1: UT, key2: AZ); |
| 111 | |
| 112 | // Check an invalid combination, with ID==UT==AZ all same color |
| 113 | DiscreteValues invalid; |
| 114 | invalid[ID.first] = 0; |
| 115 | invalid[UT.first] = 0; |
| 116 | invalid[AZ.first] = 0; |
| 117 | EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); |
| 118 | |
| 119 | // Check a valid combination |
| 120 | DiscreteValues valid; |
| 121 | valid[ID.first] = 0; |
| 122 | valid[UT.first] = 1; |
| 123 | valid[AZ.first] = 0; |
| 124 | EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); |
| 125 | |
| 126 | // Just for fun, create the product and check it |
| 127 | DecisionTreeFactor product = csp.product()->toDecisionTreeFactor(); |
| 128 | // product.dot("product"); |
| 129 | DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0" ); |
| 130 | EXPECT(assert_equal(expectedProduct, product)); |
| 131 | |
| 132 | // Solve |
| 133 | auto mpe = csp.optimize(); |
| 134 | DiscreteValues expected {{ID.first, 1}, {UT.first, 0}, {AZ.first, 1}}; |
| 135 | EXPECT(assert_equal(expected, mpe)); |
| 136 | EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); |
| 137 | } |
| 138 | |
| 139 | /* ************************************************************************* */ |
| 140 | TEST(CSP, WesternUS) { |
| 141 | // Create keys for all states in Western US, with 4 color possibilities. |
| 142 | size_t nrColors = 4; |
| 143 | DiscreteKey WA(0, nrColors), OR(3, nrColors), CA(1, nrColors), |
| 144 | NV(2, nrColors), ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), |
| 145 | MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); |
| 146 | |
| 147 | // Create the CSP |
| 148 | CSP csp; |
| 149 | csp.addAllDiff(key1: WA, key2: ID); |
| 150 | csp.addAllDiff(key1: WA, key2: OR); |
| 151 | csp.addAllDiff(key1: OR, key2: ID); |
| 152 | csp.addAllDiff(key1: OR, key2: CA); |
| 153 | csp.addAllDiff(key1: OR, key2: NV); |
| 154 | csp.addAllDiff(key1: CA, key2: NV); |
| 155 | csp.addAllDiff(key1: CA, key2: AZ); |
| 156 | csp.addAllDiff(key1: ID, key2: MT); |
| 157 | csp.addAllDiff(key1: ID, key2: WY); |
| 158 | csp.addAllDiff(key1: ID, key2: UT); |
| 159 | csp.addAllDiff(key1: ID, key2: NV); |
| 160 | csp.addAllDiff(key1: NV, key2: UT); |
| 161 | csp.addAllDiff(key1: NV, key2: AZ); |
| 162 | csp.addAllDiff(key1: UT, key2: WY); |
| 163 | csp.addAllDiff(key1: UT, key2: CO); |
| 164 | csp.addAllDiff(key1: UT, key2: NM); |
| 165 | csp.addAllDiff(key1: UT, key2: AZ); |
| 166 | csp.addAllDiff(key1: AZ, key2: CO); |
| 167 | csp.addAllDiff(key1: AZ, key2: NM); |
| 168 | csp.addAllDiff(key1: MT, key2: WY); |
| 169 | csp.addAllDiff(key1: WY, key2: CO); |
| 170 | csp.addAllDiff(key1: CO, key2: NM); |
| 171 | |
| 172 | DiscreteValues mpe{{0, 2}, {1, 3}, {2, 2}, {3, 1}, {4, 1}, {5, 3}, |
| 173 | {6, 3}, {7, 2}, {8, 0}, {9, 1}, {10, 0}}; |
| 174 | |
| 175 | // Create ordering according to example in ND-CSP.lyx |
| 176 | const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; |
| 177 | |
| 178 | // Solve using that ordering: |
| 179 | auto actualMPE = csp.optimize(ordering); |
| 180 | |
| 181 | EXPECT(assert_equal(mpe, actualMPE)); |
| 182 | EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); |
| 183 | |
| 184 | // Write out the dual graph for hmetis |
| 185 | #ifdef DUAL |
| 186 | VariableIndexOrdered index(csp); |
| 187 | index.print("index" ); |
| 188 | ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/US-West-dual.txt" ); |
| 189 | index.outputMetisFormat(os); |
| 190 | #endif |
| 191 | } |
| 192 | |
| 193 | /* ************************************************************************* */ |
| 194 | TEST(CSP, ArcConsistency) { |
| 195 | // Create keys for Idaho, Arizona, and Utah, allowing three colors for each: |
| 196 | size_t nrColors = 3; |
| 197 | DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); |
| 198 | |
| 199 | // Create the CSP using just one all-diff constraint, plus constrain Arizona. |
| 200 | CSP csp; |
| 201 | vector<DiscreteKey> dkeys{ID, UT, AZ}; |
| 202 | csp.addAllDiff(dkeys); |
| 203 | csp.addSingleValue(dkey: AZ, value: 2); |
| 204 | // GTSAM_PRINT(csp); |
| 205 | |
| 206 | // Check an invalid combination, with ID==UT==AZ all same color |
| 207 | DiscreteValues invalid; |
| 208 | invalid[ID.first] = 0; |
| 209 | invalid[UT.first] = 1; |
| 210 | invalid[AZ.first] = 0; |
| 211 | EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); |
| 212 | |
| 213 | // Check a valid combination |
| 214 | DiscreteValues valid; |
| 215 | valid[ID.first] = 0; |
| 216 | valid[UT.first] = 1; |
| 217 | valid[AZ.first] = 2; |
| 218 | EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); |
| 219 | |
| 220 | // Solve |
| 221 | auto mpe = csp.optimize(); |
| 222 | DiscreteValues expected {{ID.first, 1}, {UT.first, 0}, {AZ.first, 2}}; |
| 223 | EXPECT(assert_equal(expected, mpe)); |
| 224 | EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); |
| 225 | |
| 226 | // ensure arc-consistency, i.e., narrow domains... |
| 227 | Domains domains; |
| 228 | domains.emplace(args: 0, args: Domain(ID)); |
| 229 | domains.emplace(args: 1, args: Domain(AZ)); |
| 230 | domains.emplace(args: 2, args: Domain(UT)); |
| 231 | |
| 232 | SingleValue singleValue(AZ, 2); |
| 233 | AllDiff alldiff(dkeys); |
| 234 | EXPECT(singleValue.ensureArcConsistency(1, &domains)); |
| 235 | EXPECT(alldiff.ensureArcConsistency(0, &domains)); |
| 236 | EXPECT(!alldiff.ensureArcConsistency(1, &domains)); |
| 237 | EXPECT(alldiff.ensureArcConsistency(2, &domains)); |
| 238 | LONGS_EQUAL(2, domains.at(0).nrValues()); |
| 239 | LONGS_EQUAL(1, domains.at(1).nrValues()); |
| 240 | LONGS_EQUAL(2, domains.at(2).nrValues()); |
| 241 | |
| 242 | // Parial application, version 1 |
| 243 | DiscreteValues known; |
| 244 | known[AZ.first] = 2; |
| 245 | DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known); |
| 246 | DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0" ); |
| 247 | EXPECT(assert_equal(f3, reduced1->toDecisionTreeFactor())); |
| 248 | DiscreteFactor::shared_ptr reduced2 = singleValue.partiallyApply(values: known); |
| 249 | DecisionTreeFactor f4(AZ, "0 0 1" ); |
| 250 | EXPECT(assert_equal(f4, reduced2->toDecisionTreeFactor())); |
| 251 | |
| 252 | // Parial application, version 2 |
| 253 | DiscreteFactor::shared_ptr reduced3 = alldiff.partiallyApply(domains); |
| 254 | EXPECT(assert_equal(f3, reduced3->toDecisionTreeFactor())); |
| 255 | DiscreteFactor::shared_ptr reduced4 = singleValue.partiallyApply(domains); |
| 256 | EXPECT(assert_equal(f4, reduced4->toDecisionTreeFactor())); |
| 257 | |
| 258 | // full arc-consistency test |
| 259 | csp.runArcConsistency(cardinality: nrColors); |
| 260 | // GTSAM_PRINT(csp); |
| 261 | } |
| 262 | |
| 263 | /* ************************************************************************* */ |
| 264 | int main() { |
| 265 | TestResult tr; |
| 266 | return TestRegistry::runAllTests(result&: tr); |
| 267 | } |
| 268 | /* ************************************************************************* */ |
| 269 | |