1/* ----------------------------------------------------------------------------
2
3 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4 * Atlanta, Georgia 30332-0415
5 * All Rights Reserved
6 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7
8 * See LICENSE for the license information
9
10 * -------------------------------------------------------------------------- */
11
12/**
13 * @file DiscreteBayesNet.h
14 * @date Feb 15, 2011
15 * @author Duy-Nguyen Ta
16 * @author Frank dellaert
17 */
18
19#pragma once
20
21#include <gtsam/discrete/DiscreteConditional.h>
22#include <gtsam/discrete/DiscreteDistribution.h>
23#include <gtsam/inference/BayesNet.h>
24#include <gtsam/inference/FactorGraph.h>
25
26#include <memory>
27#include <map>
28#include <string>
29#include <utility>
30#include <vector>
31
32namespace gtsam {
33
34/**
35 * A Bayes net made from discrete conditional distributions.
36 * @ingroup discrete
37 */
38class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
39 public:
40 typedef BayesNet<DiscreteConditional> Base;
41 typedef DiscreteBayesNet This;
42 typedef DiscreteConditional ConditionalType;
43 typedef std::shared_ptr<This> shared_ptr;
44 typedef std::shared_ptr<ConditionalType> sharedConditional;
45
46 /// @name Standard Constructors
47 /// @{
48
49 /// Construct empty Bayes net.
50 DiscreteBayesNet() {}
51
52 /** Construct from iterator over conditionals */
53 template <typename ITERATOR>
54 DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
55 : Base(firstConditional, lastConditional) {}
56
57 /** Construct from container of factors (shared_ptr or plain objects) */
58 template <class CONTAINER>
59 explicit DiscreteBayesNet(const CONTAINER& conditionals)
60 : Base(conditionals) {}
61
62 /** Implicit copy/downcast constructor to override explicit template
63 * container constructor */
64 template <class DERIVEDCONDITIONAL>
65 DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
66 : Base(graph) {}
67
68 /// @}
69
70 /// @name Testable
71 /// @{
72
73 /** Check equality */
74 bool equals(const This& bn, double tol = 1e-9) const;
75
76 /// @}
77
78 /// @name Standard Interface
79 /// @{
80
81 // Add inherited versions of add.
82 using Base::add;
83
84 /** Add a DiscreteDistribution using a table or a string */
85 void add(const DiscreteKey& key, const std::string& spec) {
86 emplace_shared<DiscreteDistribution>(args: key, args: spec);
87 }
88
89 /** Add a DiscreteCondtional */
90 template <typename... Args>
91 void add(Args&&... args) {
92 emplace_shared<DiscreteConditional>(std::forward<Args>(args)...);
93 }
94
95 //** evaluate for given DiscreteValues */
96 double evaluate(const DiscreteValues & values) const;
97
98 //** (Preferred) sugar for the above for given DiscreteValues */
99 double operator()(const DiscreteValues & values) const {
100 return evaluate(values);
101 }
102
103 //** log(evaluate(values)) for given DiscreteValues */
104 double logProbability(const DiscreteValues & values) const;
105
106 /**
107 * @brief do ancestral sampling
108 *
109 * Assumes the Bayes net is reverse topologically sorted, i.e. last
110 * conditional will be sampled first. If the Bayes net resulted from
111 * eliminating a factor graph, this is true for the elimination ordering.
112 *
113 * @return a sampled value for all variables.
114 */
115 DiscreteValues sample(std::mt19937_64* rng = nullptr) const;
116
117 /**
118 * @brief do ancestral sampling, given certain variables.
119 *
120 * Assumes the Bayes net is reverse topologically sorted *and* that the
121 * Bayes net does not contain any conditionals for the given values.
122 *
123 * @return given values extended with sampled value for all other variables.
124 */
125 DiscreteValues sample(DiscreteValues given,
126 std::mt19937_64* rng = nullptr) const;
127
128 /**
129 * @brief Prune the Bayes net
130 *
131 * @param maxNrLeaves The maximum number of leaves to keep.
132 * @param marginalThreshold If given, threshold on marginals to prune variables.
133 * @param fixedValues If given, return the fixed values removed.
134 * @return A new DiscreteBayesNet with pruned conditionals.
135 */
136 DiscreteBayesNet prune(size_t maxNrLeaves,
137 const std::optional<double>& marginalThreshold = {},
138 DiscreteValues* fixedValues = nullptr) const;
139
140 /**
141 * @brief Multiply all conditionals into one big joint conditional
142 * and return it.
143 *
144 * NOTE: possibly quite expensive.
145 *
146 * @return DiscreteConditional
147 */
148 DiscreteConditional joint() const;
149
150 ///@}
151 /// @name Wrapper support
152 /// @{
153
154 /// Render as markdown tables.
155 std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
156 const DiscreteFactor::Names& names = {}) const;
157
158 /// Render as html tables.
159 std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
160 const DiscreteFactor::Names& names = {}) const;
161
162 /// @}
163 /// @name HybridValues methods.
164 /// @{
165
166 using Base::error; // Expose error(const HybridValues&) method..
167 using Base::evaluate; // Expose evaluate(const HybridValues&) method..
168 using Base::logProbability; // Expose logProbability(const HybridValues&)
169
170 /// @}
171
172 private:
173#if GTSAM_ENABLE_BOOST_SERIALIZATION
174 /** Serialization function */
175 friend class boost::serialization::access;
176 template<class ARCHIVE>
177 void serialize(ARCHIVE & ar, const unsigned int /*version*/) {
178 ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
179 }
180#endif
181 };
182
183// traits
184template<> struct traits<DiscreteBayesNet> : public Testable<DiscreteBayesNet> {};
185
186} // \ namespace gtsam
187
188