| 1 | /** |
| 2 | * @file ABC_EQF_Demo.cpp |
| 3 | * @brief Demonstration of the full Attitude-Bias-Calibration Equivariant Filter |
| 4 | * |
| 5 | * This demo shows the Equivariant Filter (EqF) for attitude estimation |
| 6 | * with both gyroscope bias and sensor extrinsic calibration, based on the |
| 7 | * paper: "Overcoming Bias: Equivariant Filter Design for Biased Attitude |
| 8 | * Estimation with Online Calibration" by Fornasier et al. Authors: Darshan |
| 9 | * Rajasekaran & Jennifer Oum |
| 10 | */ |
| 11 | |
| 12 | #include "ABC_EQF.h" |
| 13 | |
| 14 | // Use namespace for convenience |
| 15 | using namespace gtsam; |
| 16 | constexpr size_t N = 1; // Number of calibration states |
| 17 | using M = abc_eqf_lib::State<N>; |
| 18 | using Group = abc_eqf_lib::G<N>; |
| 19 | using EqFilter = abc_eqf_lib::EqF<N>; |
| 20 | using gtsam::abc_eqf_lib::EqF; |
| 21 | using gtsam::abc_eqf_lib::Input; |
| 22 | using gtsam::abc_eqf_lib::Measurement; |
| 23 | |
| 24 | /// Data structure for ground-truth, input and output data |
| 25 | struct Data { |
| 26 | M xi; /// Ground-truth state |
| 27 | Input u; /// Input measurements |
| 28 | std::vector<Measurement> y; /// Output measurements |
| 29 | int n_meas; /// Number of measurements |
| 30 | double t; /// Time |
| 31 | double dt; /// Time step |
| 32 | }; |
| 33 | |
| 34 | //======================================================================== |
| 35 | // Data Processing Functions |
| 36 | //======================================================================== |
| 37 | |
| 38 | /** |
| 39 | * Load data from CSV file into a vector of Data objects for the EqF |
| 40 | * |
| 41 | * CSV format: |
| 42 | * - t: Time |
| 43 | * - q_w, q_x, q_y, q_z: True attitude quaternion |
| 44 | * - b_x, b_y, b_z: True bias |
| 45 | * - cq_w_0, cq_x_0, cq_y_0, cq_z_0: True calibration quaternion |
| 46 | * - w_x, w_y, w_z: Angular velocity measurements |
| 47 | * - std_w_x, std_w_y, std_w_z: Angular velocity measurement standard deviations |
| 48 | * - std_b_x, std_b_y, std_b_z: Bias process noise standard deviations |
| 49 | * - y_x_0, y_y_0, y_z_0, y_x_1, y_y_1, y_z_1: Direction measurements |
| 50 | * - std_y_x_0, std_y_y_0, std_y_z_0, std_y_x_1, std_y_y_1, std_y_z_1: Direction |
| 51 | * measurement standard deviations |
| 52 | * - d_x_0, d_y_0, d_z_0, d_x_1, d_y_1, d_z_1: Reference directions |
| 53 | * |
| 54 | */ |
| 55 | std::vector<Data> loadDataFromCSV(const std::string& filename, int startRow = 0, |
| 56 | int maxRows = -1, int downsample = 1); |
| 57 | |
| 58 | /// Process data with EqF and print summary results |
| 59 | void processDataWithEqF(EqFilter& filter, const std::vector<Data>& data_list, |
| 60 | int printInterval = 10); |
| 61 | |
| 62 | //======================================================================== |
| 63 | // Data Processing Functions Implementation |
| 64 | //======================================================================== |
| 65 | |
| 66 | /* |
| 67 | * Loads the test data from the csv file |
| 68 | * startRow First row to load based on csv, 0 by default |
| 69 | * maxRows maximum rows to load, defaults to all rows |
| 70 | * downsample Downsample factor, default 1 |
| 71 | * A list of data objects |
| 72 | */ |
| 73 | |
| 74 | std::vector<Data> loadDataFromCSV(const std::string& filename, int startRow, |
| 75 | int maxRows, int downsample) { |
| 76 | std::vector<Data> data_list; |
| 77 | std::ifstream file(filename); |
| 78 | |
| 79 | if (!file.is_open()) { |
| 80 | throw std::runtime_error("Failed to open file: " + filename); |
| 81 | } |
| 82 | |
| 83 | std::cout << "Loading data from " << filename << "..." << std::flush; |
| 84 | |
| 85 | std::string line; |
| 86 | int lineNumber = 0; |
| 87 | int rowCount = 0; |
| 88 | int errorCount = 0; |
| 89 | double prevTime = 0.0; |
| 90 | |
| 91 | // Skip header |
| 92 | std::getline(is&: file, str&: line); |
| 93 | lineNumber++; |
| 94 | |
| 95 | // Skip to startRow |
| 96 | while (lineNumber < startRow && std::getline(is&: file, str&: line)) { |
| 97 | lineNumber++; |
| 98 | } |
| 99 | |
| 100 | // Read data |
| 101 | while (std::getline(is&: file, str&: line) && (maxRows == -1 || rowCount < maxRows)) { |
| 102 | lineNumber++; |
| 103 | |
| 104 | // Apply downsampling |
| 105 | if ((lineNumber - startRow - 1) % downsample != 0) { |
| 106 | continue; |
| 107 | } |
| 108 | |
| 109 | std::istringstream ss(line); |
| 110 | std::string token; |
| 111 | std::vector<double> values; |
| 112 | |
| 113 | // Parse line into values |
| 114 | while (std::getline(in&: ss, str&: token, delim: ',')) { |
| 115 | try { |
| 116 | values.push_back(x: std::stod(str: token)); |
| 117 | } catch (const std::exception& e) { |
| 118 | errorCount++; |
| 119 | values.push_back(x: 0.0); // Use default value |
| 120 | } |
| 121 | } |
| 122 | |
| 123 | // Check if we have enough values |
| 124 | if (values.size() < 39) { |
| 125 | errorCount++; |
| 126 | continue; |
| 127 | } |
| 128 | |
| 129 | // Extract values |
| 130 | double t = values[0]; |
| 131 | double dt = (rowCount == 0) ? 0.0 : t - prevTime; |
| 132 | prevTime = t; |
| 133 | |
| 134 | // Create ground truth state |
| 135 | Quaternion quat(values[1], values[2], values[3], values[4]); // w, x, y, z |
| 136 | Rot3 R = Rot3(quat); |
| 137 | |
| 138 | Vector3 b(values[5], values[6], values[7]); |
| 139 | |
| 140 | Quaternion calQuat(values[8], values[9], values[10], |
| 141 | values[11]); // w, x, y, z |
| 142 | std::array<Rot3, N> S = {Rot3(calQuat)}; |
| 143 | |
| 144 | M xi(R, b, S); |
| 145 | |
| 146 | // Create input |
| 147 | Vector3 w(values[12], values[13], values[14]); |
| 148 | |
| 149 | // Create input covariance matrix (6x6) |
| 150 | // First 3x3 block for angular velocity, second 3x3 block for bias process |
| 151 | // noise |
| 152 | Matrix inputCov = Matrix::Zero(rows: 6, cols: 6); |
| 153 | inputCov(0, 0) = values[15] * values[15]; // std_w_x^2 |
| 154 | inputCov(1, 1) = values[16] * values[16]; // std_w_y^2 |
| 155 | inputCov(2, 2) = values[17] * values[17]; // std_w_z^2 |
| 156 | inputCov(3, 3) = values[18] * values[18]; // std_b_x^2 |
| 157 | inputCov(4, 4) = values[19] * values[19]; // std_b_y^2 |
| 158 | inputCov(5, 5) = values[20] * values[20]; // std_b_z^2 |
| 159 | |
| 160 | Input u{.w: w, .Sigma: inputCov}; |
| 161 | |
| 162 | // Create measurements |
| 163 | std::vector<Measurement> measurements; |
| 164 | |
| 165 | // First measurement (calibrated sensor, cal_idx = 0) |
| 166 | Vector3 y0(values[21], values[22], values[23]); |
| 167 | Vector3 d0(values[33], values[34], values[35]); |
| 168 | |
| 169 | // Normalize vectors if needed |
| 170 | if (abs(x: y0.norm() - 1.0) > 1e-5) y0.normalize(); |
| 171 | if (abs(x: d0.norm() - 1.0) > 1e-5) d0.normalize(); |
| 172 | |
| 173 | // Measurement covariance |
| 174 | Matrix3 covY0 = Matrix3::Zero(); |
| 175 | covY0(0, 0) = values[27] * values[27]; // std_y_x_0^2 |
| 176 | covY0(1, 1) = values[28] * values[28]; // std_y_y_0^2 |
| 177 | covY0(2, 2) = values[29] * values[29]; // std_y_z_0^2 |
| 178 | |
| 179 | // Create measurement |
| 180 | measurements.push_back(x: Measurement{.y: Unit3(y0), .d: Unit3(d0), .Sigma: covY0, .cal_idx: 0}); |
| 181 | |
| 182 | // Second measurement (calibrated sensor, cal_idx = -1) |
| 183 | Vector3 y1(values[24], values[25], values[26]); |
| 184 | Vector3 d1(values[36], values[37], values[38]); |
| 185 | |
| 186 | // Normalize vectors if needed |
| 187 | if (abs(x: y1.norm() - 1.0) > 1e-5) y1.normalize(); |
| 188 | if (abs(x: d1.norm() - 1.0) > 1e-5) d1.normalize(); |
| 189 | |
| 190 | // Measurement covariance |
| 191 | Matrix3 covY1 = Matrix3::Zero(); |
| 192 | covY1(0, 0) = values[30] * values[30]; // std_y_x_1^2 |
| 193 | covY1(1, 1) = values[31] * values[31]; // std_y_y_1^2 |
| 194 | covY1(2, 2) = values[32] * values[32]; // std_y_z_1^2 |
| 195 | |
| 196 | // Create measurement |
| 197 | measurements.push_back(x: Measurement{.y: Unit3(y1), .d: Unit3(d1), .Sigma: covY1, .cal_idx: -1}); |
| 198 | |
| 199 | // Create Data object and add to list |
| 200 | data_list.push_back(x: Data{.xi: xi, .u: u, .y: measurements, .n_meas: 2, .t: t, .dt: dt}); |
| 201 | |
| 202 | rowCount++; |
| 203 | |
| 204 | // Show loading progress every 1000 rows |
| 205 | if (rowCount % 1000 == 0) { |
| 206 | std::cout << "." << std::flush; |
| 207 | } |
| 208 | } |
| 209 | |
| 210 | std::cout << " Done!" << std::endl; |
| 211 | std::cout << "Loaded " << data_list.size() << " data points" ; |
| 212 | |
| 213 | if (errorCount > 0) { |
| 214 | std::cout << " (" << errorCount << " errors encountered)" ; |
| 215 | } |
| 216 | |
| 217 | std::cout << std::endl; |
| 218 | |
| 219 | return data_list; |
| 220 | } |
| 221 | |
| 222 | /// Takes in the data and runs an EqF on it and reports the results |
| 223 | void processDataWithEqF(EqFilter& filter, const std::vector<Data>& data_list, |
| 224 | int printInterval) { |
| 225 | if (data_list.empty()) { |
| 226 | std::cerr << "No data to process" << std::endl; |
| 227 | return; |
| 228 | } |
| 229 | |
| 230 | std::cout << "Processing " << data_list.size() << " data points with EqF..." |
| 231 | << std::endl; |
| 232 | |
| 233 | // Track performance metrics |
| 234 | std::vector<double> att_errors; |
| 235 | std::vector<double> bias_errors; |
| 236 | std::vector<double> cal_errors; |
| 237 | |
| 238 | // Track time for performance measurement |
| 239 | auto start = std::chrono::high_resolution_clock::now(); |
| 240 | |
| 241 | int totalMeasurements = 0; |
| 242 | int validMeasurements = 0; |
| 243 | |
| 244 | // Define constant for converting radians to degrees |
| 245 | const double RAD_TO_DEG = 180.0 / M_PI; |
| 246 | |
| 247 | // Print a progress indicator |
| 248 | int progressStep = data_list.size() / 10; // 10 progress updates |
| 249 | if (progressStep < 1) progressStep = 1; |
| 250 | |
| 251 | std::cout << "Progress: " ; |
| 252 | |
| 253 | for (size_t i = 0; i < data_list.size(); i++) { |
| 254 | const Data& data = data_list[i]; |
| 255 | |
| 256 | // Propagate filter with current input and time step |
| 257 | filter.propagation(u: data.u, dt: data.dt); |
| 258 | |
| 259 | // Process all measurements |
| 260 | for (const auto& y : data.y) { |
| 261 | totalMeasurements++; |
| 262 | |
| 263 | // Skip invalid measurements |
| 264 | Vector3 y_vec = y.y.unitVector(); |
| 265 | Vector3 d_vec = y.d.unitVector(); |
| 266 | if (std::isnan(x: y_vec[0]) || std::isnan(x: y_vec[1]) || |
| 267 | std::isnan(x: y_vec[2]) || std::isnan(x: d_vec[0]) || |
| 268 | std::isnan(x: d_vec[1]) || std::isnan(x: d_vec[2])) { |
| 269 | continue; |
| 270 | } |
| 271 | |
| 272 | try { |
| 273 | filter.update(y); |
| 274 | validMeasurements++; |
| 275 | } catch (const std::exception& e) { |
| 276 | std::cerr << "Error updating at t=" << data.t << ": " << e.what() |
| 277 | << std::endl; |
| 278 | } |
| 279 | } |
| 280 | |
| 281 | // Get current state estimate |
| 282 | M estimate = filter.stateEstimate(); |
| 283 | |
| 284 | // Calculate errors |
| 285 | Vector3 att_error = Rot3::Logmap(R: data.xi.R.between(g: estimate.R)); |
| 286 | Vector3 bias_error = estimate.b - data.xi.b; |
| 287 | Vector3 cal_error = Vector3::Zero(); |
| 288 | if (!data.xi.S.empty() && !estimate.S.empty()) { |
| 289 | cal_error = Rot3::Logmap(R: data.xi.S[0].between(g: estimate.S[0])); |
| 290 | } |
| 291 | |
| 292 | // Store errors |
| 293 | att_errors.push_back(x: att_error.norm()); |
| 294 | bias_errors.push_back(x: bias_error.norm()); |
| 295 | cal_errors.push_back(x: cal_error.norm()); |
| 296 | |
| 297 | // Show progress dots |
| 298 | if (i % progressStep == 0) { |
| 299 | std::cout << "." << std::flush; |
| 300 | } |
| 301 | } |
| 302 | |
| 303 | std::cout << " Done!" << std::endl; |
| 304 | |
| 305 | auto end = std::chrono::high_resolution_clock::now(); |
| 306 | std::chrono::duration<double> elapsed = end - start; |
| 307 | |
| 308 | // Calculate average errors |
| 309 | double avg_att_error = 0.0; |
| 310 | double avg_bias_error = 0.0; |
| 311 | double avg_cal_error = 0.0; |
| 312 | |
| 313 | if (!att_errors.empty()) { |
| 314 | avg_att_error = std::accumulate(first: att_errors.begin(), last: att_errors.end(), init: 0.0) / |
| 315 | att_errors.size(); |
| 316 | avg_bias_error = |
| 317 | std::accumulate(first: bias_errors.begin(), last: bias_errors.end(), init: 0.0) / |
| 318 | bias_errors.size(); |
| 319 | avg_cal_error = std::accumulate(first: cal_errors.begin(), last: cal_errors.end(), init: 0.0) / |
| 320 | cal_errors.size(); |
| 321 | } |
| 322 | |
| 323 | // Calculate final errors from last data point |
| 324 | const Data& final_data = data_list.back(); |
| 325 | M final_estimate = filter.stateEstimate(); |
| 326 | Vector3 final_att_error = |
| 327 | Rot3::Logmap(R: final_data.xi.R.between(g: final_estimate.R)); |
| 328 | Vector3 final_bias_error = final_estimate.b - final_data.xi.b; |
| 329 | Vector3 final_cal_error = Vector3::Zero(); |
| 330 | if (!final_data.xi.S.empty() && !final_estimate.S.empty()) { |
| 331 | final_cal_error = |
| 332 | Rot3::Logmap(R: final_data.xi.S[0].between(g: final_estimate.S[0])); |
| 333 | } |
| 334 | |
| 335 | // Print summary statistics |
| 336 | std::cout << "\n=== Filter Performance Summary ===" << std::endl; |
| 337 | std::cout << "Processing time: " << elapsed.count() << " seconds" |
| 338 | << std::endl; |
| 339 | std::cout << "Processed measurements: " << totalMeasurements |
| 340 | << " (valid: " << validMeasurements << ")" << std::endl; |
| 341 | |
| 342 | // Average errors |
| 343 | std::cout << "\n-- Average Errors --" << std::endl; |
| 344 | std::cout << "Attitude: " << (avg_att_error * RAD_TO_DEG) << "°" << std::endl; |
| 345 | std::cout << "Bias: " << avg_bias_error << std::endl; |
| 346 | std::cout << "Calibration: " << (avg_cal_error * RAD_TO_DEG) << "°" |
| 347 | << std::endl; |
| 348 | |
| 349 | // Final errors |
| 350 | std::cout << "\n-- Final Errors --" << std::endl; |
| 351 | std::cout << "Attitude: " << (final_att_error.norm() * RAD_TO_DEG) << "°" |
| 352 | << std::endl; |
| 353 | std::cout << "Bias: " << final_bias_error.norm() << std::endl; |
| 354 | std::cout << "Calibration: " << (final_cal_error.norm() * RAD_TO_DEG) << "°" |
| 355 | << std::endl; |
| 356 | |
| 357 | // Print a brief comparison of final estimate vs ground truth |
| 358 | std::cout << "\n-- Final State vs Ground Truth --" << std::endl; |
| 359 | std::cout << "Attitude (RPY) - Estimate: " |
| 360 | << (final_estimate.R.rpy() * RAD_TO_DEG).transpose() |
| 361 | << "° | Truth: " << (final_data.xi.R.rpy() * RAD_TO_DEG).transpose() |
| 362 | << "°" << std::endl; |
| 363 | std::cout << "Bias - Estimate: " << final_estimate.b.transpose() |
| 364 | << " | Truth: " << final_data.xi.b.transpose() << std::endl; |
| 365 | |
| 366 | if (!final_estimate.S.empty() && !final_data.xi.S.empty()) { |
| 367 | std::cout << "Calibration (RPY) - Estimate: " |
| 368 | << (final_estimate.S[0].rpy() * RAD_TO_DEG).transpose() |
| 369 | << "° | Truth: " |
| 370 | << (final_data.xi.S[0].rpy() * RAD_TO_DEG).transpose() << "°" |
| 371 | << std::endl; |
| 372 | } |
| 373 | } |
| 374 | |
| 375 | int main(int argc, char* argv[]) { |
| 376 | std::cout << "ABC-EqF: Attitude-Bias-Calibration Equivariant Filter Demo" |
| 377 | << std::endl; |
| 378 | std::cout << "==============================================================" |
| 379 | << std::endl; |
| 380 | |
| 381 | try { |
| 382 | // Parse command line options |
| 383 | std::string csvFilePath; |
| 384 | int maxRows = -1; // Process all rows by default |
| 385 | int downsample = 1; // No downsampling by default |
| 386 | |
| 387 | if (argc > 1) { |
| 388 | csvFilePath = argv[1]; |
| 389 | } else { |
| 390 | // Try to find the EQFdata file in the GTSAM examples directory |
| 391 | try { |
| 392 | csvFilePath = findExampleDataFile(name: "EqFdata.csv" ); |
| 393 | } catch (const std::exception& e) { |
| 394 | std::cerr << "Error: Could not find EqFdata.csv" << std::endl; |
| 395 | std::cerr << "Usage: " << argv[0] |
| 396 | << " [csv_file_path] [max_rows] [downsample]" << std::endl; |
| 397 | return 1; |
| 398 | } |
| 399 | } |
| 400 | |
| 401 | // Optional command line parameters |
| 402 | if (argc > 2) { |
| 403 | maxRows = std::stoi(str: argv[2]); |
| 404 | } |
| 405 | |
| 406 | if (argc > 3) { |
| 407 | downsample = std::stoi(str: argv[3]); |
| 408 | } |
| 409 | |
| 410 | // Load data from CSV file |
| 411 | std::vector<Data> data = |
| 412 | loadDataFromCSV(filename: csvFilePath, startRow: 0, maxRows, downsample); |
| 413 | |
| 414 | if (data.empty()) { |
| 415 | std::cerr << "No data available to process. Exiting." << std::endl; |
| 416 | return 1; |
| 417 | } |
| 418 | |
| 419 | // Initialize the EqF filter with one calibration state |
| 420 | int n_sensors = 2; |
| 421 | |
| 422 | // Initial covariance - larger values allow faster convergence |
| 423 | Matrix initialSigma = Matrix::Identity(rows: 6 + 3 * N, cols: 6 + 3 * N); |
| 424 | initialSigma.diagonal().head<3>() = |
| 425 | Vector3::Constant(value: 0.1); // Attitude uncertainty |
| 426 | initialSigma.diagonal().segment<3>(start: 3) = |
| 427 | Vector3::Constant(value: 0.01); // Bias uncertainty |
| 428 | initialSigma.diagonal().tail<3>() = |
| 429 | Vector3::Constant(value: 0.1); // Calibration uncertainty |
| 430 | |
| 431 | // Create filter |
| 432 | EqFilter filter(initialSigma, n_sensors); |
| 433 | |
| 434 | // Process data |
| 435 | processDataWithEqF(filter, data_list: data); |
| 436 | |
| 437 | } catch (const std::exception& e) { |
| 438 | std::cerr << "Error: " << e.what() << std::endl; |
| 439 | return 1; |
| 440 | } |
| 441 | |
| 442 | std::cout << "\nEqF demonstration completed successfully." << std::endl; |
| 443 | return 0; |
| 444 | } |