diff --git a/Tools/FilterTestTool/BiquadFilter.py b/Tools/FilterTestTool/BiquadFilter.py
new file mode 100755
index 0000000000..8811cc2167
--- /dev/null
+++ b/Tools/FilterTestTool/BiquadFilter.py
@@ -0,0 +1,201 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+""" ArduPilot BiquadFilter
+
+This program is free software: you can redistribute it and/or modify it under
+the terms of the GNU General Public License as published by the Free Software
+Foundation, either version 3 of the License, or (at your option) any later
+version.
+This program is distributed in the hope that it will be useful, but WITHOUT
+ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
+You should have received a copy of the GNU General Public License along with
+this program. If not, see .
+"""
+
+__author__ = "Guglielmo Cassinelli"
+__contact__ = "gdguglie@gmail.com"
+
+import numpy as np
+
+
+class DigitalLPF:
+ def __init__(self, cutoff_freq, sample_freq):
+ self._cutoff_freq = cutoff_freq
+ self._sample_freq = sample_freq
+
+ self._output = 0
+
+ self.compute_alpha()
+
+ def compute_alpha(self):
+
+ if self._cutoff_freq <= 0 or self._sample_freq <= 0:
+ self.alpha = 1.
+ else:
+ dt = 1. / self._sample_freq
+ rc = 1. / (np.pi * 2 * self._cutoff_freq)
+ a = dt / (dt + rc)
+ self.alpha = np.clip(a, 0, 1)
+
+ def apply(self, sample):
+ self._output += (sample - self._output) * self.alpha
+ return self._output
+
+
+class BiquadFilterType:
+ LPF = 0
+ PEAK = 1
+ NOTCH = 2
+
+
+class BiquadFilter:
+
+ def __init__(self, center_freq, sample_freq, type=BiquadFilterType.LPF, attenuation=10, bandwidth=15):
+ self._center_freq = int(center_freq)
+ self._attenuation_db = int(attenuation) # used only by notch, use setter
+ self._bandwidth_hz = int(bandwidth) # used only by notch, use setter
+
+ self._sample_freq = sample_freq
+ self._type = type
+
+ self._delayed_sample1 = 0
+ self._delayed_sample2 = 0
+ self._delayed_output1 = 0
+ self._delayed_output2 = 0
+
+ self.b0 = 0.
+ self.b1 = 0.
+ self.b2 = 0.
+ self.a0 = 1
+ self.a1 = 0.
+ self.a2 = 0.
+
+ self.compute_params()
+
+ def get_sample_freq(self):
+ return self._sample_freq
+
+ def reset(self):
+ self._delayed_sample1 = 0
+ self._delayed_sample2 = 0
+ self._delayed_output1 = 0
+ self._delayed_output2 = 0
+
+ def get_type(self):
+ return self._type
+
+ def set_attenuation(self, attenuation_db):
+ self._attenuation_db = int(attenuation_db)
+ self.compute_params()
+
+ def set_bandwidth(self, bandwidth_hz):
+ self._bandwidth_hz = int(bandwidth_hz)
+ self.compute_params()
+
+ def set_center_freq(self, cutoff_freq):
+ self._center_freq = int(cutoff_freq)
+ self.compute_params()
+
+ def compute_params(self):
+
+ omega = 2 * np.pi * self._center_freq / self._sample_freq
+ sin_om = np.sin(omega)
+ cos_om = np.cos(omega)
+
+ if self._type == BiquadFilterType.LPF:
+
+ if self._center_freq > 0:
+ Q = 1 / np.sqrt(2)
+ alpha = sin_om / (2 * Q)
+
+ self.b0 = (1 - cos_om) / 2
+ self.b1 = 1 - cos_om
+ self.b2 = self.b0
+ self.a0 = 1 + alpha
+ self.a1 = -2 * cos_om
+ self.a2 = 1 - alpha
+
+ elif self._type == BiquadFilterType.PEAK:
+
+ A = 10 ** (-self._attenuation_db / 40)
+
+ # why not the formula below? It prevents a division by 0 when bandwidth = 2*frequency
+ octaves = np.log2(self._center_freq / (self._center_freq - self._bandwidth_hz / 2)) * 2
+ Q = np.sqrt(2 ** octaves) / (2 ** octaves - 1)
+
+ # Q = self._center_freq / self._bandwidth_hz
+
+ alpha = sin_om / (2 * Q / A)
+
+ self.b0 = 1.0 + alpha * A
+ self.b1 = -2.0 * cos_om
+ self.b2 = 1.0 - alpha * A
+ self.a0 = 1.0 + alpha / A
+ self.a1 = -2.0 * cos_om
+ self.a2 = 1.0 - alpha / A
+
+ elif self._type == BiquadFilterType.NOTCH:
+ alpha = sin_om * np.sinh(np.log(2) / 2 * self._bandwidth_hz * omega * sin_om)
+
+ self.b0 = 1
+ self.b1 = -2 * cos_om
+ self.b2 = self.b0
+ self.a0 = 1 + alpha
+ self.a1 = -2 * cos_om
+ self.a2 = 1 - alpha
+
+ self.b0 /= self.a0
+ self.b1 /= self.a0
+ self.b2 /= self.a0
+ self.a1 /= self.a0
+ self.a2 /= self.a0
+
+ def apply(self, sample):
+
+ if self._center_freq <= 0:
+ return sample
+
+ output = (self.b0 * sample + self.b1 * self._delayed_sample1 + self.b2 * self._delayed_sample2 - self.a1
+ * self._delayed_output1 - self.a2 * self._delayed_output2)
+
+ self._delayed_sample2 = self._delayed_sample1
+ self._delayed_sample1 = sample
+
+ self._delayed_output2 = self._delayed_output1
+ self._delayed_output1 = output
+
+ return output
+
+ def get_params(self):
+
+ return {
+ "a1": self.a1,
+ "a2": self.a2,
+ "b0": self.b0,
+ "b1": self.b1,
+ "b2": self.b2,
+ }
+
+ def get_center_freq(self):
+ return self._center_freq
+
+ def get_attenuation(self):
+ return self._attenuation_db
+
+ def get_bandwidth(self):
+ return self._bandwidth_hz
+
+ def freq_response(self, f):
+ if self._center_freq <= 0:
+ return 1
+
+ phi = (np.sin(np.pi * f * 2 / (2 * self._sample_freq))) ** 2
+ r = (((self.b0 + self.b1 + self.b2) ** 2 - 4 * (self.b0 * self.b1 + 4 * self.b0 * self.b2 + self.b1 * self.b2)
+ * phi + 16 * self.b0 * self.b2 * phi * phi)
+ / ((1 + self.a1 + self.a2) ** 2 - 4 * (self.a1 + 4 * self.a2 + self.a1 * self.a2) * phi + 16
+ * self.a2 * phi * phi))
+ # if r < 0:
+ # r = 0
+ return r ** .5
diff --git a/Tools/FilterTestTool/FilterTest.py b/Tools/FilterTestTool/FilterTest.py
new file mode 100644
index 0000000000..5b70641ad6
--- /dev/null
+++ b/Tools/FilterTestTool/FilterTest.py
@@ -0,0 +1,468 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+""" ArduPilot IMU Filter Test Class
+
+This program is free software: you can redistribute it and/or modify it under
+the terms of the GNU General Public License as published by the Free Software
+Foundation, either version 3 of the License, or (at your option) any later
+version.
+This program is distributed in the hope that it will be useful, but WITHOUT
+ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
+You should have received a copy of the GNU General Public License along with
+this program. If not, see .
+"""
+
+__author__ = "Guglielmo Cassinelli"
+__contact__ = "gdguglie@gmail.com"
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.widgets import Slider
+from matplotlib.animation import FuncAnimation
+from scipy import signal
+from BiquadFilter import BiquadFilterType, BiquadFilter
+
+sliders = [] # matplotlib sliders must be global
+anim = None # matplotlib animations must be global
+
+
+class FilterTest:
+ FILTER_DEBOUNCE = 10 # ms
+
+ FILT_SHAPE_DT_FACTOR = 1 # increase to reduce filter shape size
+
+ FFT_N = 512
+
+ filters = {}
+
+ def __init__(self, acc_t, acc_x, acc_y, acc_z, gyr_t, gyr_x, gyr_y, gyr_z, acc_freq, gyr_freq,
+ acc_lpf_cutoff, gyr_lpf_cutoff,
+ acc_notch_freq, acc_notch_att, acc_notch_band,
+ gyr_notch_freq, gyr_notch_att, gyr_notch_band,
+ log_name, accel_notch=False, second_notch=False):
+
+ self.filter_color_map = plt.get_cmap('summer')
+
+ self.filters["acc"] = [
+ BiquadFilter(acc_lpf_cutoff, acc_freq)
+ ]
+
+ if accel_notch:
+ self.filters["acc"].append(
+ BiquadFilter(acc_notch_freq, acc_freq, BiquadFilterType.PEAK, acc_notch_att, acc_notch_band),
+ )
+
+ self.filters["gyr"] = [
+ BiquadFilter(gyr_lpf_cutoff, gyr_freq),
+ BiquadFilter(gyr_notch_freq, gyr_freq, BiquadFilterType.PEAK, gyr_notch_att, gyr_notch_band)
+ ]
+
+ if second_notch:
+ self.filters["acc"].append(
+ BiquadFilter(acc_notch_freq * 2, acc_freq, BiquadFilterType.PEAK, acc_notch_att, acc_notch_band)
+ )
+ self.filters["gyr"].append(
+ BiquadFilter(gyr_notch_freq * 2, gyr_freq, BiquadFilterType.PEAK, gyr_notch_att, gyr_notch_band)
+ )
+
+ self.ACC_t = acc_t
+ self.ACC_x = acc_x
+ self.ACC_y = acc_y
+ self.ACC_z = acc_z
+
+ self.GYR_t = gyr_t
+ self.GYR_x = gyr_x
+ self.GYR_y = gyr_y
+ self.GYR_z = gyr_z
+
+ self.GYR_freq = gyr_freq
+ self.ACC_freq = acc_freq
+
+ self.gyr_dt = 1. / gyr_freq
+ self.acc_dt = 1. / acc_freq
+
+ self.timer = None
+
+ self.updated_artists = []
+
+ # INIT
+ self.init_plot(log_name)
+
+ def test_acc_filters(self):
+ filt_xs = self.test_filters(self.filters["acc"], self.ACC_t, self.ACC_x)
+ filt_ys = self.test_filters(self.filters["acc"], self.ACC_t, self.ACC_y)
+ filt_zs = self.test_filters(self.filters["acc"], self.ACC_t, self.ACC_z)
+ return filt_xs, filt_ys, filt_zs
+
+ def test_gyr_filters(self):
+ filt_xs = self.test_filters(self.filters["gyr"], self.GYR_t, self.GYR_x)
+ filt_ys = self.test_filters(self.filters["gyr"], self.GYR_t, self.GYR_y)
+ filt_zs = self.test_filters(self.filters["gyr"], self.GYR_t, self.GYR_z)
+ return filt_xs, filt_ys, filt_zs
+
+ def test_filters(self, filters, Ts, Xs):
+ for f in filters:
+ f.reset()
+
+ x_filtered = []
+
+ for i, t in enumerate(Ts):
+ x = Xs[i]
+
+ x_f = x
+ for filt in filters:
+ x_f = filt.apply(x_f)
+
+ x_filtered.append(x_f)
+
+ return x_filtered
+
+ def get_filter_shape(self, filter):
+ samples = int(filter.get_sample_freq()) # resolution of filter shape based on sample rate
+ x_space = np.linspace(0.0, samples // 2, samples // int(2 * self.FILT_SHAPE_DT_FACTOR))
+ return x_space, filter.freq_response(x_space)
+
+ def init_signal_plot(self, ax, Ts, Xs, Ys, Zs, Xs_filtered, Ys_filtered, Zs_filtered, label):
+ ax.plot(Ts, Xs, linewidth=1, label="{}X".format(label), alpha=0.5)
+ ax.plot(Ts, Ys, linewidth=1, label="{}Y".format(label), alpha=0.5)
+ ax.plot(Ts, Zs, linewidth=1, label="{}Z".format(label), alpha=0.5)
+ filtered_x_ax, = ax.plot(Ts, Xs_filtered, linewidth=1, label="{}X filtered".format(label), alpha=1)
+ filtered_y_ax, = ax.plot(Ts, Ys_filtered, linewidth=1, label="{}Y filtered".format(label), alpha=1)
+ filtered_z_ax, = ax.plot(Ts, Zs_filtered, linewidth=1, label="{}Z filtered".format(label), alpha=1)
+ ax.legend(prop={'size': 8})
+ return filtered_x_ax, filtered_y_ax, filtered_z_ax
+
+ def fft_to_xdata(self, fft):
+ n = len(fft)
+ norm_factor = 2. / n
+ return norm_factor * np.abs(fft[:n // 2])
+
+ def plot_fft(self, ax, x, fft, label):
+ fft_ax, = ax.plot(x, self.fft_to_xdata(fft), label=label)
+ return fft_ax
+
+ def init_fft(self, ax, Ts, Xs, Ys, Zs, sample_rate, dt, Xs_filtered, Ys_filtered, Zs_filtered, label):
+
+ _freqs_raw_x, _times_raw_x, _stft_raw_x = signal.stft(Xs, sample_rate, window='hann', nperseg=self.FFT_N)
+ raw_fft_x = np.average(np.abs(_stft_raw_x), axis=1)
+
+ _freqs_raw_y, _times_raw_y, _stft_raw_y = signal.stft(Ys, sample_rate, window='hann', nperseg=self.FFT_N)
+ raw_fft_y = np.average(np.abs(_stft_raw_y), axis=1)
+
+ _freqs_raw_z, _times_raw_z, _stft_raw_z = signal.stft(Zs, sample_rate, window='hann', nperseg=self.FFT_N)
+ raw_fft_z = np.average(np.abs(_stft_raw_z), axis=1)
+
+ _freqs_x, _times_x, _stft_x = signal.stft(Xs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
+ filtered_fft_x = np.average(np.abs(_stft_x), axis=1)
+
+ _freqs_y, _times_y, _stft_y = signal.stft(Ys_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
+ filtered_fft_y = np.average(np.abs(_stft_y), axis=1)
+
+ _freqs_z, _times_z, _stft_z = signal.stft(Zs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
+ filtered_fft_z = np.average(np.abs(_stft_z), axis=1)
+
+ ax.plot(_freqs_raw_x, raw_fft_x, alpha=0.5, linewidth=1, label="{}x FFT".format(label))
+ ax.plot(_freqs_raw_y, raw_fft_y, alpha=0.5, linewidth=1, label="{}y FFT".format(label))
+ ax.plot(_freqs_raw_z, raw_fft_z, alpha=0.5, linewidth=1, label="{}z FFT".format(label))
+
+ filtered_fft_ax_x, = ax.plot(_freqs_x, filtered_fft_x, label="filt. {}x FFT".format(label))
+ filtered_fft_ax_y, = ax.plot(_freqs_y, filtered_fft_y, label="filt. {}y FFT".format(label))
+ filtered_fft_ax_z, = ax.plot(_freqs_z, filtered_fft_z, label="filt. {}z FFT".format(label))
+
+ # FFT
+ # samples = len(Ts)
+ # x_space = np.linspace(0.0, 1.0 / (2.0 * dt), samples // 2)
+ # filtered_data = np.hanning(len(Xs_filtered)) * Xs_filtered
+ # raw_fft = np.fft.fft(np.hanning(len(Xs)) * Xs)
+ # filtered_fft = np.fft.fft(filtered_data, n=self.FFT_N)
+ # self.plot_fft(ax, x_space, raw_fft, "{} FFT".format(label))
+ # fft_freq = np.fft.fftfreq(self.FFT_N, d=dt)
+ # x_space
+ # filtered_fft_ax = self.plot_fft(ax, fft_freq[:self.FFT_N // 2], filtered_fft, "filtered {} FFT".format(label))
+
+ ax.set_xlabel("frequency")
+ # ax.set_xscale("log")
+ # ax.xaxis.set_major_formatter(ScalarFormatter())
+ ax.legend(prop={'size': 8})
+
+ return filtered_fft_ax_x, filtered_fft_ax_y, filtered_fft_ax_z
+
+ def init_filter_shape(self, ax, filter, color):
+ center = filter.get_center_freq()
+ x_space, lpf_shape = self.get_filter_shape(filter)
+
+ plot_slpf_shape, = ax.plot(x_space, lpf_shape, c=color, label="LPF shape")
+ xvline_lpf_cutoff = ax.axvline(x=center, linestyle="--", c=color) # LPF cutoff freq
+
+ return plot_slpf_shape, xvline_lpf_cutoff
+
+ def create_slider(self, name, rect, max, value, color, callback):
+ global sliders
+ ax_slider = self.fig.add_axes(rect, facecolor='lightgoldenrodyellow')
+ slider = Slider(ax_slider, name, 0, max, valinit=np.sqrt(max * value), valstep=1, color=color)
+ slider.valtext.set_text(value)
+
+ # slider.drawon = False
+
+ def changed(val, cbk, max, slider):
+ # non linear slider to better control small values
+ val = int(val ** 2 / max)
+ slider.valtext.set_text(val)
+ cbk(val)
+
+ slider.on_changed(lambda val, cbk=callback, max=max, s=slider: changed(val, cbk, max, s))
+ sliders.append(slider)
+
+ def delay_update(self, update_cbk):
+ def _delayed_update(self, cbk):
+ self.timer.stop()
+ cbk()
+
+ # delay actual filtering
+ if self.fig:
+ if self.timer:
+ self.timer.stop()
+ self.timer = self.fig.canvas.new_timer(interval=self.FILTER_DEBOUNCE)
+ self.timer.add_callback(lambda self=self: _delayed_update(self, update_cbk))
+ self.timer.start()
+
+ def update_filter_shape(self, filter, shape, center_line):
+ x_data, new_shape = self.get_filter_shape(filter)
+
+ shape.set_ydata(new_shape)
+ center_line.set_xdata(filter.get_center_freq())
+
+ self.updated_artists.extend([
+ shape,
+ center_line,
+ ])
+
+ def update_signal_and_fft_plot(self, filters_key, time_list, sample_lists, signal_shapes, fft_shapes, shape,
+ center_line, sample_rate):
+ # print("update_signal_and_fft_plot", self.filters[filters_key][0].get_center_freq())
+ Xs, Ys, Zs = sample_lists
+ signal_shape_x, signal_shape_y, signal_shape_z = signal_shapes
+ fft_shape_x, fft_shape_y, fft_shape_z = fft_shapes
+
+ Xs_filtered = self.test_filters(self.filters[filters_key], time_list, Xs)
+ Ys_filtered = self.test_filters(self.filters[filters_key], time_list, Ys)
+ Zs_filtered = self.test_filters(self.filters[filters_key], time_list, Zs)
+
+ signal_shape_x.set_ydata(Xs_filtered)
+ signal_shape_y.set_ydata(Ys_filtered)
+ signal_shape_z.set_ydata(Zs_filtered)
+
+ self.updated_artists.extend([signal_shape_x, signal_shape_y, signal_shape_z])
+
+ _freqs_x, _times_x, _stft_x = signal.stft(Xs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
+ filtered_fft_x = np.average(np.abs(_stft_x), axis=1)
+
+ _freqs_y, _times_y, _stft_y = signal.stft(Ys_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
+ filtered_fft_y = np.average(np.abs(_stft_y), axis=1)
+
+ _freqs_z, _times_z, _stft_z = signal.stft(Zs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
+ filtered_fft_z = np.average(np.abs(_stft_z), axis=1)
+
+ fft_shape_x.set_ydata(filtered_fft_x)
+ fft_shape_y.set_ydata(filtered_fft_y)
+ fft_shape_z.set_ydata(filtered_fft_z)
+
+ self.updated_artists.extend([
+ fft_shape_x, fft_shape_y, fft_shape_z,
+ shape, center_line,
+ ])
+
+ # self.fig.canvas.draw()
+
+ def animation_update(self):
+ updated_artists = self.updated_artists.copy()
+
+ # if updated_artists:
+ # print("animation update")
+
+ # reset updated artists
+ self.updated_artists = []
+
+ return updated_artists
+
+ def update_filter(self, val, cbk, filter, shape, center_line, filters_key, time_list, sample_lists, signal_shapes,
+ fft_shapes):
+ # this callback sets the parameter controlled by the slider
+ cbk(val)
+ # print("filter update",val)
+ # update filter shape and delay fft update
+ self.update_filter_shape(filter, shape, center_line)
+ sample_freq = filter.get_sample_freq()
+ self.delay_update(
+ lambda self=self: self.update_signal_and_fft_plot(filters_key, time_list, sample_lists, signal_shapes,
+ fft_shapes, shape, center_line, sample_freq))
+
+ def create_filter_control(self, name, filter, rect, max, default, shape, center_line, cbk, filters_key, time_list,
+ sample_lists, signal_shapes, fft_shapes, filt_color):
+ self.create_slider(name, rect, max, default, filt_color, lambda val, cbk=cbk, self=self, filter=filter, shape=shape,
+ center_line=center_line, filters_key=filters_key,
+ time_list=time_list, sample_list=sample_lists,
+ signal_shape=signal_shapes, fft_shape=fft_shapes:
+ self.update_filter(val, cbk, filter, shape, center_line, filters_key,
+ time_list, sample_list, signal_shape, fft_shape))
+
+ def create_controls(self, filters_key, base_rect, padding, ax_fft, time_list, sample_lists, signal_shapes,
+ fft_shapes):
+ ax_filter = ax_fft.twinx()
+ ax_filter.set_navigate(False)
+ ax_filter.set_yticks([])
+
+ num_filters = len(self.filters[filters_key])
+
+ for i, filter in enumerate(self.filters[filters_key]):
+ filt_type = filter.get_type()
+ filt_color = self.filter_color_map(i / num_filters)
+ filt_shape, filt_cutoff = self.init_filter_shape(ax_filter, filter, filt_color)
+
+ if filt_type == BiquadFilterType.PEAK:
+ name = "Notch"
+ else:
+ name = "LPF"
+
+ # control for center freq is common to all filters
+ self.create_filter_control("{} freq".format(name), filter, base_rect, 500, filter.get_center_freq(),
+ filt_shape, filt_cutoff,
+ lambda val, filter=filter: filter.set_center_freq(val),
+ filters_key, time_list, sample_lists, signal_shapes, fft_shapes, filt_color)
+ # move down of control height + padding
+ base_rect[1] -= (base_rect[3] + padding)
+
+ if filt_type == BiquadFilterType.PEAK:
+ self.create_filter_control("{} att (db)".format(name), filter, base_rect, 100, filter.get_attenuation(),
+ filt_shape, filt_cutoff,
+ lambda val, filter=filter: filter.set_attenuation(val),
+ filters_key, time_list, sample_lists, signal_shapes, fft_shapes, filt_color)
+ base_rect[1] -= (base_rect[3] + padding)
+ self.create_filter_control("{} band".format(name), filter, base_rect, 300, filter.get_bandwidth(),
+ filt_shape, filt_cutoff,
+ lambda val, filter=filter: filter.set_bandwidth(val),
+ filters_key, time_list, sample_lists, signal_shapes, fft_shapes, filt_color)
+ base_rect[1] -= (base_rect[3] + padding)
+
+ def create_spectrogram(self, data, name, sample_rate):
+ freqs, times, Sx = signal.spectrogram(np.array(data), fs=sample_rate, window='hanning',
+ nperseg=self.FFT_N, noverlap=self.FFT_N - self.FFT_N // 10,
+ detrend=False, scaling='spectrum')
+
+ f, ax = plt.subplots(figsize=(4.8, 2.4))
+ ax.pcolormesh(times, freqs, 10 * np.log10(Sx), cmap='viridis')
+ ax.set_title(name)
+ ax.set_ylabel('Frequency (Hz)')
+ ax.set_xlabel('Time (s)')
+
+ def init_plot(self, log_name):
+
+ self.fig = plt.figure(figsize=(14, 9))
+ self.fig.canvas.set_window_title("ArduPilot Filter Test Tool - {}".format(log_name))
+ self.fig.canvas.draw()
+
+ rows = 2
+ cols = 3
+ raw_acc_index = 1
+ fft_acc_index = raw_acc_index + 1
+ raw_gyr_index = cols + 1
+ fft_gyr_index = raw_gyr_index + 1
+
+ # signal
+ self.ax_acc = self.fig.add_subplot(rows, cols, raw_acc_index)
+ self.ax_gyr = self.fig.add_subplot(rows, cols, raw_gyr_index, sharex=self.ax_acc)
+
+ accx_filtered, accy_filtered, accz_filtered = self.test_acc_filters()
+ self.ax_filtered_accx, self.ax_filtered_accy, self.ax_filtered_accz = self.init_signal_plot(self.ax_acc,
+ self.ACC_t,
+ self.ACC_x,
+ self.ACC_y,
+ self.ACC_z,
+ accx_filtered,
+ accy_filtered,
+ accz_filtered,
+ "AccX")
+
+ gyrx_filtered, gyry_filtered, gyrz_filtered = self.test_gyr_filters()
+ self.ax_filtered_gyrx, self.ax_filtered_gyry, self.ax_filtered_gyrz = self.init_signal_plot(self.ax_gyr,
+ self.GYR_t,
+ self.GYR_x,
+ self.GYR_y,
+ self.GYR_z,
+ gyrx_filtered,
+ gyry_filtered,
+ gyrz_filtered,
+ "GyrX")
+
+ # FFT
+ self.ax_acc_fft = self.fig.add_subplot(rows, cols, fft_acc_index)
+ self.ax_gyr_fft = self.fig.add_subplot(rows, cols, fft_gyr_index)
+
+ self.acc_filtered_fft_ax_x, self.acc_filtered_fft_ax_y, self.acc_filtered_fft_ax_z = self.init_fft(
+ self.ax_acc_fft, self.ACC_t, self.ACC_x, self.ACC_y, self.ACC_z, self.ACC_freq, self.acc_dt, accx_filtered,
+ accy_filtered, accz_filtered, "AccX")
+ self.gyr_filtered_fft_ax_x, self.gyr_filtered_fft_ax_y, self.gyr_filtered_fft_ax_z = self.init_fft(
+ self.ax_gyr_fft, self.GYR_t, self.GYR_x, self.GYR_y, self.GYR_z, self.GYR_freq, self.gyr_dt, gyrx_filtered,
+ gyry_filtered, gyrz_filtered, "GyrX")
+
+ self.fig.tight_layout()
+
+ # TODO add y z
+ self.create_controls("acc", [0.75, 0.95, 0.2, 0.02], 0.01, self.ax_acc_fft, self.ACC_t,
+ (self.ACC_x, self.ACC_y, self.ACC_z),
+ (self.ax_filtered_accx, self.ax_filtered_accy, self.ax_filtered_accz),
+ (self.acc_filtered_fft_ax_x, self.acc_filtered_fft_ax_y, self.acc_filtered_fft_ax_z))
+ self.create_controls("gyr", [0.75, 0.45, 0.2, 0.02], 0.01, self.ax_gyr_fft, self.GYR_t,
+ (self.GYR_x, self.GYR_y, self.GYR_z),
+ (self.ax_filtered_gyrx, self.ax_filtered_gyry, self.ax_filtered_gyrz),
+ (self.gyr_filtered_fft_ax_x, self.gyr_filtered_fft_ax_y, self.gyr_filtered_fft_ax_z))
+
+ # setup animation for continuous update
+ global anim
+ anim = FuncAnimation(self.fig, lambda frame, self=self: self.animation_update(), interval=1, blit=False)
+
+ # Work in progress here...
+ # self.create_spectrogram(self.GYR_x, "GyrX", self.GYR_freq)
+ # self.create_spectrogram(gyrx_filtered, "GyrX filtered", self.GYR_freq)
+ # self.create_spectrogram(self.ACC_x, "AccX", self.ACC_freq)
+ # self.create_spectrogram(accx_filtered, "AccX filtered", self.ACC_freq)
+
+ plt.show()
+
+ self.print_filter_param_info()
+
+ def print_filter_param_info(self):
+ if len(self.filters["acc"]) > 2 or len(self.filters["gyr"]) > 2:
+ print("Testing too many filters unsupported from firmware, cannot calculate parameters to set them")
+ return
+
+ print("To have the last filter settings in the graphs set the following parameters:\n")
+
+ for f in self.filters["acc"]:
+ filt_type = f.get_type()
+
+ if filt_type == BiquadFilterType.PEAK: # NOTCH
+ print("INS_NOTCA_ENABLE,", 1)
+ print("INS_NOTCA_FREQ,", f.get_center_freq())
+ print("INS_NOTCA_BW,", f.get_bandwidth())
+ print("INS_NOTCA_ATT,", f.get_attenuation())
+ else: # LPF
+ print("INS_ACCEL_FILTER,", f.get_center_freq())
+
+ for f in self.filters["gyr"]:
+ filt_type = f.get_type()
+
+ if filt_type == BiquadFilterType.PEAK: # NOTCH
+ print("INS_NOTCH_ENABLE,", 1)
+ print("INS_NOTCH_FREQ,", f.get_center_freq())
+ print("INS_NOTCH_BW,", f.get_bandwidth())
+ print("INS_NOTCH_ATT,", f.get_attenuation())
+ else: # LPF
+ print("INS_GYRO_FILTER,", f.get_center_freq())
+
+ print("\n+---------+")
+ print("| WARNING |")
+ print("+---------+")
+ print("Always check the onboard FFT to setup filters, this tool only simulate effects of filtering.")
diff --git a/Tools/FilterTestTool/Readme.md b/Tools/FilterTestTool/Readme.md
new file mode 100644
index 0000000000..730ae2453a
--- /dev/null
+++ b/Tools/FilterTestTool/Readme.md
@@ -0,0 +1,32 @@
+# ArduPilot IMU Filter Test Tool
+
+**Warning: always check the onboard FFT to setup filters, this tool only simulate effects of filtering.**
+
+This is a tool to simulate IMU filtering on a raw IMU log.
+To run it:
+
+```bash
+ python run_filter_test.py
+```
+
+This will open a file chooser dialog to select a log file.
+
+
+Log file can also be specified from command line
+
+```bash
+ python run_filter_test.py logfile.bin
+```
+
+To choose a smaller section of the log begin and/or end time can be specified in seconds.
+E.g. to open only the log section between 60 and 120 seconds:
+
+```bash
+ python run_filter_test.py logfile.bin -b 60 -e 120
+```
+
+More info here:
+
+ https://discuss.ardupilot.org/t/imu-filter-tool/43633
+
+
diff --git a/Tools/FilterTestTool/run_filter_test.py b/Tools/FilterTestTool/run_filter_test.py
new file mode 100755
index 0000000000..be9dcb9514
--- /dev/null
+++ b/Tools/FilterTestTool/run_filter_test.py
@@ -0,0 +1,269 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+""" ArduPilot IMU Filter Test Tool
+
+This program is free software: you can redistribute it and/or modify it under
+the terms of the GNU General Public License as published by the Free Software
+Foundation, either version 3 of the License, or (at your option) any later
+version.
+This program is distributed in the hope that it will be useful, but WITHOUT
+ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
+You should have received a copy of the GNU General Public License along with
+this program. If not, see .
+"""
+
+__author__ = "Guglielmo Cassinelli"
+__contact__ = "gdguglie@gmail.com"
+
+try: # Python 3.x
+ from tkinter import Tk
+ from tkinter.filedialog import askopenfilename
+except ImportError: # Python 2.x
+ from Tkinter import Tk
+ from tkFileDialog import askopenfilename
+
+import argparse
+import ntpath
+import numpy as np
+from pymavlink import mavutil
+
+
+"""
+read command line parameters
+"""
+
+parser = argparse.ArgumentParser(description='ArduPilot IMU Filter Tester Tool. Input one log file from ')
+parser.add_argument('file', nargs='?', default=None, help='bin log file containing raw IMU logs')
+parser.add_argument('--begin-time', '-b', type=int, default=0, help='start from second')
+parser.add_argument('--end-time', '-e', type=int, default=-1, help='end to second')
+
+args = parser.parse_args()
+
+log_file = args.file
+begin_time = args.begin_time
+end_time = args.end_time
+
+# if log not input by command line
+if not log_file:
+ # GUI log file chooser
+ root = Tk()
+ root.withdraw()
+ root.focus_force()
+ log_file = askopenfilename(title="Select log file", filetypes=(("log files", ".bin .log"), ("all files", "*.*")))
+ root.update()
+ root.destroy()
+
+if log_file is None or log_file == "":
+ print("No log file to open")
+ quit()
+
+log_name = ntpath.basename(log_file)
+
+"""
+default settings
+"""
+POST_FILTER_LOGGING_BIT = 2 ** 1
+
+RAW_IMU_LOG_BIT = 2 ** 19
+
+PREVENT_POST_FILTER_LOGS = False
+
+PARAMS_TO_CHECK = [
+ "INS_LOG_BAT_OPT", "INS_GYRO_FILTER", "INS_ACCEL_FILTER",
+ "INS_NOTCH_ENABLE", "INS_NOTCH_FREQ", "INS_NOTCH_BW", "INS_NOTCH_ATT",
+ "INS_NOTCA_ENABLE", "INS_NOTCA_FREQ", "INS_NOTCA_BW", "INS_NOTCA_ATT",
+ "LOG_BITMASK"
+]
+
+DEFAULT_ACC_FILTER = 80 # hz
+DEFAULT_GYR_FILTER = 80 # hz
+
+DEFAULT_ACC_NOTCH_FREQ = 150 # hz
+DEFAULT_ACC_NOTCH_ATTENUATION = 30 # db
+DEFAULT_ACC_NOTCH_BANDWIDTH = 100 # hz
+
+DEFAULT_GYR_NOTCH_FREQ = 145
+DEFAULT_GYR_NOTCH_ATTENUATION = 30 # db
+DEFAULT_GYR_NOTCH_BANDWIDTH = 100 # hz
+
+ACCEL_NOTCH_FILTER = True
+
+"""
+load LOG
+"""
+print("Loading %s...\n" % log_name)
+
+mlog = mavutil.mavlink_connection(log_file)
+
+log_start_time = 0
+log_end_time = 0
+
+ACC_t = []
+ACC_x = []
+ACC_y = []
+ACC_z = []
+
+GYR_t = []
+GYR_x = []
+GYR_y = []
+GYR_z = []
+
+params = {}
+
+while True:
+ m = mlog._parse_next()
+ """
+ @type m DFMessage
+ """
+
+ if m is None:
+ break
+
+ if m.fmt.name == "PARM":
+ # check param value
+
+ if m.Name in PARAMS_TO_CHECK:
+ print(m.Name, ", ", m.Value)
+ params[m.Name] = m.Value
+
+ try:
+ m_time_sec = m.TimeUS / 1000000.
+
+ if log_start_time == 0:
+ log_start_time = m_time_sec
+
+ if m_time_sec < begin_time:
+ continue
+
+ if end_time > 0 and m_time_sec > end_time:
+ continue
+ except AttributeError:
+ pass
+
+ if m.fmt.name == "ACC1":
+ ACC_t.append(m_time_sec)
+ ACC_x.append(m.AccX)
+ ACC_y.append(m.AccY)
+ ACC_z.append(m.AccZ)
+
+ elif m.fmt.name == "GYR1":
+ GYR_t.append(m_time_sec)
+ GYR_x.append(m.GyrX)
+ GYR_y.append(m.GyrY)
+ GYR_z.append(m.GyrZ)
+
+
+def print_log_msg_stats(log_time_list, msg_name):
+ msg_count = len(log_time_list)
+
+ if msg_count > 0:
+ msg_total_time = log_time_list[-1] - log_time_list[0]
+ msg_freq = msg_count / msg_total_time
+ else:
+ msg_total_time = 0
+ msg_freq = 0
+
+ print("\n{} {} logs for a duration of {:.1f} secs".format(msg_count, msg_name, msg_total_time))
+ print(msg_name + " frequency = {:.2f} hz".format(msg_freq))
+
+ return msg_freq
+
+
+def get_mean_and_std(np_arr):
+ mean = np.mean(np_arr)
+ std = np.std(np_arr)
+ return mean, std
+
+
+def print_mean_and_std(np_arr, name=""):
+ mean, std = get_mean_and_std(np_arr)
+ print("{} mean {:.3f} std {:.3f}".format(name, mean, std))
+
+
+def set_bit(number, bit_index, bit_value):
+ """Set the index:th bit of v to 1 if x is truthy, else to 0, and return the new value."""
+ mask = 1 << bit_index # Compute mask, an integer with just bit 'index' set.
+ number &= ~mask # Clear the bit indicated by the mask (if x is False)
+ if bit_value:
+ number |= mask # If x was True, set the bit indicated by the mask.
+ return number # Return the result, we're done.
+
+
+ACC_freq = print_log_msg_stats(ACC_t, "ACC")
+GYR_freq = print_log_msg_stats(GYR_t, "GYR")
+
+if not ACC_t or not GYR_t:
+ print("\nNo RAW IMU logs to analyze")
+ quit()
+
+if "INS_LOG_BAT_OPT" in params:
+ log_bat_opt = int(params["INS_LOG_BAT_OPT"])
+ if log_bat_opt & POST_FILTER_LOGGING_BIT:
+ print("\nINS_LOG_BAT_OPT was set to {} which enables post filter logging,"
+ "use pre filter logging to not sum multiple filter passes.".format(log_bat_opt))
+ print("(set INS_LOG_BAT_OPT = {})".format(set_bit(log_bat_opt, 1, 0)))
+
+ if PREVENT_POST_FILTER_LOGS:
+ quit()
+else:
+ print("couldn't check ")
+
+if "LOG_BITMASK" in params:
+ log_bitmask = int(params["LOG_BITMASK"])
+ if not log_bitmask & RAW_IMU_LOG_BIT:
+ print("\nWARNING: LOG_BITMASK was not set to enable RAW_IMU logging, please enable it to have best resolution")
+else:
+ print("\nWARNING: Cannot read LOG_BITMASK, please ensure to have enabled RAW_IMU logging")
+
+# set filter parameters
+print("Reading filter parameters to set initial filter values...")
+
+if "INS_GYRO_FILTER" in params:
+ DEFAULT_GYR_FILTER = params["INS_GYRO_FILTER"]
+
+if "INS_ACCEL_FILTER" in params:
+ DEFAULT_ACC_FILTER = params["INS_ACCEL_FILTER"]
+
+if "INS_NOTCH_ENABLE" in params:
+ if params["INS_NOTCH_ENABLE"] != 0:
+ if "INS_NOTCH_ATT" in params:
+ DEFAULT_GYR_NOTCH_ATTENUATION = params["INS_NOTCH_ATT"]
+ else:
+ DEFAULT_GYR_NOTCH_ATTENUATION = 0
+
+ if "INS_NOTCH_BW" in params:
+ DEFAULT_GYR_NOTCH_BANDWIDTH = params["INS_NOTCH_BW"]
+
+ if "INS_NOTCH_FREQ" in params:
+ DEFAULT_GYR_NOTCH_FREQ = params["INS_NOTCH_FREQ"]
+
+if "INS_NOTCA_ENABLE" in params:
+ if params["INS_NOTCA_ENABLE"] != 0:
+ if "INS_NOTCA_ATT" in params:
+ DEFAULT_ACC_NOTCH_ATTENUATION = params["INS_NOTCA_ATT"]
+ else:
+ DEFAULT_ACC_NOTCH_ATTENUATION = 0
+
+ if "INS_NOTCA_BW" in params:
+ DEFAULT_ACC_NOTCH_BANDWIDTH = params["INS_NOTCA_BW"]
+
+ if "INS_NOTCA_FREQ" in params:
+ DEFAULT_ACC_NOTCH_FREQ = params["INS_NOTCA_FREQ"]
+
+else:
+ print("The firmware that produced this log does not support notch filter on accelerometer")
+ ACCEL_NOTCH_FILTER = False
+
+
+"""
+run filter tet
+"""
+from FilterTest import FilterTest
+
+filter_test = FilterTest(ACC_t, ACC_x, ACC_y, ACC_z, GYR_t, GYR_x, GYR_y, GYR_z, ACC_freq, GYR_freq,
+ DEFAULT_ACC_FILTER, DEFAULT_GYR_FILTER,
+ DEFAULT_ACC_NOTCH_FREQ, DEFAULT_ACC_NOTCH_ATTENUATION, DEFAULT_ACC_NOTCH_BANDWIDTH,
+ DEFAULT_GYR_NOTCH_FREQ, DEFAULT_GYR_NOTCH_ATTENUATION, DEFAULT_GYR_NOTCH_BANDWIDTH,
+ log_name, ACCEL_NOTCH_FILTER)