| 1 | /** |
| 2 | * @file testLoopyBelief.cpp |
| 3 | * @brief |
| 4 | * @author Duy-Nguyen Ta |
| 5 | * @date Oct 11, 2013 |
| 6 | */ |
| 7 | |
| 8 | #include <CppUnitLite/TestHarness.h> |
| 9 | #include <gtsam/discrete/DecisionTreeFactor.h> |
| 10 | #include <gtsam/discrete/DiscreteFactorGraph.h> |
| 11 | #include <gtsam/discrete/DiscreteConditional.h> |
| 12 | #include <gtsam/inference/VariableIndex.h> |
| 13 | |
| 14 | #include <fstream> |
| 15 | #include <iostream> |
| 16 | |
| 17 | using namespace std; |
| 18 | using namespace boost; |
| 19 | using namespace gtsam; |
| 20 | |
| 21 | /** |
| 22 | * Loopy belief solver for graphs with only binary and unary factors |
| 23 | */ |
| 24 | class LoopyBelief { |
| 25 | /** Star graph struct for each node, containing |
| 26 | * - the star graph itself |
| 27 | * - the product of original unary factors so we don't have to recompute it |
| 28 | * later, and |
| 29 | * - the factor indices of the corrected belief factors of the neighboring |
| 30 | * nodes |
| 31 | */ |
| 32 | typedef std::map<Key, size_t> CorrectedBeliefIndices; |
| 33 | struct StarGraph { |
| 34 | DiscreteFactorGraph::shared_ptr star; |
| 35 | CorrectedBeliefIndices correctedBeliefIndices; |
| 36 | DecisionTreeFactor::shared_ptr unary; |
| 37 | VariableIndex varIndex_; |
| 38 | StarGraph(const DiscreteFactorGraph::shared_ptr& _star, |
| 39 | const CorrectedBeliefIndices& _beliefIndices, |
| 40 | const DecisionTreeFactor::shared_ptr& _unary) |
| 41 | : star(_star), |
| 42 | correctedBeliefIndices(_beliefIndices), |
| 43 | unary(_unary), |
| 44 | varIndex_(*_star) {} |
| 45 | |
| 46 | void print(const std::string& s = "" ) const { |
| 47 | cout << s << ":" << endl; |
| 48 | star->print(s: "Star graph: " ); |
| 49 | for (const auto& [key, _] : correctedBeliefIndices) { |
| 50 | cout << "Belief factor index for " << key << ": " |
| 51 | << correctedBeliefIndices.at(k: key) << endl; |
| 52 | } |
| 53 | if (unary) unary->print(s: "Unary: " ); |
| 54 | } |
| 55 | }; |
| 56 | |
| 57 | typedef std::map<Key, StarGraph> StarGraphs; |
| 58 | StarGraphs starGraphs_; ///< star graph at each variable |
| 59 | |
| 60 | public: |
| 61 | /** Constructor |
| 62 | * Need all discrete keys to access node's cardinality for creating belief |
| 63 | * factors |
| 64 | * TODO: so troublesome!! |
| 65 | */ |
| 66 | LoopyBelief(const DiscreteFactorGraph& graph, |
| 67 | const std::map<Key, DiscreteKey>& allDiscreteKeys) |
| 68 | : starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {} |
| 69 | |
| 70 | /// print |
| 71 | void print(const std::string& s = "" ) const { |
| 72 | cout << s << ":" << endl; |
| 73 | for (const auto& [key, _] : starGraphs_) { |
| 74 | starGraphs_.at(k: key).print(s: "Node " + std::to_string(val: key) + ":" ); |
| 75 | } |
| 76 | } |
| 77 | |
| 78 | /// One step of belief propagation |
| 79 | DiscreteFactorGraph::shared_ptr iterate( |
| 80 | const std::map<Key, DiscreteKey>& allDiscreteKeys) { |
| 81 | static const bool debug = false; |
| 82 | DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph()); |
| 83 | std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages; |
| 84 | // Eliminate each star graph |
| 85 | for (const auto& [key, _] : starGraphs_) { |
| 86 | // cout << "***** Node " << key << "*****" << endl; |
| 87 | // initialize belief to the unary factor from the original graph |
| 88 | DecisionTreeFactor::shared_ptr beliefAtKey; |
| 89 | |
| 90 | // keep intermediate messages to divide later |
| 91 | std::map<Key, DiscreteFactor::shared_ptr> messages; |
| 92 | |
| 93 | // eliminate each neighbor in this star graph one by one |
| 94 | for (const auto& [neighbor, _] : starGraphs_.at(k: key).correctedBeliefIndices) { |
| 95 | DiscreteFactorGraph subGraph; |
| 96 | for (size_t factor : starGraphs_.at(k: key).varIndex_[neighbor]) { |
| 97 | subGraph.push_back(factor: starGraphs_.at(k: key).star->at(i: factor)); |
| 98 | } |
| 99 | if (debug) subGraph.print(s: "------- Subgraph:" ); |
| 100 | const auto [dummyCond, message] = |
| 101 | EliminateDiscrete(factors: subGraph, frontalKeys: Ordering{neighbor}); |
| 102 | // store the new factor into messages |
| 103 | messages.insert(x: make_pair(x: neighbor, y: message)); |
| 104 | if (debug) message->print(s: "------- Message: " ); |
| 105 | |
| 106 | // Belief is the product of all messages and the unary factor |
| 107 | // Incorporate new the factor to belief |
| 108 | if (!beliefAtKey) |
| 109 | beliefAtKey = |
| 110 | std::dynamic_pointer_cast<DecisionTreeFactor>(r: message); |
| 111 | else |
| 112 | beliefAtKey = std::make_shared<DecisionTreeFactor>( |
| 113 | args: (*beliefAtKey) * |
| 114 | (*std::dynamic_pointer_cast<DecisionTreeFactor>(r: message))); |
| 115 | } |
| 116 | if (starGraphs_.at(k: key).unary) |
| 117 | beliefAtKey = std::make_shared<DecisionTreeFactor>( |
| 118 | args: (*beliefAtKey) * (*starGraphs_.at(k: key).unary)); |
| 119 | if (debug) beliefAtKey->print(s: "New belief at key: " ); |
| 120 | // normalize belief |
| 121 | double sum = 0.0; |
| 122 | for (size_t v = 0; v < allDiscreteKeys.at(k: key).second; ++v) { |
| 123 | DiscreteValues val; |
| 124 | val[key] = v; |
| 125 | sum += (*beliefAtKey)(val); |
| 126 | } |
| 127 | // TODO(kartikarcot): Check if this makes sense |
| 128 | string sumFactorTable; |
| 129 | for (size_t v = 0; v < allDiscreteKeys.at(k: key).second; ++v) { |
| 130 | sumFactorTable = sumFactorTable + " " + std::to_string(val: sum); |
| 131 | } |
| 132 | DecisionTreeFactor sumFactor(allDiscreteKeys.at(k: key), sumFactorTable); |
| 133 | if (debug) sumFactor.print(s: "denomFactor: " ); |
| 134 | beliefAtKey = |
| 135 | std::make_shared<DecisionTreeFactor>(args: (*beliefAtKey) / sumFactor); |
| 136 | if (debug) beliefAtKey->print(s: "New belief at key normalized: " ); |
| 137 | beliefs->push_back(factor: beliefAtKey); |
| 138 | allMessages[key] = messages; |
| 139 | } |
| 140 | |
| 141 | // Update corrected beliefs |
| 142 | VariableIndex beliefFactors(*beliefs); |
| 143 | for (const auto& [key, _] : starGraphs_) { |
| 144 | std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key]; |
| 145 | for (const auto& [neighbor, _] : starGraphs_.at(k: key).correctedBeliefIndices) { |
| 146 | DecisionTreeFactor correctedBelief = |
| 147 | (*std::dynamic_pointer_cast<DecisionTreeFactor>( |
| 148 | r: beliefs->at(i: beliefFactors[key].front()))) / |
| 149 | (*std::dynamic_pointer_cast<DecisionTreeFactor>( |
| 150 | r: messages.at(k: neighbor))); |
| 151 | if (debug) correctedBelief.print(s: "correctedBelief: " ); |
| 152 | size_t beliefIndex = |
| 153 | starGraphs_.at(k: neighbor).correctedBeliefIndices.at(k: key); |
| 154 | starGraphs_.at(k: neighbor).star->replace( |
| 155 | index: beliefIndex, |
| 156 | factor: std::make_shared<DecisionTreeFactor>(args&: correctedBelief)); |
| 157 | } |
| 158 | } |
| 159 | |
| 160 | if (debug) print(s: "After update: " ); |
| 161 | |
| 162 | return beliefs; |
| 163 | } |
| 164 | |
| 165 | private: |
| 166 | /** |
| 167 | * Build star graphs for each node. |
| 168 | */ |
| 169 | StarGraphs buildStarGraphs( |
| 170 | const DiscreteFactorGraph& graph, |
| 171 | const std::map<Key, DiscreteKey>& allDiscreteKeys) const { |
| 172 | StarGraphs starGraphs; |
| 173 | VariableIndex varIndex(graph); ///< access to all factors of each node |
| 174 | for (const auto& [key, _] : varIndex) { |
| 175 | // initialize to multiply with other unary factors later |
| 176 | DecisionTreeFactor::shared_ptr prodOfUnaries; |
| 177 | |
| 178 | // collect all factors involving this key in the original graph |
| 179 | DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph()); |
| 180 | for (size_t factorIndex : varIndex[key]) { |
| 181 | star->push_back(factor: graph.at(i: factorIndex)); |
| 182 | |
| 183 | // accumulate unary factors |
| 184 | if (graph.at(i: factorIndex)->size() == 1) { |
| 185 | if (!prodOfUnaries) |
| 186 | prodOfUnaries = graph.at<DecisionTreeFactor>(i: factorIndex); |
| 187 | else |
| 188 | prodOfUnaries = std::make_shared<DecisionTreeFactor>( |
| 189 | args: *prodOfUnaries * (*graph.at<DecisionTreeFactor>(i: factorIndex))); |
| 190 | } |
| 191 | } |
| 192 | |
| 193 | // add the belief factor for each neighbor variable to this star graph |
| 194 | // also record the factor index for later modification |
| 195 | KeySet neighbors = star->keys(); |
| 196 | neighbors.erase(x: key); |
| 197 | CorrectedBeliefIndices correctedBeliefIndices; |
| 198 | for (Key neighbor : neighbors) { |
| 199 | // TODO: default table for keys with more than 2 values? |
| 200 | string initialBelief; |
| 201 | for (size_t v = 0; v < allDiscreteKeys.at(k: neighbor).second - 1; ++v) { |
| 202 | initialBelief = initialBelief + "0.0 " ; |
| 203 | } |
| 204 | initialBelief = initialBelief + "1.0" ; |
| 205 | star->push_back( |
| 206 | factor: DecisionTreeFactor(allDiscreteKeys.at(k: neighbor), initialBelief)); |
| 207 | correctedBeliefIndices.insert(x: make_pair(x&: neighbor, y: star->size() - 1)); |
| 208 | } |
| 209 | starGraphs.insert(x: make_pair( |
| 210 | x: key, y: StarGraph(star, correctedBeliefIndices, prodOfUnaries))); |
| 211 | } |
| 212 | return starGraphs; |
| 213 | } |
| 214 | }; |
| 215 | |
| 216 | /* ************************************************************************* */ |
| 217 | |
| 218 | TEST_UNSAFE(LoopyBelief, construction) { |
| 219 | // Variables: Cloudy, Sprinkler, Rain, Wet |
| 220 | DiscreteKey C(0, 2), S(1, 2), R(2, 2), W(3, 2); |
| 221 | |
| 222 | // Map from key to DiscreteKey for building belief factor. |
| 223 | // TODO: this is bad! |
| 224 | std::map<Key, DiscreteKey> allKeys{{0, C}, {1, S}, {2, R}, {3, W}}; |
| 225 | |
| 226 | // Build graph |
| 227 | DecisionTreeFactor pC(C, "0.5 0.5" ); |
| 228 | DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1" ); |
| 229 | DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8" ); |
| 230 | DecisionTreeFactor pSR(S & R, "0.0 0.9 0.9 0.99" ); |
| 231 | |
| 232 | DiscreteFactorGraph graph; |
| 233 | graph.push_back(factor: pC); |
| 234 | graph.push_back(factor: pSC); |
| 235 | graph.push_back(factor: pRC); |
| 236 | graph.push_back(factor: pSR); |
| 237 | |
| 238 | graph.print(s: "graph: " ); |
| 239 | |
| 240 | LoopyBelief solver(graph, allKeys); |
| 241 | solver.print(s: "Loopy belief: " ); |
| 242 | |
| 243 | // Main loop |
| 244 | for (size_t iter = 0; iter < 20; ++iter) { |
| 245 | cout << "==================================" << endl; |
| 246 | cout << "iteration: " << iter << endl; |
| 247 | DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(allDiscreteKeys: allKeys); |
| 248 | beliefs->print(); |
| 249 | } |
| 250 | } |
| 251 | |
| 252 | /* ************************************************************************* */ |
| 253 | int main() { |
| 254 | TestResult tr; |
| 255 | return TestRegistry::runAllTests(result&: tr); |
| 256 | } |
| 257 | /* ************************************************************************* */ |
| 258 | |