ekf2: Compute stapdown INS propagation using closed-form solution

This commit is contained in:
bresch 2023-10-30 14:22:55 +01:00
parent 0d6c2c8ce9
commit 5f02c21d73
3 changed files with 198 additions and 17 deletions

View File

@ -42,6 +42,7 @@
#include "ekf.h"
#include <mathlib/mathlib.h>
#include <ekf_derivation/generated/predict_vel_pos_closed_form.h>
bool Ekf::init(uint64_t timestamp)
{
@ -289,29 +290,18 @@ void Ekf::predictState(const imuSample &imu_delayed)
// subtract component of angular rate due to earth rotation
corrected_delta_ang -= _R_to_earth.transpose() * _earth_rate_NED * imu_delayed.delta_ang_dt;
// Calculate an earth frame delta velocity
const Vector3f delta_vel_bias_scaled = getAccelBias() * imu_delayed.delta_vel_dt;
const Vector3f corrected_delta_vel = imu_delayed.delta_vel - delta_vel_bias_scaled;
sym::PredictVelPosClosedForm(_state.vector(), corrected_delta_vel, imu_delayed.delta_vel_dt, corrected_delta_ang, imu_delayed.delta_ang_dt, CONSTANTS_ONE_G, FLT_EPSILON, &_state.vel, &_state.pos);
const Quatf dq(AxisAnglef{corrected_delta_ang});
// rotate the previous quaternion by the delta quaternion using a quaternion multiplication
_state.quat_nominal = (_state.quat_nominal * dq).normalized();
_R_to_earth = Dcmf(_state.quat_nominal);
// Calculate an earth frame delta velocity
const Vector3f delta_vel_bias_scaled = getAccelBias() * imu_delayed.delta_vel_dt;
const Vector3f corrected_delta_vel = imu_delayed.delta_vel - delta_vel_bias_scaled;
const Vector3f corrected_delta_vel_ef = _R_to_earth * corrected_delta_vel;
// save the previous value of velocity so we can use trapzoidal integration
const Vector3f vel_last = _state.vel;
// calculate the increment in velocity using the current orientation
_state.vel += corrected_delta_vel_ef;
// compensate for acceleration due to gravity
_state.vel(2) += CONSTANTS_ONE_G * imu_delayed.delta_vel_dt;
// predict position states via trapezoidal integration of velocity
_state.pos += (vel_last + _state.vel) * imu_delayed.delta_vel_dt * 0.5f;
constrainStates();
@ -326,6 +316,7 @@ void Ekf::predictState(const imuSample &imu_delayed)
// calculate a filtered horizontal acceleration with a 1 sec time constant
// this are used for manoeuvre detection elsewhere
const float alpha = 1.0f - imu_delayed.delta_vel_dt;
const Vector3f corrected_delta_vel_ef = _R_to_earth * corrected_delta_vel;
_accel_lpf_NE = _accel_lpf_NE * alpha + corrected_delta_vel_ef.xy();
// calculate a yaw change about the earth frame vertical

View File

@ -205,6 +205,63 @@ def predict_covariance(
return P_new
def predict_vel_pos_closed_form(
state: VState,
d_vel: sf.V3,
d_vel_dt: sf.Scalar,
d_ang: sf.V3,
d_ang_dt: sf.Scalar,
g: sf.Scalar,
epsilon: sf.Scalar
) -> (sf.V3, sf.V3):
# Closed-form integration of accelerometer and gyro measurements based on
# Goppert, James, et al. "A Closed-form Solution for the Strapdown Inertial Navigation Initial Value Problem." arXiv preprint arXiv:2310.04886 (2023).
# TODO: check which dt to use
state = vstate_to_state(state)
gyro = d_ang / d_ang_dt
accel = d_vel / d_vel_dt
dt = d_vel_dt
R_0 = state["quat_nominal"].to_rotation_matrix()
P_0 = sf.M32.block_matrix([[state["vel"], state["pos"]]])
R_l_prime = sf.Rot3.from_tangent(d_ang).to_rotation_matrix()
R_r_prime = sf.M33.eye()
A_M = sf.M32([
[0, 0],
[0, 0],
[g, 0]
])
A_N = sf.M32.block_matrix([[accel, sf.V3()]])
B = sf.M22([
[0, 1],
[0, 0]
])
def P(omega, A, B) -> sf.M32:
theta = (gyro.dot(gyro))**0.5
C1 = (1 - theta**2 / 2 - sf.cos(theta)) / sf.Max(theta**2, epsilon)
C2 = (theta - sf.sin(theta)) / sf.Max(theta**3, epsilon)
C3 = (theta**2 / 2 - theta**4 / 24 + sf.cos(theta) - 1) / sf.Max(theta**4, epsilon)
P = A + (A * B) / 2
P += omega * A * (C1 * sf.M22.eye() + C2 * B)
P += omega * omega * A * (C2 * sf.M22.eye() + C3 * B)
return P
P_M = P(sf.M33(), A_M * dt, -B * dt)
P_N = P(sf.M33.skew_symmetric(gyro * dt), A_N * dt, B * dt)
P_new = R_r_prime * R_0 * P_N + (R_r_prime * P_0 + P_M) * (sf.M22.eye() + B * dt)
# The attitude propagation can be computed as follows, but since the result is simple,
# it is directly implemented in the code.
# R_new = R_r_prime * R_0 * R_l_prime
# q_xyzw = sf.Rot3.from_rotation_matrix(R_new).to_storage()
# q_wxyz = sf.V4(q_xyzw[3], q_xyzw[0], q_xyzw[1], q_xyzw[2])
return (P_new.col(0), P_new.col(1))
def compute_airspeed_innov_and_innov_var(
state: VState,
P: MTangent,
@ -652,5 +709,6 @@ generate_px4_function(compute_flow_xy_innov_var_and_hx, output_names=["innov_var
generate_px4_function(compute_flow_y_innov_var_and_h, output_names=["innov_var", "H"])
generate_px4_function(compute_gnss_yaw_pred_innov_var_and_h, output_names=["meas_pred", "innov_var", "H"])
generate_px4_function(compute_gravity_innov_var_and_k_and_h, output_names=["innov", "innov_var", "Kx", "Ky", "Kz"])
generate_px4_function(predict_vel_pos_closed_form, output_names=["v_new", "p_new"])
generate_px4_state(State, tangent_idx)

View File

@ -0,0 +1,132 @@
// -----------------------------------------------------------------------------
// This file was autogenerated by symforce from template:
// function/FUNCTION.h.jinja
// Do NOT modify by hand.
// -----------------------------------------------------------------------------
#pragma once
#include <matrix/math.hpp>
namespace sym {
/**
* This function was autogenerated from a symbolic function. Do not modify by hand.
*
* Symbolic function: predict_vel_pos_closed_form
*
* Args:
* state: Matrix24_1
* d_vel: Matrix31
* d_vel_dt: Scalar
* d_ang: Matrix31
* d_ang_dt: Scalar
* g: Scalar
* epsilon: Scalar
*
* Outputs:
* v_new: Matrix31
* p_new: Matrix31
*/
template <typename Scalar>
void PredictVelPosClosedForm(const matrix::Matrix<Scalar, 24, 1>& state,
const matrix::Matrix<Scalar, 3, 1>& d_vel, const Scalar d_vel_dt,
const matrix::Matrix<Scalar, 3, 1>& d_ang, const Scalar d_ang_dt,
const Scalar g, const Scalar epsilon,
matrix::Matrix<Scalar, 3, 1>* const v_new = nullptr,
matrix::Matrix<Scalar, 3, 1>* const p_new = nullptr) {
// Total ops: 179
// Input arrays
// Intermediate terms (61)
const Scalar _tmp0 = 2 * state(0, 0) * state(3, 0);
const Scalar _tmp1 = 2 * state(2, 0);
const Scalar _tmp2 = _tmp1 * state(1, 0);
const Scalar _tmp3 = -_tmp0 + _tmp2;
const Scalar _tmp4 = d_vel_dt / d_ang_dt;
const Scalar _tmp5 = d_ang(0, 0) * d_vel(2, 0);
const Scalar _tmp6 = d_ang(2, 0) * d_vel(0, 0);
const Scalar _tmp7 = -_tmp4 * _tmp5 + _tmp4 * _tmp6;
const Scalar _tmp8 = std::pow(d_ang_dt, Scalar(-2));
const Scalar _tmp9 = _tmp8 * std::pow(d_ang(0, 0), Scalar(2));
const Scalar _tmp10 = _tmp8 * std::pow(d_ang(2, 0), Scalar(2));
const Scalar _tmp11 = _tmp8 * std::pow(d_ang(1, 0), Scalar(2));
const Scalar _tmp12 = _tmp10 + _tmp11 + _tmp9;
const Scalar _tmp13 = std::pow(_tmp12, Scalar(Scalar(1.0)));
const Scalar _tmp14 = std::sqrt(_tmp12);
const Scalar _tmp15 = std::cos(_tmp14);
const Scalar _tmp16 = (Scalar(1) / Scalar(2)) * _tmp13;
const Scalar _tmp17 = (-_tmp15 - _tmp16 + 1) / math::max<Scalar>(_tmp13, epsilon);
const Scalar _tmp18 = std::pow(d_vel_dt, Scalar(2));
const Scalar _tmp19 = _tmp18 * _tmp8;
const Scalar _tmp20 = _tmp19 * d_ang(1, 0);
const Scalar _tmp21 = _tmp20 * d_ang(2, 0);
const Scalar _tmp22 = _tmp20 * d_ang(0, 0);
const Scalar _tmp23 = -_tmp10 * _tmp18;
const Scalar _tmp24 = -_tmp18 * _tmp9;
const Scalar _tmp25 =
_tmp21 * d_vel(2, 0) + _tmp22 * d_vel(0, 0) + d_vel(1, 0) * (_tmp23 + _tmp24);
const Scalar _tmp26 =
(_tmp14 - std::sin(_tmp14)) / math::max<Scalar>(epsilon, (_tmp12 * std::sqrt(_tmp12)));
const Scalar _tmp27 = _tmp17 * _tmp7 + _tmp25 * _tmp26 + d_vel(1, 0);
const Scalar _tmp28 = -2 * std::pow(state(3, 0), Scalar(2));
const Scalar _tmp29 = 1 - 2 * std::pow(state(2, 0), Scalar(2));
const Scalar _tmp30 = _tmp28 + _tmp29;
const Scalar _tmp31 = -_tmp11 * _tmp18;
const Scalar _tmp32 =
_tmp19 * _tmp5 * d_ang(2, 0) + _tmp22 * d_vel(1, 0) + d_vel(0, 0) * (_tmp23 + _tmp31);
const Scalar _tmp33 = _tmp4 * d_ang(1, 0);
const Scalar _tmp34 = _tmp4 * d_vel(1, 0);
const Scalar _tmp35 = _tmp33 * d_vel(2, 0) - _tmp34 * d_ang(2, 0);
const Scalar _tmp36 = _tmp17 * _tmp35 + _tmp26 * _tmp32 + d_vel(0, 0);
const Scalar _tmp37 = _tmp1 * state(0, 0);
const Scalar _tmp38 = 2 * state(1, 0);
const Scalar _tmp39 = _tmp38 * state(3, 0);
const Scalar _tmp40 = _tmp37 + _tmp39;
const Scalar _tmp41 =
_tmp19 * _tmp6 * d_ang(0, 0) + _tmp21 * d_vel(1, 0) + d_vel(2, 0) * (_tmp24 + _tmp31);
const Scalar _tmp42 = -_tmp33 * d_vel(0, 0) + _tmp34 * d_ang(0, 0);
const Scalar _tmp43 = _tmp17 * _tmp42 + _tmp26 * _tmp41 + d_vel(2, 0);
const Scalar _tmp44 = -2 * std::pow(state(1, 0), Scalar(2));
const Scalar _tmp45 = _tmp28 + _tmp44 + 1;
const Scalar _tmp46 = _tmp0 + _tmp2;
const Scalar _tmp47 = _tmp1 * state(3, 0);
const Scalar _tmp48 = _tmp38 * state(0, 0);
const Scalar _tmp49 = _tmp47 - _tmp48;
const Scalar _tmp50 = _tmp47 + _tmp48;
const Scalar _tmp51 = -_tmp37 + _tmp39;
const Scalar _tmp52 = _tmp29 + _tmp44;
const Scalar _tmp53 = d_vel_dt * g + state(6, 0);
const Scalar _tmp54 = (Scalar(1) / Scalar(2)) * d_vel_dt;
const Scalar _tmp55 = std::pow(_tmp12, Scalar(Scalar(2.0)));
const Scalar _tmp56 = d_vel_dt * (_tmp15 + _tmp16 - Scalar(1) / Scalar(24) * _tmp55 - 1) /
math::max<Scalar>(_tmp55, epsilon);
const Scalar _tmp57 = _tmp26 * d_vel_dt;
const Scalar _tmp58 = _tmp25 * _tmp56 + _tmp54 * d_vel(1, 0) + _tmp57 * _tmp7;
const Scalar _tmp59 = _tmp41 * _tmp56 + _tmp42 * _tmp57 + _tmp54 * d_vel(2, 0);
const Scalar _tmp60 = _tmp32 * _tmp56 + _tmp35 * _tmp57 + _tmp54 * d_vel(0, 0);
// Output terms (2)
if (v_new != nullptr) {
matrix::Matrix<Scalar, 3, 1>& _v_new = (*v_new);
_v_new(0, 0) = _tmp27 * _tmp3 + _tmp30 * _tmp36 + _tmp40 * _tmp43 + state(4, 0);
_v_new(1, 0) = _tmp27 * _tmp45 + _tmp36 * _tmp46 + _tmp43 * _tmp49 + state(5, 0);
_v_new(2, 0) = _tmp27 * _tmp50 + _tmp36 * _tmp51 + _tmp43 * _tmp52 + _tmp53;
}
if (p_new != nullptr) {
matrix::Matrix<Scalar, 3, 1>& _p_new = (*p_new);
_p_new(0, 0) =
_tmp3 * _tmp58 + _tmp30 * _tmp60 + _tmp40 * _tmp59 + d_vel_dt * state(4, 0) + state(7, 0);
_p_new(1, 0) =
_tmp45 * _tmp58 + _tmp46 * _tmp60 + _tmp49 * _tmp59 + d_vel_dt * state(5, 0) + state(8, 0);
_p_new(2, 0) = -Scalar(1) / Scalar(2) * _tmp18 * g + _tmp50 * _tmp58 + _tmp51 * _tmp60 +
_tmp52 * _tmp59 + _tmp53 * d_vel_dt + state(9, 0);
}
} // NOLINT(readability/fn_size)
// NOLINTNEXTLINE(readability/fn_size)
} // namespace sym