diff --git a/libraries/AP_NavEKF3/derivation/code_gen.py b/libraries/AP_NavEKF3/derivation/code_gen.py index 8aa14500b0..b17d359df9 100644 --- a/libraries/AP_NavEKF3/derivation/code_gen.py +++ b/libraries/AP_NavEKF3/derivation/code_gen.py @@ -9,11 +9,22 @@ class CodeGenerator: self.file_name = file_name self.file = open(self.file_name, 'w') + # custom SymPy -> C function mappings. note that at least one entry must + # match, the last entry will always be used if none match! + self._custom_funcs = { + "Pow": [ + (lambda b, e: e == 2, lambda b, e: f"sq({b})"), # use square function for b**2 + (lambda b, e: e == -1, lambda b, e: f"1.0F/({b})"), # inverse + (lambda b, e: e == -2, lambda b, e: f"1.0F/sq({b})"), # inverse square + (lambda b, e: True, "powf"), # otherwise use default powf + ], + } + def print_string(self, string): self.file.write("// " + string + "\n") def get_ccode(self, expression): - return ccode(expression, type_aliases={real:float32}) + return ccode(expression, type_aliases={real:float32}, user_functions=self._custom_funcs) def write_subexpressions(self,subexpressions): write_string = ""