1/**
2 * @file ABC_EQF.h
3 * @brief Header file for the Attitude-Bias-Calibration Equivariant Filter
4 *
5 * This file contains declarations for the Equivariant Filter (EqF) for attitude
6 * estimation with both gyroscope bias and sensor extrinsic calibration, based
7 * on the paper: "Overcoming Bias: Equivariant Filter Design for Biased Attitude
8 * Estimation with Online Calibration" by Fornasier et al. Authors: Darshan
9 * Rajasekaran & Jennifer Oum
10 */
11
12#ifndef ABC_EQF_H
13#define ABC_EQF_H
14#pragma once
15#include <gtsam/base/Matrix.h>
16#include <gtsam/base/Vector.h>
17#include <gtsam/geometry/Rot3.h>
18#include <gtsam/geometry/Unit3.h>
19#include <gtsam/inference/Symbol.h>
20#include <gtsam/navigation/ImuBias.h>
21#include <gtsam/nonlinear/Values.h>
22#include <gtsam/slam/dataset.h>
23
24#include <chrono>
25#include <cmath>
26#include <fstream>
27#include <functional>
28#include <iostream>
29#include <numeric> // For std::accumulate
30#include <string>
31#include <vector>
32
33#include "ABC.h"
34
35// All implementations are wrapped in this namespace to avoid conflicts
36namespace gtsam {
37namespace abc_eqf_lib {
38
39using namespace std;
40using namespace gtsam;
41
42//========================================================================
43// Helper Functions for EqF
44//========================================================================
45
46/// Calculate numerical differential
47
48Matrix numericalDifferential(std::function<Vector(const Vector&)> f,
49 const Vector& x);
50
51/**
52 * Compute the lift of the system (Theorem 3.8, Equation 7)
53 * @param xi State
54 * @param u Input
55 * @return Lift vector
56 */
57template <size_t N>
58Vector lift(const State<N>& xi, const Input& u);
59
60/**
61 * Action of the symmetry group on the state space (Equation 4)
62 * @param X Group element
63 * @param xi State
64 * @return New state after group action
65 */
66template <size_t N>
67State<N> operator*(const G<N>& X, const State<N>& xi);
68
69/**
70 * Action of the symmetry group on the input space (Equation 5)
71 * @param X Group element
72 * @param u Input
73 * @return New input after group action
74 */
75template <size_t N>
76Input velocityAction(const G<N>& X, const Input& u);
77
78/**
79 * Action of the symmetry group on the output space (Equation 6)
80 * @param X Group element
81 * @param y Direction measurement
82 * @param idx Calibration index
83 * @return New direction after group action
84 */
85template <size_t N>
86Vector3 outputAction(const G<N>& X, const Unit3& y, int idx);
87
88/**
89 * Differential of the phi action at E = Id in local coordinates
90 * @param xi State
91 * @return Differential matrix
92 */
93template <size_t N>
94Matrix stateActionDiff(const State<N>& xi);
95
96//========================================================================
97// Equivariant Filter (EqF)
98//========================================================================
99
100/// Equivariant Filter (EqF) implementation
101template <size_t N>
102class EqF {
103 private:
104 int dof; // Degrees of freedom
105 G<N> X_hat; // Filter state
106 Matrix Sigma; // Error covariance
107 State<N> xi_0; // Origin state
108 Matrix Dphi0; // Differential of phi at origin
109 Matrix InnovationLift; // Innovation lift matrix
110
111 /**
112 * Return the state matrix A0t (Equation 14a)
113 * @param u Input
114 * @return State matrix A0t
115 */
116 Matrix stateMatrixA(const Input& u) const;
117
118 /**
119 * Return the state transition matrix Phi (Equation 17)
120 * @param u Input
121 * @param dt Time step
122 * @return State transition matrix Phi
123 */
124 Matrix stateTransitionMatrix(const Input& u, double dt) const;
125
126 /**
127 * Return the Input matrix Bt
128 * @return Input matrix Bt
129 */
130 Matrix inputMatrixBt() const;
131
132 /**
133 * Return the measurement matrix C0 (Equation 14b)
134 * @param d Known direction
135 * @param idx Calibration index
136 * @return Measurement matrix C0
137 */
138 Matrix measurementMatrixC(const Unit3& d, int idx) const;
139
140 /**
141 * Return the measurement output matrix Dt
142 * @param idx Calibration index
143 * @return Measurement output matrix Dt
144 */
145 Matrix outputMatrixDt(int idx) const;
146
147 public:
148 /**
149 * Initialize EqF
150 * @param Sigma Initial covariance
151 * @param m Number of sensors
152 */
153 EqF(const Matrix& Sigma, int m);
154
155 /**
156 * Return estimated state
157 * @return Current state estimate
158 */
159 State<N> stateEstimate() const;
160
161 /**
162 * Propagate the filter state
163 * @param u Angular velocity measurement
164 * @param dt Time step
165 */
166 void propagation(const Input& u, double dt);
167
168 /**
169 * Update the filter state with a measurement
170 * @param y Direction measurement
171 */
172 void update(const Measurement& y);
173};
174
175//========================================================================
176// Helper Functions Implementation
177//========================================================================
178
179/**
180 * Maps system dynamics to the symmetry group
181 * @param xi State
182 * @param u Input
183 * @return Lifted input in Lie Algebra
184 * Uses Vector zero & Rot3 inverse, matrix functions
185 */
186template <size_t N>
187Vector lift(const State<N>& xi, const Input& u) {
188 Vector L = Vector::Zero(size: 6 + 3 * N);
189
190 // First 3 elements
191 L.head<3>() = u.w - xi.b;
192
193 // Next 3 elements
194 L.segment<3>(start: 3) = -u.W() * xi.b;
195
196 // Remaining elements
197 for (size_t i = 0; i < N; i++) {
198 L.segment<3>(start: 6 + 3 * i) = xi.S[i].inverse().matrix() * L.head<3>();
199 }
200
201 return L;
202}
203/**
204 * Implements group actions on the states
205 * @param X A symmetry group element G consisting of the attitude, bias and the
206 * calibration components X.a -> Rotation matrix containing the attitude X.b ->
207 * A skew-symmetric matrix representing bias X.B -> A vector of Rotation
208 * matrices for the calibration components
209 * @param xi State object
210 * xi.R -> Attitude (Rot3)
211 * xi.b -> Gyroscope Bias(Vector 3)
212 * xi.S -> Vector of calibration matrices(Rot3)
213 * @return Transformed state
214 * Uses the Rot3 inverse and Vee functions
215 */
216template <size_t N>
217State<N> operator*(const G<N>& X, const State<N>& xi) {
218 std::array<Rot3, N> new_S;
219
220 for (size_t i = 0; i < N; i++) {
221 new_S[i] = X.A.inverse() * xi.S[i] * X.B[i];
222 }
223
224 return State<N>(xi.R * X.A, X.A.inverse().matrix() * (xi.b - Rot3::Vee(X: X.a)),
225 new_S);
226}
227/**
228 * Transforms the angular velocity measurements b/w frames
229 * @param X A symmetry group element X with the components
230 * @param u Inputs
231 * @return Transformed inputs
232 * Uses Rot3 Inverse, matrix and Vee functions and is critical for maintaining
233 * the input equivariance
234 */
235template <size_t N>
236Input velocityAction(const G<N>& X, const Input& u) {
237 return Input{X.A.inverse().matrix() * (u.w - Rot3::Vee(X: X.a)), u.Sigma};
238}
239/**
240 * Transforms the Direction measurements based on the calibration type ( Eqn 6)
241 * @param X Group element X
242 * @param y Direction measurement y
243 * @param idx Calibration index
244 * @return Transformed direction
245 * Uses Rot3 inverse, matric and Unit3 unitvector functions
246 */
247template <size_t N>
248Vector3 outputAction(const G<N>& X, const Unit3& y, int idx) {
249 if (idx == -1) {
250 return X.A.inverse().matrix() * y.unitVector();
251 } else {
252 if (idx >= static_cast<int>(N)) {
253 throw std::out_of_range("Calibration index out of range");
254 }
255 return X.B[idx].inverse().matrix() * y.unitVector();
256 }
257}
258
259/**
260 * @brief Calculates the Jacobian matrix using central difference approximation
261 * @param f Vector function f
262 * @param x The point at which Jacobian is evaluated
263 * @return Matrix containing numerical partial derivatives of f at x
264 * Uses Vector's size() and Zero(), Matrix's Zero() and col() methods
265 */
266Matrix numericalDifferential(std::function<Vector(const Vector&)> f,
267 const Vector& x) {
268 double h = 1e-6;
269 Vector fx = f(x);
270 int n = fx.size();
271 int m = x.size();
272 Matrix Df = Matrix::Zero(rows: n, cols: m);
273
274 for (int j = 0; j < m; j++) {
275 Vector ej = Vector::Zero(size: m);
276 ej(j) = 1.0;
277
278 Vector fplus = f(x + h * ej);
279 Vector fminus = f(x - h * ej);
280
281 Df.col(i: j) = (fplus - fminus) / (2 * h);
282 }
283
284 return Df;
285}
286
287/**
288 * Computes the differential of a state action at the identity of the symmetry
289 * group
290 * @param xi State object Xi representing the point at which to evaluate the
291 * differential
292 * @return A matrix representing the jacobian of the state action
293 * Uses numericalDifferential, and Rot3 expmap, logmap
294 */
295template <size_t N>
296Matrix stateActionDiff(const State<N>& xi) {
297 std::function<Vector(const Vector&)> coordsAction = [&xi](const Vector& U) {
298 G<N> groupElement = G<N>::exp(U);
299 State<N> transformed = groupElement * xi;
300 return xi.localCoordinates(transformed);
301 };
302
303 Vector zeros = Vector::Zero(size: 6 + 3 * N);
304 Matrix differential = numericalDifferential(f: coordsAction, x: zeros);
305 return differential;
306}
307
308//========================================================================
309// Equivariant Filter (EqF) Implementation
310//========================================================================
311/**
312 * Initializes the EqF with state dimension validation and computes lifted
313 * innovation mapping
314 * @param Sigma Initial covariance
315 * @param n Number of calibration states
316 * @param m Number of sensors
317 * Uses SelfAdjointSolver, completeOrthoganalDecomposition().pseudoInverse()
318 */
319template <size_t N>
320EqF<N>::EqF(const Matrix& Sigma, int m)
321 : dof(6 + 3 * N),
322 X_hat(G<N>::identity(N)),
323 Sigma(Sigma),
324 xi_0(State<N>::identity()) {
325 if (Sigma.rows() != dof || Sigma.cols() != dof) {
326 throw std::invalid_argument(
327 "Initial covariance dimensions must match the degrees of freedom");
328 }
329
330 // Check positive semi-definite
331 Eigen::SelfAdjointEigenSolver<Matrix> eigensolver(Sigma);
332 if (eigensolver.eigenvalues().minCoeff() < -1e-10) {
333 throw std::invalid_argument(
334 "Covariance matrix must be semi-positive definite");
335 }
336
337 if (N < 0) {
338 throw std::invalid_argument(
339 "Number of calibration states must be non-negative");
340 }
341
342 if (m <= 1) {
343 throw std::invalid_argument(
344 "Number of direction sensors must be at least 2");
345 }
346
347 // Compute differential of phi
348 Dphi0 = stateActionDiff(xi_0);
349 InnovationLift = Dphi0.completeOrthogonalDecomposition().pseudoInverse();
350}
351/**
352 * Computes the internal group state to a physical state estimate
353 * @return Current state estimate
354 */
355template <size_t N>
356State<N> EqF<N>::stateEstimate() const {
357 return X_hat * xi_0;
358}
359/**
360 * Implements the prediction step of the EqF using system dynamics and
361 * covariance propagation and advances the filter state by symmtery-preserving
362 * dynamics.Uses a Lie group integrator scheme for discrete time propagation
363 * @param u Angular velocity measurements
364 * @param dt time steps
365 * Updated internal state and covariance
366 */
367template <size_t N>
368void EqF<N>::propagation(const Input& u, double dt) {
369 State<N> state_est = stateEstimate();
370 Vector L = lift(state_est, u);
371
372 Matrix Phi_DT = stateTransitionMatrix(u, dt);
373 Matrix Bt = inputMatrixBt();
374
375 Matrix tempSigma = blockDiag(A: u.Sigma, B: repBlock(A: 1e-9 * I_3x3, n: N));
376 Matrix M_DT = (Bt * tempSigma * Bt.transpose()) * dt;
377
378 X_hat = X_hat * G<N>::exp(L * dt);
379 Sigma = Phi_DT * Sigma * Phi_DT.transpose() + M_DT;
380}
381/**
382 * Implements the correction step of the filter using discrete measurements
383 * Computes the measurement residual, Kalman gain and the updates both the state
384 * and covariance
385 *
386 * @param y Measurements
387 */
388template <size_t N>
389void EqF<N>::update(const Measurement& y) {
390 if (y.cal_idx > static_cast<int>(N)) {
391 throw std::invalid_argument("Calibration index out of range");
392 }
393
394 // Get vector representations for checking
395 Vector3 y_vec = y.y.unitVector();
396 Vector3 d_vec = y.d.unitVector();
397
398 // Skip update if any NaN values are present
399 if (std::isnan(x: y_vec[0]) || std::isnan(x: y_vec[1]) || std::isnan(x: y_vec[2]) ||
400 std::isnan(x: d_vec[0]) || std::isnan(x: d_vec[1]) || std::isnan(x: d_vec[2])) {
401 return; // Skip this measurement
402 }
403
404 Matrix Ct = measurementMatrixC(d: y.d, idx: y.cal_idx);
405 Vector3 action_result = outputAction(X_hat.inv(), y.y, y.cal_idx);
406 Vector3 delta_vec = Rot3::Hat(xi: y.d.unitVector()) * action_result;
407 Matrix Dt = outputMatrixDt(idx: y.cal_idx);
408 Matrix S = Ct * Sigma * Ct.transpose() + Dt * y.Sigma * Dt.transpose();
409 Matrix K = Sigma * Ct.transpose() * S.inverse();
410 Vector Delta = InnovationLift * K * delta_vec;
411 X_hat = G<N>::exp(Delta) * X_hat;
412 Sigma = (Matrix::Identity(rows: dof, cols: dof) - K * Ct) * Sigma;
413}
414/**
415 * Computes linearized continuous time state matrix
416 * @param u Angular velocity
417 * @return Linearized state matrix
418 * Uses Matrix zero and Identity functions
419 */
420template <size_t N>
421Matrix EqF<N>::stateMatrixA(const Input& u) const {
422 Matrix3 W0 = velocityAction(X_hat.inv(), u).W();
423 Matrix A1 = Matrix::Zero(rows: 6, cols: 6);
424 A1.block<3, 3>(startRow: 0, startCol: 3) = -I_3x3;
425 A1.block<3, 3>(startRow: 3, startCol: 3) = W0;
426 Matrix A2 = repBlock(A: W0, n: N);
427 return blockDiag(A: A1, B: A2);
428}
429
430/**
431 * Computes the discrete time state transition matrix
432 * @param u Angular velocity
433 * @param dt time step
434 * @return State transition matrix in discrete time
435 */
436template <size_t N>
437Matrix EqF<N>::stateTransitionMatrix(const Input& u, double dt) const {
438 Matrix3 W0 = velocityAction(X_hat.inv(), u).W();
439 Matrix Phi1 = Matrix::Zero(rows: 6, cols: 6);
440
441 Matrix3 Phi12 = -dt * (I_3x3 + (dt / 2) * W0 + ((dt * dt) / 6) * W0 * W0);
442 Matrix3 Phi22 = I_3x3 + dt * W0 + ((dt * dt) / 2) * W0 * W0;
443
444 Phi1.block<3, 3>(startRow: 0, startCol: 0) = I_3x3;
445 Phi1.block<3, 3>(startRow: 0, startCol: 3) = Phi12;
446 Phi1.block<3, 3>(startRow: 3, startCol: 3) = Phi22;
447 Matrix Phi2 = repBlock(A: Phi22, n: N);
448 return blockDiag(A: Phi1, B: Phi2);
449}
450/**
451 * Computes the input uncertainty propagation matrix
452 * @return
453 * Uses the blockdiag matrix
454 */
455template <size_t N>
456Matrix EqF<N>::inputMatrixBt() const {
457 Matrix B1 = blockDiag(X_hat.A.matrix(), X_hat.A.matrix());
458 Matrix B2(3 * N, 3 * N);
459
460 for (size_t i = 0; i < N; ++i) {
461 B2.block<3, 3>(startRow: 3 * i, startCol: 3 * i) = X_hat.B[i].matrix();
462 }
463
464 return blockDiag(A: B1, B: B2);
465}
466/**
467 * Computes the linearized measurement matrix. The structure depends on whether
468 * the sensor has a calibration state
469 * @param d reference direction
470 * @param idx Calibration index
471 * @return Measurement matrix
472 * Uses the matrix zero, Rot3 hat and the Unitvector functions
473 */
474template <size_t N>
475Matrix EqF<N>::measurementMatrixC(const Unit3& d, int idx) const {
476 Matrix Cc = Matrix::Zero(rows: 3, cols: 3 * N);
477
478 // If the measurement is related to a sensor that has a calibration state
479 if (idx >= 0) {
480 // Set the correct 3x3 block in Cc
481 Cc.block<3, 3>(startRow: 0, startCol: 3 * idx) = Rot3::Hat(xi: d.unitVector());
482 }
483
484 Matrix3 wedge_d = Rot3::Hat(xi: d.unitVector());
485
486 // Create the combined matrix
487 Matrix temp(3, 6 + 3 * N);
488 temp.block<3, 3>(startRow: 0, startCol: 0) = wedge_d;
489 temp.block<3, 3>(startRow: 0, startCol: 3) = Matrix3::Zero();
490 temp.block(startRow: 0, startCol: 6, blockRows: 3, blockCols: 3 * N) = Cc;
491
492 return wedge_d * temp;
493}
494/**
495 * Computes the measurement uncertainty propagation matrix
496 * @param idx Calibration index
497 * @return Returns B[idx] for calibrated sensors, A for uncalibrated
498 */
499template <size_t N>
500Matrix EqF<N>::outputMatrixDt(int idx) const {
501 // If the measurement is related to a sensor that has a calibration state
502 if (idx >= 0) {
503 if (idx >= static_cast<int>(N)) {
504 throw std::out_of_range("Calibration index out of range");
505 }
506 return X_hat.B[idx].matrix();
507 } else {
508 return X_hat.A.matrix();
509 }
510}
511
512} // namespace abc_eqf_lib
513
514template <size_t N>
515struct traits<abc_eqf_lib::EqF<N>>
516 : internal::LieGroupTraits<abc_eqf_lib::EqF<N>> {};
517} // namespace gtsam
518
519#endif // ABC_EQF_H