autotest: allow message hooks to be instances of a MessageHook class

allows state to be encapsulated within the object rather than simply in the same scope / closure of the method being called.

Will allow easier re-use of these blocks
This commit is contained in:
Peter Barker 2024-06-22 00:22:02 +10:00 committed by Peter Barker
parent 0c7a527ad3
commit eb539f2c4a
2 changed files with 110 additions and 41 deletions

View File

@ -8,7 +8,6 @@ from __future__ import print_function
import math
import os
import signal
import time
from pymavlink import quaternion
from pymavlink import mavutil
@ -2136,42 +2135,18 @@ class AutoTestPlane(vehicle_test_suite.TestSuite):
self.fly_home_land_and_disarm()
def deadreckoning_main(self, disable_airspeed_sensor=False):
self.context_push()
self.set_parameter("EK3_OPTIONS", 1)
self.set_parameter("AHRS_OPTIONS", 3)
self.set_parameter("LOG_REPLAY", 1)
self.reboot_sitl()
self.wait_ready_to_arm()
self.gpi = None
self.simstate = None
self.last_print = 0
self.max_divergence = 0
def validate_global_position_int_against_simstate(mav, m):
if m.get_type() == 'GLOBAL_POSITION_INT':
self.gpi = m
elif m.get_type() == 'SIMSTATE':
self.simstate = m
if self.gpi is None:
return
if self.simstate is None:
return
divergence = self.get_distance_int(self.gpi, self.simstate)
if disable_airspeed_sensor:
max_allowed_divergence = 300
else:
max_allowed_divergence = 150
if (time.time() - self.last_print > 1 or
divergence > self.max_divergence):
self.progress("position-estimate-divergence=%fm" % (divergence,))
self.last_print = time.time()
if divergence > self.max_divergence:
self.max_divergence = divergence
if divergence > max_allowed_divergence:
raise NotAchievedException(
"global-position-int diverged from simstate by %fm (max=%fm" %
(divergence, max_allowed_divergence,))
self.install_message_hook(validate_global_position_int_against_simstate)
if disable_airspeed_sensor:
max_allowed_divergence = 300
else:
max_allowed_divergence = 150
self.install_message_hook_context(vehicle_test_suite.TestSuite.ValidateGlobalPositionIntAgainstSimState(self, max_allowed_divergence=max_allowed_divergence)) # noqa
try:
# wind is from the West:
@ -2232,9 +2207,11 @@ class AutoTestPlane(vehicle_test_suite.TestSuite):
raise NotAchievedException("GPS use re-started %f sec after jamming stopped" % time_since_jamming_stopped)
self.set_rc(3, 1000)
self.fly_home_land_and_disarm()
self.progress("max-divergence: %fm" % (self.max_divergence,))
finally:
self.remove_message_hook(validate_global_position_int_against_simstate)
pass
self.context_pop()
self.reboot_sitl()
def Deadreckoning(self):
'''Test deadreckoning support'''

View File

@ -1832,6 +1832,7 @@ class TestSuite(ABC):
self.terrain_data_messages_sent = 0 # count of messages back
self.dronecan_tests = dronecan_tests
self.statustext_id = 1
self.message_hooks = [] # functions or MessageHook instances
def __del__(self):
if self.rc_thread is not None:
@ -3325,6 +3326,71 @@ class TestSuite(ABC):
return
self.drain_all_pexpects()
class MessageHook():
'''base class for objects that watch the message stream and check for
validity of fields'''
def __init__(self, suite):
self.suite = suite
def process(self):
pass
def progress_prefix(self):
return ""
def progress(self, string):
string = self.progress_prefix() + string
self.suite.progress(string)
class ValidateIntPositionAgainstSimState(MessageHook):
'''monitors a message containing a position containing lat/lng in 1e7,
makes sure it stays close to SIMSTATE'''
def __init__(self, suite, other_int_message_name, max_allowed_divergence=150):
super(TestSuite.ValidateIntPositionAgainstSimState, self).__init__(suite)
self.other_int_message_name = other_int_message_name
self.max_allowed_divergence = max_allowed_divergence
self.max_divergence = 0
self.gpi = None
self.simstate = None
self.last_print = 0
self.min_print_interval = 1 # seconds
def progress_prefix(self):
return "VIPASS: "
def process(self, mav, m):
if m.get_type() == self.other_int_message_name:
self.gpi = m
elif m.get_type() == 'SIMSTATE':
self.simstate = m
if self.gpi is None:
return
if self.simstate is None:
return
divergence = self.suite.get_distance_int(self.gpi, self.simstate)
if (time.time() - self.last_print > self.min_print_interval or
divergence > self.max_divergence):
self.progress(f"distance(SIMSTATE,{self.other_int_message_name})={divergence:.5f}m")
self.last_print = time.time()
if divergence > self.max_divergence:
self.max_divergence = divergence
if divergence > self.max_allowed_divergence:
raise NotAchievedException(
"%s diverged from simstate by %fm (max=%fm" %
(self.other_int_message_name, divergence, self.max_allowed_divergence,))
def hook_removed(self):
self.progress(f"Maximum divergence was {self.max_divergence}m (max={self.max_allowed_divergence}m)")
class ValidateGlobalPositionIntAgainstSimState(ValidateIntPositionAgainstSimState):
def __init__(self, suite, **kwargs):
super(TestSuite.ValidateGlobalPositionIntAgainstSimState, self).__init__(suite, 'GLOBAL_POSITION_INT', **kwargs)
class ValidateAHRS3AgainstSimState(ValidateIntPositionAgainstSimState):
def __init__(self, suite, **kwargs):
super(TestSuite.ValidateAHRS3AgainstSimState, self).__init__(suite, 'AHRS3', **kwargs)
def message_hook(self, mav, msg):
"""Called as each mavlink msg is received."""
# print("msg: %s" % str(msg))
@ -3336,6 +3402,13 @@ class TestSuite(ABC):
self.idle_hook(mav)
self.do_heartbeats()
for h in self.message_hooks:
if isinstance(h, TestSuite.MessageHook):
h.process(mav, msg)
continue
# assume it's a function
h(mav, msg)
def send_message_hook(self, msg, x):
self.write_msg_to_tlog(msg)
@ -4643,14 +4716,14 @@ class TestSuite(ABC):
return wploader.count()
def install_message_hook(self, hook):
self.mav.message_hooks.append(hook)
self.message_hooks.append(hook)
def install_message_hook_context(self, hook):
'''installs a message hook which will be removed when the context goes
away'''
if self.mav is None:
return
self.mav.message_hooks.append(hook)
self.message_hooks.append(hook)
self.context_get().message_hooks.append(hook)
def remove_message_hook(self, hook):
@ -4658,7 +4731,9 @@ class TestSuite(ABC):
once'''
if self.mav is None:
return
self.mav.message_hooks.remove(hook)
self.message_hooks.remove(hook)
if isinstance(hook, TestSuite.MessageHook):
hook.hook_removed()
def install_example_script_context(self, scriptname):
'''installs an example script which will be removed when the context goes
@ -7480,6 +7555,23 @@ class TestSuite(ABC):
**kwargs
)
def wait_distance_between(self, series1, series2, min_distance, max_distance, timeout=30, **kwargs):
"""Wait for distance between two position series to be between two thresholds."""
def get_distance():
self.drain_mav()
m1 = self.mav.messages[series1]
m2 = self.mav.messages[series2]
return self.get_distance_int(m1, m2)
self.wait_and_maintain_range(
value_name=f"Distance({series1}, {series2})",
minimum=min_distance,
maximum=max_distance,
current_value_getter=lambda: get_distance(),
timeout=timeout,
**kwargs
)
def wait_distance(self, distance, accuracy=2, timeout=30, **kwargs):
"""Wait for flight of a given distance."""
start = self.mav.location()
@ -8517,7 +8609,7 @@ Also, ignores heartbeats not from our target system'''
tee = TeeBoth(test_output_filename, 'w', self.mavproxy_logfile, suppress_stdout=suppress_stdout)
start_message_hooks = copy.copy(self.mav.message_hooks)
start_message_hooks = copy.copy(self.message_hooks)
prettyname = "%s (%s)" % (name, desc)
self.start_test(prettyname)
@ -8551,9 +8643,9 @@ Also, ignores heartbeats not from our target system'''
ex = e
# reset the message hooks; we've failed-via-exception and
# can't expect the hooks to have been cleaned up
for h in copy.copy(self.mav.message_hooks):
for h in copy.copy(self.message_hooks):
if h not in start_message_hooks:
self.mav.message_hooks.remove(h)
self.message_hooks.remove(h)
hooks_removed = True
self.test_timings[desc] = time.time() - start_time
reset_needed = self.contexts[-1].sitl_commandline_customised
@ -8640,9 +8732,9 @@ Also, ignores heartbeats not from our target system'''
self.progress("Done popping extra contexts")
# make sure we don't leave around stray listeners:
if len(self.mav.message_hooks) != len(start_message_hooks):
if len(self.message_hooks) != len(start_message_hooks):
self.progress("Stray message listeners: %s vs start %s" %
(str(self.mav.message_hooks), str(start_message_hooks)))
(str(self.message_hooks), str(start_message_hooks)))
passed = False
if passed: