1/* ----------------------------------------------------------------------------
2
3 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4 * Atlanta, Georgia 30332-0415
5 * All Rights Reserved
6 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7
8 * See LICENSE for the license information
9
10 * -------------------------------------------------------------------------- */
11
12/**
13 * @file testSubgraphConditioner.cpp
14 * @brief Unit tests for SubgraphPreconditioner
15 * @author Frank Dellaert
16 **/
17
18#include <tests/smallExample.h>
19
20#include <gtsam/base/numericalDerivative.h>
21#include <gtsam/inference/Ordering.h>
22#include <gtsam/inference/Symbol.h>
23#include <gtsam/linear/GaussianEliminationTree.h>
24#include <gtsam/linear/GaussianFactorGraph.h>
25#include <gtsam/linear/SubgraphPreconditioner.h>
26#include <gtsam/linear/iterative.h>
27#include <gtsam/slam/dataset.h>
28#include <gtsam/symbolic/SymbolicFactorGraph.h>
29
30#include <CppUnitLite/TestHarness.h>
31
32#include <fstream>
33
34using namespace std;
35using namespace gtsam;
36using namespace example;
37
38// define keys
39// Create key for simulated planar graph
40Symbol key(int x, int y) { return symbol_shorthand::X(j: 1000 * x + y); }
41
42/* ************************************************************************* */
43TEST(SubgraphPreconditioner, planarOrdering) {
44 // Check canonical ordering
45 Ordering ordering = planarOrdering(N: 3),
46 expected{key(x: 3, y: 3), key(x: 2, y: 3), key(x: 1, y: 3), key(x: 3, y: 2), key(x: 2, y: 2),
47 key(x: 1, y: 2), key(x: 3, y: 1), key(x: 2, y: 1), key(x: 1, y: 1)};
48 EXPECT(assert_equal(expected, ordering));
49}
50
51/* ************************************************************************* */
52/** unnormalized error */
53static double error(const GaussianFactorGraph& fg, const VectorValues& x) {
54 double total_error = 0.;
55 for (const GaussianFactor::shared_ptr& factor : fg)
56 total_error += factor->error(c: x);
57 return total_error;
58}
59
60/* ************************************************************************* */
61TEST(SubgraphPreconditioner, planarGraph) {
62 // Check planar graph construction
63 const auto [A, xtrue] = planarGraph(N: 3);
64 LONGS_EQUAL(13, A.size());
65 LONGS_EQUAL(9, xtrue.size());
66 DOUBLES_EQUAL(0, error(A, xtrue), 1e-9); // check zero error for xtrue
67
68 // Check that xtrue is optimal
69 GaussianBayesNet R1 = *A.eliminateSequential();
70 VectorValues actual = R1.optimize();
71 EXPECT(assert_equal(xtrue, actual));
72}
73
74/* ************************************************************************* */
75TEST(SubgraphPreconditioner, splitOffPlanarTree) {
76 // Build a planar graph
77 const auto [A, xtrue] = planarGraph(N: 3);
78
79 // Get the spanning tree and constraints, and check their sizes
80 const auto [T, C] = splitOffPlanarTree(N: 3, original: A);
81 LONGS_EQUAL(9, T.size());
82 LONGS_EQUAL(4, C.size());
83
84 // Check that the tree can be solved to give the ground xtrue
85 GaussianBayesNet R1 = *T.eliminateSequential();
86 VectorValues xbar = R1.optimize();
87 EXPECT(assert_equal(xtrue, xbar));
88}
89
90/* ************************************************************************* */
91TEST(SubgraphPreconditioner, system) {
92 // Build a planar graph
93 size_t N = 3;
94 const auto [Ab, xtrue] = planarGraph(N); // A*x-b
95
96 // Get the spanning tree and remaining graph
97 auto [Ab1, Ab2] = splitOffPlanarTree(N, original: Ab);
98
99 // Eliminate the spanning tree to build a prior
100 const Ordering ord = planarOrdering(N);
101 auto Rc1 = *Ab1.eliminateSequential(ordering: ord); // R1*x-c1
102 VectorValues xbar = Rc1.optimize(); // xbar = inv(R1)*c1
103
104 // Create Subgraph-preconditioned system
105 const SubgraphPreconditioner system(Ab2, Rc1, xbar);
106
107 // Get corresponding matrices for tests. Add dummy factors to Ab2 to make
108 // sure it works with the ordering.
109 Ordering ordering = Rc1.ordering(); // not ord in general!
110 Ab2.add(key1: key(x: 1, y: 1), A1: Z_2x2, b: Z_2x1);
111 Ab2.add(key1: key(x: 1, y: 2), A1: Z_2x2, b: Z_2x1);
112 Ab2.add(key1: key(x: 1, y: 3), A1: Z_2x2, b: Z_2x1);
113 const auto [A, b] = Ab.jacobian(ordering);
114 const auto [A1, b1] = Ab1.jacobian(ordering);
115 const auto [A2, b2] = Ab2.jacobian(ordering);
116 Matrix R1 = Rc1.matrix(ordering).first;
117 Matrix Abar(13 * 2, 9 * 2);
118 Abar.topRows(n: 9 * 2) = Matrix::Identity(rows: 9 * 2, cols: 9 * 2);
119 Abar.bottomRows(n: 8) = A2.topRows(n: 8) * R1.inverse();
120
121 // Helper function to vectorize in correct order, which is the order in which
122 // we eliminated the spanning tree.
123 auto vec = [ordering](const VectorValues& x) { return x.vector(keys: ordering); };
124
125 // Set up y0 as all zeros
126 const VectorValues y0 = system.zero();
127
128 // y1 = perturbed y0
129 VectorValues y1 = system.zero();
130 y1[key(x: 3, y: 3)] = Vector2(1.0, -1.0);
131
132 // Check backSubstituteTranspose works with R1
133 VectorValues actual = Rc1.backSubstituteTranspose(gx: y1);
134 Vector expected = R1.transpose().inverse() * vec(y1);
135 EXPECT(assert_equal(expected, vec(actual)));
136
137 // Check corresponding x values
138 // for y = 0, we get xbar:
139 EXPECT(assert_equal(xbar, system.x(y0)));
140 // for non-zero y, answer is x = xbar + inv(R1)*y
141 const Vector expected_x1 = vec(xbar) + R1.inverse() * vec(y1);
142 const VectorValues x1 = system.x(y: y1);
143 EXPECT(assert_equal(expected_x1, vec(x1)));
144
145 // Check errors
146 DOUBLES_EQUAL(0, error(Ab, xbar), 1e-9);
147 DOUBLES_EQUAL(0, system.error(y0), 1e-9);
148 DOUBLES_EQUAL(2, error(Ab, x1), 1e-9);
149 DOUBLES_EQUAL(2, system.error(y1), 1e-9);
150
151 // Check that transposeMultiplyAdd <=> y += alpha * Abar' * e
152 // We check for e1 =[1;0] and e2=[0;1] corresponding to T and C
153 const double alpha = 0.5;
154 Errors e1, e2;
155 for (size_t i = 0; i < 13; i++) {
156 e1.push_back(x: i < 9 ? Vector2(1, 1) : Vector2(0, 0));
157 e2.push_back(x: i >= 9 ? Vector2(1, 1) : Vector2(0, 0));
158 }
159 Vector ee1(13 * 2), ee2(13 * 2);
160 ee1 << Vector::Ones(newSize: 9 * 2), Vector::Zero(size: 4 * 2);
161 ee2 << Vector::Zero(size: 9 * 2), Vector::Ones(newSize: 4 * 2);
162
163 // Check transposeMultiplyAdd for e1
164 VectorValues y = system.zero();
165 system.transposeMultiplyAdd(alpha, e: e1, y);
166 Vector expected_y = alpha * Abar.transpose() * ee1;
167 EXPECT(assert_equal(expected_y, vec(y)));
168
169 // Check transposeMultiplyAdd for e2
170 y = system.zero();
171 system.transposeMultiplyAdd(alpha, e: e2, y);
172 expected_y = alpha * Abar.transpose() * ee2;
173 EXPECT(assert_equal(expected_y, vec(y)));
174
175 // Test gradient in y
176 auto g = system.gradient(y: y0);
177 Vector expected_g = Vector::Zero(size: 18);
178 EXPECT(assert_equal(expected_g, vec(g)));
179}
180
181/* ************************************************************************* */
182TEST(SubgraphPreconditioner, conjugateGradients) {
183 // Build a planar graph
184 size_t N = 3;
185 const auto [Ab, xtrue] = planarGraph(N); // A*x-b
186
187 // Get the spanning tree
188 const auto [Ab1, Ab2] = splitOffPlanarTree(N, original: Ab);
189
190 // Eliminate the spanning tree to build a prior
191 GaussianBayesNet Rc1 = *Ab1.eliminateSequential(); // R1*x-c1
192 VectorValues xbar = Rc1.optimize(); // xbar = inv(R1)*c1
193
194 // Create Subgraph-preconditioned system
195 SubgraphPreconditioner system(Ab2, Rc1, xbar);
196
197 // Create zero config y0 and perturbed config y1
198 VectorValues y0 = VectorValues::Zero(other: xbar);
199
200 VectorValues y1 = y0;
201 y1[key(x: 2, y: 2)] = Vector2(1.0, -1.0);
202 VectorValues x1 = system.x(y: y1);
203
204 // Solve for the remaining constraints using PCG
205 ConjugateGradientParameters parameters;
206 VectorValues actual = conjugateGradients<SubgraphPreconditioner,
207 VectorValues, Errors>(Ab: system, x: y1, parameters);
208 EXPECT(assert_equal(y0,actual));
209
210 // Compare with non preconditioned version:
211 VectorValues actual2 = conjugateGradientDescent(fg: Ab, x: x1, parameters);
212 EXPECT(assert_equal(xtrue, actual2, 1e-4));
213}
214
215/* ************************************************************************* */
216int main() {
217 TestResult tr;
218 return TestRegistry::runAllTests(result&: tr);
219}
220/* ************************************************************************* */
221