1/**
2 * @file testLoopyBelief.cpp
3 * @brief
4 * @author Duy-Nguyen Ta
5 * @date Oct 11, 2013
6 */
7
8#include <CppUnitLite/TestHarness.h>
9#include <gtsam/discrete/DecisionTreeFactor.h>
10#include <gtsam/discrete/DiscreteFactorGraph.h>
11#include <gtsam/discrete/DiscreteConditional.h>
12#include <gtsam/inference/VariableIndex.h>
13
14#include <fstream>
15#include <iostream>
16
17using namespace std;
18using namespace boost;
19using namespace gtsam;
20
21/**
22 * Loopy belief solver for graphs with only binary and unary factors
23 */
24class LoopyBelief {
25 /** Star graph struct for each node, containing
26 * - the star graph itself
27 * - the product of original unary factors so we don't have to recompute it
28 * later, and
29 * - the factor indices of the corrected belief factors of the neighboring
30 * nodes
31 */
32 typedef std::map<Key, size_t> CorrectedBeliefIndices;
33 struct StarGraph {
34 DiscreteFactorGraph::shared_ptr star;
35 CorrectedBeliefIndices correctedBeliefIndices;
36 DecisionTreeFactor::shared_ptr unary;
37 VariableIndex varIndex_;
38 StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
39 const CorrectedBeliefIndices& _beliefIndices,
40 const DecisionTreeFactor::shared_ptr& _unary)
41 : star(_star),
42 correctedBeliefIndices(_beliefIndices),
43 unary(_unary),
44 varIndex_(*_star) {}
45
46 void print(const std::string& s = "") const {
47 cout << s << ":" << endl;
48 star->print(s: "Star graph: ");
49 for (const auto& [key, _] : correctedBeliefIndices) {
50 cout << "Belief factor index for " << key << ": "
51 << correctedBeliefIndices.at(k: key) << endl;
52 }
53 if (unary) unary->print(s: "Unary: ");
54 }
55 };
56
57 typedef std::map<Key, StarGraph> StarGraphs;
58 StarGraphs starGraphs_; ///< star graph at each variable
59
60 public:
61 /** Constructor
62 * Need all discrete keys to access node's cardinality for creating belief
63 * factors
64 * TODO: so troublesome!!
65 */
66 LoopyBelief(const DiscreteFactorGraph& graph,
67 const std::map<Key, DiscreteKey>& allDiscreteKeys)
68 : starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {}
69
70 /// print
71 void print(const std::string& s = "") const {
72 cout << s << ":" << endl;
73 for (const auto& [key, _] : starGraphs_) {
74 starGraphs_.at(k: key).print(s: "Node " + std::to_string(val: key) + ":");
75 }
76 }
77
78 /// One step of belief propagation
79 DiscreteFactorGraph::shared_ptr iterate(
80 const std::map<Key, DiscreteKey>& allDiscreteKeys) {
81 static const bool debug = false;
82 DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
83 std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
84 // Eliminate each star graph
85 for (const auto& [key, _] : starGraphs_) {
86 // cout << "***** Node " << key << "*****" << endl;
87 // initialize belief to the unary factor from the original graph
88 DecisionTreeFactor::shared_ptr beliefAtKey;
89
90 // keep intermediate messages to divide later
91 std::map<Key, DiscreteFactor::shared_ptr> messages;
92
93 // eliminate each neighbor in this star graph one by one
94 for (const auto& [neighbor, _] : starGraphs_.at(k: key).correctedBeliefIndices) {
95 DiscreteFactorGraph subGraph;
96 for (size_t factor : starGraphs_.at(k: key).varIndex_[neighbor]) {
97 subGraph.push_back(factor: starGraphs_.at(k: key).star->at(i: factor));
98 }
99 if (debug) subGraph.print(s: "------- Subgraph:");
100 const auto [dummyCond, message] =
101 EliminateDiscrete(factors: subGraph, frontalKeys: Ordering{neighbor});
102 // store the new factor into messages
103 messages.insert(x: make_pair(x: neighbor, y: message));
104 if (debug) message->print(s: "------- Message: ");
105
106 // Belief is the product of all messages and the unary factor
107 // Incorporate new the factor to belief
108 if (!beliefAtKey)
109 beliefAtKey =
110 std::dynamic_pointer_cast<DecisionTreeFactor>(r: message);
111 else
112 beliefAtKey = std::make_shared<DecisionTreeFactor>(
113 args: (*beliefAtKey) *
114 (*std::dynamic_pointer_cast<DecisionTreeFactor>(r: message)));
115 }
116 if (starGraphs_.at(k: key).unary)
117 beliefAtKey = std::make_shared<DecisionTreeFactor>(
118 args: (*beliefAtKey) * (*starGraphs_.at(k: key).unary));
119 if (debug) beliefAtKey->print(s: "New belief at key: ");
120 // normalize belief
121 double sum = 0.0;
122 for (size_t v = 0; v < allDiscreteKeys.at(k: key).second; ++v) {
123 DiscreteValues val;
124 val[key] = v;
125 sum += (*beliefAtKey)(val);
126 }
127 // TODO(kartikarcot): Check if this makes sense
128 string sumFactorTable;
129 for (size_t v = 0; v < allDiscreteKeys.at(k: key).second; ++v) {
130 sumFactorTable = sumFactorTable + " " + std::to_string(val: sum);
131 }
132 DecisionTreeFactor sumFactor(allDiscreteKeys.at(k: key), sumFactorTable);
133 if (debug) sumFactor.print(s: "denomFactor: ");
134 beliefAtKey =
135 std::make_shared<DecisionTreeFactor>(args: (*beliefAtKey) / sumFactor);
136 if (debug) beliefAtKey->print(s: "New belief at key normalized: ");
137 beliefs->push_back(factor: beliefAtKey);
138 allMessages[key] = messages;
139 }
140
141 // Update corrected beliefs
142 VariableIndex beliefFactors(*beliefs);
143 for (const auto& [key, _] : starGraphs_) {
144 std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key];
145 for (const auto& [neighbor, _] : starGraphs_.at(k: key).correctedBeliefIndices) {
146 DecisionTreeFactor correctedBelief =
147 (*std::dynamic_pointer_cast<DecisionTreeFactor>(
148 r: beliefs->at(i: beliefFactors[key].front()))) /
149 (*std::dynamic_pointer_cast<DecisionTreeFactor>(
150 r: messages.at(k: neighbor)));
151 if (debug) correctedBelief.print(s: "correctedBelief: ");
152 size_t beliefIndex =
153 starGraphs_.at(k: neighbor).correctedBeliefIndices.at(k: key);
154 starGraphs_.at(k: neighbor).star->replace(
155 index: beliefIndex,
156 factor: std::make_shared<DecisionTreeFactor>(args&: correctedBelief));
157 }
158 }
159
160 if (debug) print(s: "After update: ");
161
162 return beliefs;
163 }
164
165 private:
166 /**
167 * Build star graphs for each node.
168 */
169 StarGraphs buildStarGraphs(
170 const DiscreteFactorGraph& graph,
171 const std::map<Key, DiscreteKey>& allDiscreteKeys) const {
172 StarGraphs starGraphs;
173 VariableIndex varIndex(graph); ///< access to all factors of each node
174 for (const auto& [key, _] : varIndex) {
175 // initialize to multiply with other unary factors later
176 DecisionTreeFactor::shared_ptr prodOfUnaries;
177
178 // collect all factors involving this key in the original graph
179 DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph());
180 for (size_t factorIndex : varIndex[key]) {
181 star->push_back(factor: graph.at(i: factorIndex));
182
183 // accumulate unary factors
184 if (graph.at(i: factorIndex)->size() == 1) {
185 if (!prodOfUnaries)
186 prodOfUnaries = graph.at<DecisionTreeFactor>(i: factorIndex);
187 else
188 prodOfUnaries = std::make_shared<DecisionTreeFactor>(
189 args: *prodOfUnaries * (*graph.at<DecisionTreeFactor>(i: factorIndex)));
190 }
191 }
192
193 // add the belief factor for each neighbor variable to this star graph
194 // also record the factor index for later modification
195 KeySet neighbors = star->keys();
196 neighbors.erase(x: key);
197 CorrectedBeliefIndices correctedBeliefIndices;
198 for (Key neighbor : neighbors) {
199 // TODO: default table for keys with more than 2 values?
200 string initialBelief;
201 for (size_t v = 0; v < allDiscreteKeys.at(k: neighbor).second - 1; ++v) {
202 initialBelief = initialBelief + "0.0 ";
203 }
204 initialBelief = initialBelief + "1.0";
205 star->push_back(
206 factor: DecisionTreeFactor(allDiscreteKeys.at(k: neighbor), initialBelief));
207 correctedBeliefIndices.insert(x: make_pair(x&: neighbor, y: star->size() - 1));
208 }
209 starGraphs.insert(x: make_pair(
210 x: key, y: StarGraph(star, correctedBeliefIndices, prodOfUnaries)));
211 }
212 return starGraphs;
213 }
214};
215
216/* ************************************************************************* */
217
218TEST_UNSAFE(LoopyBelief, construction) {
219 // Variables: Cloudy, Sprinkler, Rain, Wet
220 DiscreteKey C(0, 2), S(1, 2), R(2, 2), W(3, 2);
221
222 // Map from key to DiscreteKey for building belief factor.
223 // TODO: this is bad!
224 std::map<Key, DiscreteKey> allKeys{{0, C}, {1, S}, {2, R}, {3, W}};
225
226 // Build graph
227 DecisionTreeFactor pC(C, "0.5 0.5");
228 DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1");
229 DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8");
230 DecisionTreeFactor pSR(S & R, "0.0 0.9 0.9 0.99");
231
232 DiscreteFactorGraph graph;
233 graph.push_back(factor: pC);
234 graph.push_back(factor: pSC);
235 graph.push_back(factor: pRC);
236 graph.push_back(factor: pSR);
237
238 graph.print(s: "graph: ");
239
240 LoopyBelief solver(graph, allKeys);
241 solver.print(s: "Loopy belief: ");
242
243 // Main loop
244 for (size_t iter = 0; iter < 20; ++iter) {
245 cout << "==================================" << endl;
246 cout << "iteration: " << iter << endl;
247 DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(allDiscreteKeys: allKeys);
248 beliefs->print();
249 }
250}
251
252/* ************************************************************************* */
253int main() {
254 TestResult tr;
255 return TestRegistry::runAllTests(result&: tr);
256}
257/* ************************************************************************* */
258