1/**
2 * @file ABC_EQF_Demo.cpp
3 * @brief Demonstration of the full Attitude-Bias-Calibration Equivariant Filter
4 *
5 * This demo shows the Equivariant Filter (EqF) for attitude estimation
6 * with both gyroscope bias and sensor extrinsic calibration, based on the
7 * paper: "Overcoming Bias: Equivariant Filter Design for Biased Attitude
8 * Estimation with Online Calibration" by Fornasier et al. Authors: Darshan
9 * Rajasekaran & Jennifer Oum
10 */
11
12#include "ABC_EQF.h"
13
14// Use namespace for convenience
15using namespace gtsam;
16constexpr size_t N = 1; // Number of calibration states
17using M = abc_eqf_lib::State<N>;
18using Group = abc_eqf_lib::G<N>;
19using EqFilter = abc_eqf_lib::EqF<N>;
20using gtsam::abc_eqf_lib::EqF;
21using gtsam::abc_eqf_lib::Input;
22using gtsam::abc_eqf_lib::Measurement;
23
24/// Data structure for ground-truth, input and output data
25struct Data {
26 M xi; /// Ground-truth state
27 Input u; /// Input measurements
28 std::vector<Measurement> y; /// Output measurements
29 int n_meas; /// Number of measurements
30 double t; /// Time
31 double dt; /// Time step
32};
33
34//========================================================================
35// Data Processing Functions
36//========================================================================
37
38/**
39 * Load data from CSV file into a vector of Data objects for the EqF
40 *
41 * CSV format:
42 * - t: Time
43 * - q_w, q_x, q_y, q_z: True attitude quaternion
44 * - b_x, b_y, b_z: True bias
45 * - cq_w_0, cq_x_0, cq_y_0, cq_z_0: True calibration quaternion
46 * - w_x, w_y, w_z: Angular velocity measurements
47 * - std_w_x, std_w_y, std_w_z: Angular velocity measurement standard deviations
48 * - std_b_x, std_b_y, std_b_z: Bias process noise standard deviations
49 * - y_x_0, y_y_0, y_z_0, y_x_1, y_y_1, y_z_1: Direction measurements
50 * - std_y_x_0, std_y_y_0, std_y_z_0, std_y_x_1, std_y_y_1, std_y_z_1: Direction
51 * measurement standard deviations
52 * - d_x_0, d_y_0, d_z_0, d_x_1, d_y_1, d_z_1: Reference directions
53 *
54 */
55std::vector<Data> loadDataFromCSV(const std::string& filename, int startRow = 0,
56 int maxRows = -1, int downsample = 1);
57
58/// Process data with EqF and print summary results
59void processDataWithEqF(EqFilter& filter, const std::vector<Data>& data_list,
60 int printInterval = 10);
61
62//========================================================================
63// Data Processing Functions Implementation
64//========================================================================
65
66/*
67 * Loads the test data from the csv file
68 * startRow First row to load based on csv, 0 by default
69 * maxRows maximum rows to load, defaults to all rows
70 * downsample Downsample factor, default 1
71 * A list of data objects
72 */
73
74std::vector<Data> loadDataFromCSV(const std::string& filename, int startRow,
75 int maxRows, int downsample) {
76 std::vector<Data> data_list;
77 std::ifstream file(filename);
78
79 if (!file.is_open()) {
80 throw std::runtime_error("Failed to open file: " + filename);
81 }
82
83 std::cout << "Loading data from " << filename << "..." << std::flush;
84
85 std::string line;
86 int lineNumber = 0;
87 int rowCount = 0;
88 int errorCount = 0;
89 double prevTime = 0.0;
90
91 // Skip header
92 std::getline(is&: file, str&: line);
93 lineNumber++;
94
95 // Skip to startRow
96 while (lineNumber < startRow && std::getline(is&: file, str&: line)) {
97 lineNumber++;
98 }
99
100 // Read data
101 while (std::getline(is&: file, str&: line) && (maxRows == -1 || rowCount < maxRows)) {
102 lineNumber++;
103
104 // Apply downsampling
105 if ((lineNumber - startRow - 1) % downsample != 0) {
106 continue;
107 }
108
109 std::istringstream ss(line);
110 std::string token;
111 std::vector<double> values;
112
113 // Parse line into values
114 while (std::getline(in&: ss, str&: token, delim: ',')) {
115 try {
116 values.push_back(x: std::stod(str: token));
117 } catch (const std::exception& e) {
118 errorCount++;
119 values.push_back(x: 0.0); // Use default value
120 }
121 }
122
123 // Check if we have enough values
124 if (values.size() < 39) {
125 errorCount++;
126 continue;
127 }
128
129 // Extract values
130 double t = values[0];
131 double dt = (rowCount == 0) ? 0.0 : t - prevTime;
132 prevTime = t;
133
134 // Create ground truth state
135 Quaternion quat(values[1], values[2], values[3], values[4]); // w, x, y, z
136 Rot3 R = Rot3(quat);
137
138 Vector3 b(values[5], values[6], values[7]);
139
140 Quaternion calQuat(values[8], values[9], values[10],
141 values[11]); // w, x, y, z
142 std::array<Rot3, N> S = {Rot3(calQuat)};
143
144 M xi(R, b, S);
145
146 // Create input
147 Vector3 w(values[12], values[13], values[14]);
148
149 // Create input covariance matrix (6x6)
150 // First 3x3 block for angular velocity, second 3x3 block for bias process
151 // noise
152 Matrix inputCov = Matrix::Zero(rows: 6, cols: 6);
153 inputCov(0, 0) = values[15] * values[15]; // std_w_x^2
154 inputCov(1, 1) = values[16] * values[16]; // std_w_y^2
155 inputCov(2, 2) = values[17] * values[17]; // std_w_z^2
156 inputCov(3, 3) = values[18] * values[18]; // std_b_x^2
157 inputCov(4, 4) = values[19] * values[19]; // std_b_y^2
158 inputCov(5, 5) = values[20] * values[20]; // std_b_z^2
159
160 Input u{.w: w, .Sigma: inputCov};
161
162 // Create measurements
163 std::vector<Measurement> measurements;
164
165 // First measurement (calibrated sensor, cal_idx = 0)
166 Vector3 y0(values[21], values[22], values[23]);
167 Vector3 d0(values[33], values[34], values[35]);
168
169 // Normalize vectors if needed
170 if (abs(x: y0.norm() - 1.0) > 1e-5) y0.normalize();
171 if (abs(x: d0.norm() - 1.0) > 1e-5) d0.normalize();
172
173 // Measurement covariance
174 Matrix3 covY0 = Matrix3::Zero();
175 covY0(0, 0) = values[27] * values[27]; // std_y_x_0^2
176 covY0(1, 1) = values[28] * values[28]; // std_y_y_0^2
177 covY0(2, 2) = values[29] * values[29]; // std_y_z_0^2
178
179 // Create measurement
180 measurements.push_back(x: Measurement{.y: Unit3(y0), .d: Unit3(d0), .Sigma: covY0, .cal_idx: 0});
181
182 // Second measurement (calibrated sensor, cal_idx = -1)
183 Vector3 y1(values[24], values[25], values[26]);
184 Vector3 d1(values[36], values[37], values[38]);
185
186 // Normalize vectors if needed
187 if (abs(x: y1.norm() - 1.0) > 1e-5) y1.normalize();
188 if (abs(x: d1.norm() - 1.0) > 1e-5) d1.normalize();
189
190 // Measurement covariance
191 Matrix3 covY1 = Matrix3::Zero();
192 covY1(0, 0) = values[30] * values[30]; // std_y_x_1^2
193 covY1(1, 1) = values[31] * values[31]; // std_y_y_1^2
194 covY1(2, 2) = values[32] * values[32]; // std_y_z_1^2
195
196 // Create measurement
197 measurements.push_back(x: Measurement{.y: Unit3(y1), .d: Unit3(d1), .Sigma: covY1, .cal_idx: -1});
198
199 // Create Data object and add to list
200 data_list.push_back(x: Data{.xi: xi, .u: u, .y: measurements, .n_meas: 2, .t: t, .dt: dt});
201
202 rowCount++;
203
204 // Show loading progress every 1000 rows
205 if (rowCount % 1000 == 0) {
206 std::cout << "." << std::flush;
207 }
208 }
209
210 std::cout << " Done!" << std::endl;
211 std::cout << "Loaded " << data_list.size() << " data points";
212
213 if (errorCount > 0) {
214 std::cout << " (" << errorCount << " errors encountered)";
215 }
216
217 std::cout << std::endl;
218
219 return data_list;
220}
221
222/// Takes in the data and runs an EqF on it and reports the results
223void processDataWithEqF(EqFilter& filter, const std::vector<Data>& data_list,
224 int printInterval) {
225 if (data_list.empty()) {
226 std::cerr << "No data to process" << std::endl;
227 return;
228 }
229
230 std::cout << "Processing " << data_list.size() << " data points with EqF..."
231 << std::endl;
232
233 // Track performance metrics
234 std::vector<double> att_errors;
235 std::vector<double> bias_errors;
236 std::vector<double> cal_errors;
237
238 // Track time for performance measurement
239 auto start = std::chrono::high_resolution_clock::now();
240
241 int totalMeasurements = 0;
242 int validMeasurements = 0;
243
244 // Define constant for converting radians to degrees
245 const double RAD_TO_DEG = 180.0 / M_PI;
246
247 // Print a progress indicator
248 int progressStep = data_list.size() / 10; // 10 progress updates
249 if (progressStep < 1) progressStep = 1;
250
251 std::cout << "Progress: ";
252
253 for (size_t i = 0; i < data_list.size(); i++) {
254 const Data& data = data_list[i];
255
256 // Propagate filter with current input and time step
257 filter.propagation(u: data.u, dt: data.dt);
258
259 // Process all measurements
260 for (const auto& y : data.y) {
261 totalMeasurements++;
262
263 // Skip invalid measurements
264 Vector3 y_vec = y.y.unitVector();
265 Vector3 d_vec = y.d.unitVector();
266 if (std::isnan(x: y_vec[0]) || std::isnan(x: y_vec[1]) ||
267 std::isnan(x: y_vec[2]) || std::isnan(x: d_vec[0]) ||
268 std::isnan(x: d_vec[1]) || std::isnan(x: d_vec[2])) {
269 continue;
270 }
271
272 try {
273 filter.update(y);
274 validMeasurements++;
275 } catch (const std::exception& e) {
276 std::cerr << "Error updating at t=" << data.t << ": " << e.what()
277 << std::endl;
278 }
279 }
280
281 // Get current state estimate
282 M estimate = filter.stateEstimate();
283
284 // Calculate errors
285 Vector3 att_error = Rot3::Logmap(R: data.xi.R.between(g: estimate.R));
286 Vector3 bias_error = estimate.b - data.xi.b;
287 Vector3 cal_error = Vector3::Zero();
288 if (!data.xi.S.empty() && !estimate.S.empty()) {
289 cal_error = Rot3::Logmap(R: data.xi.S[0].between(g: estimate.S[0]));
290 }
291
292 // Store errors
293 att_errors.push_back(x: att_error.norm());
294 bias_errors.push_back(x: bias_error.norm());
295 cal_errors.push_back(x: cal_error.norm());
296
297 // Show progress dots
298 if (i % progressStep == 0) {
299 std::cout << "." << std::flush;
300 }
301 }
302
303 std::cout << " Done!" << std::endl;
304
305 auto end = std::chrono::high_resolution_clock::now();
306 std::chrono::duration<double> elapsed = end - start;
307
308 // Calculate average errors
309 double avg_att_error = 0.0;
310 double avg_bias_error = 0.0;
311 double avg_cal_error = 0.0;
312
313 if (!att_errors.empty()) {
314 avg_att_error = std::accumulate(first: att_errors.begin(), last: att_errors.end(), init: 0.0) /
315 att_errors.size();
316 avg_bias_error =
317 std::accumulate(first: bias_errors.begin(), last: bias_errors.end(), init: 0.0) /
318 bias_errors.size();
319 avg_cal_error = std::accumulate(first: cal_errors.begin(), last: cal_errors.end(), init: 0.0) /
320 cal_errors.size();
321 }
322
323 // Calculate final errors from last data point
324 const Data& final_data = data_list.back();
325 M final_estimate = filter.stateEstimate();
326 Vector3 final_att_error =
327 Rot3::Logmap(R: final_data.xi.R.between(g: final_estimate.R));
328 Vector3 final_bias_error = final_estimate.b - final_data.xi.b;
329 Vector3 final_cal_error = Vector3::Zero();
330 if (!final_data.xi.S.empty() && !final_estimate.S.empty()) {
331 final_cal_error =
332 Rot3::Logmap(R: final_data.xi.S[0].between(g: final_estimate.S[0]));
333 }
334
335 // Print summary statistics
336 std::cout << "\n=== Filter Performance Summary ===" << std::endl;
337 std::cout << "Processing time: " << elapsed.count() << " seconds"
338 << std::endl;
339 std::cout << "Processed measurements: " << totalMeasurements
340 << " (valid: " << validMeasurements << ")" << std::endl;
341
342 // Average errors
343 std::cout << "\n-- Average Errors --" << std::endl;
344 std::cout << "Attitude: " << (avg_att_error * RAD_TO_DEG) << "°" << std::endl;
345 std::cout << "Bias: " << avg_bias_error << std::endl;
346 std::cout << "Calibration: " << (avg_cal_error * RAD_TO_DEG) << "°"
347 << std::endl;
348
349 // Final errors
350 std::cout << "\n-- Final Errors --" << std::endl;
351 std::cout << "Attitude: " << (final_att_error.norm() * RAD_TO_DEG) << "°"
352 << std::endl;
353 std::cout << "Bias: " << final_bias_error.norm() << std::endl;
354 std::cout << "Calibration: " << (final_cal_error.norm() * RAD_TO_DEG) << "°"
355 << std::endl;
356
357 // Print a brief comparison of final estimate vs ground truth
358 std::cout << "\n-- Final State vs Ground Truth --" << std::endl;
359 std::cout << "Attitude (RPY) - Estimate: "
360 << (final_estimate.R.rpy() * RAD_TO_DEG).transpose()
361 << "° | Truth: " << (final_data.xi.R.rpy() * RAD_TO_DEG).transpose()
362 << "°" << std::endl;
363 std::cout << "Bias - Estimate: " << final_estimate.b.transpose()
364 << " | Truth: " << final_data.xi.b.transpose() << std::endl;
365
366 if (!final_estimate.S.empty() && !final_data.xi.S.empty()) {
367 std::cout << "Calibration (RPY) - Estimate: "
368 << (final_estimate.S[0].rpy() * RAD_TO_DEG).transpose()
369 << "° | Truth: "
370 << (final_data.xi.S[0].rpy() * RAD_TO_DEG).transpose() << "°"
371 << std::endl;
372 }
373}
374
375int main(int argc, char* argv[]) {
376 std::cout << "ABC-EqF: Attitude-Bias-Calibration Equivariant Filter Demo"
377 << std::endl;
378 std::cout << "=============================================================="
379 << std::endl;
380
381 try {
382 // Parse command line options
383 std::string csvFilePath;
384 int maxRows = -1; // Process all rows by default
385 int downsample = 1; // No downsampling by default
386
387 if (argc > 1) {
388 csvFilePath = argv[1];
389 } else {
390 // Try to find the EQFdata file in the GTSAM examples directory
391 try {
392 csvFilePath = findExampleDataFile(name: "EqFdata.csv");
393 } catch (const std::exception& e) {
394 std::cerr << "Error: Could not find EqFdata.csv" << std::endl;
395 std::cerr << "Usage: " << argv[0]
396 << " [csv_file_path] [max_rows] [downsample]" << std::endl;
397 return 1;
398 }
399 }
400
401 // Optional command line parameters
402 if (argc > 2) {
403 maxRows = std::stoi(str: argv[2]);
404 }
405
406 if (argc > 3) {
407 downsample = std::stoi(str: argv[3]);
408 }
409
410 // Load data from CSV file
411 std::vector<Data> data =
412 loadDataFromCSV(filename: csvFilePath, startRow: 0, maxRows, downsample);
413
414 if (data.empty()) {
415 std::cerr << "No data available to process. Exiting." << std::endl;
416 return 1;
417 }
418
419 // Initialize the EqF filter with one calibration state
420 int n_sensors = 2;
421
422 // Initial covariance - larger values allow faster convergence
423 Matrix initialSigma = Matrix::Identity(rows: 6 + 3 * N, cols: 6 + 3 * N);
424 initialSigma.diagonal().head<3>() =
425 Vector3::Constant(value: 0.1); // Attitude uncertainty
426 initialSigma.diagonal().segment<3>(start: 3) =
427 Vector3::Constant(value: 0.01); // Bias uncertainty
428 initialSigma.diagonal().tail<3>() =
429 Vector3::Constant(value: 0.1); // Calibration uncertainty
430
431 // Create filter
432 EqFilter filter(initialSigma, n_sensors);
433
434 // Process data
435 processDataWithEqF(filter, data_list: data);
436
437 } catch (const std::exception& e) {
438 std::cerr << "Error: " << e.what() << std::endl;
439 return 1;
440 }
441
442 std::cout << "\nEqF demonstration completed successfully." << std::endl;
443 return 0;
444}