From d41554d88067f42461c9610eca6e41a5d30293ff Mon Sep 17 00:00:00 2001 From: Vasily Evseenko Date: Wed, 1 Feb 2023 17:42:07 +0300 Subject: [PATCH] Refactor mavlink parser --- wfb_ng/__init__.py | 36 +++++++- wfb_ng/mavlink_protocol.py | 146 +++++++++++++++++++++++------- wfb_ng/proxy.py | 178 ++----------------------------------- wfb_ng/server.py | 5 +- wfb_ng/tests/test_proxy.py | 54 ++++++++--- 5 files changed, 206 insertions(+), 213 deletions(-) diff --git a/wfb_ng/__init__.py b/wfb_ng/__init__.py index f7e45e3..43f434b 100644 --- a/wfb_ng/__init__.py +++ b/wfb_ng/__init__.py @@ -1,9 +1,10 @@ import sys import os +from twisted.internet import utils from logging import currentframe -from twisted.python import log as twisted_log +from twisted.python import log -__orig_msg = twisted_log.msg +__orig_msg = log.msg _srcfile = os.path.splitext(os.path.normcase(__file__))[0] + '.py' # Returns escape codes from format codes @@ -132,3 +133,34 @@ def _log_msg(*args, **kwargs): kwargs['system'] = '%s #%s' % ('.'.join(tmp), log_level_map[level]) return __orig_msg(*args, **kwargs) + + +class ExecError(Exception): + pass + + +def call_and_check_rc(cmd, *args, **kwargs): + def _check_rc(_args): + (stdout, stderr, rc) = _args + if rc != 0: + err = ExecError('RC %d: %s %s' % (rc, cmd, ' '.join(args))) + err.stdout = stdout.strip() + err.stderr = stderr.strip() + raise err + + log.msg('# %s' % (' '.join((cmd,) + args),)) + + if stdout and kwargs.get('log_stdout', True): + log.msg(stdout) + + return stdout + + def _got_signal(f): + f.trap(tuple) + stdout, stderr, signum = f.value + err = ExecError('Got signal %d: %s %s' % (signum, cmd, ' '.join(args))) + err.stdout = stdout.strip() + err.stderr = stderr.strip() + raise err + + return utils.getProcessOutputAndValue(cmd, args, env=os.environ).addCallbacks(_check_rc, _got_signal) diff --git a/wfb_ng/mavlink_protocol.py b/wfb_ng/mavlink_protocol.py index b8cf6a6..a8fd2a6 100644 --- a/wfb_ng/mavlink_protocol.py +++ b/wfb_ng/mavlink_protocol.py @@ -18,49 +18,135 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # -from . import mavlink +import struct +from . import call_and_check_rc, ExecError +from .mavlink import MAV_MODE_FLAG_SAFETY_ARMED, MAVLINK_MSG_ID_HEARTBEAT from twisted.python import log -from twisted.internet import reactor, defer +from twisted.internet import reactor, defer, utils from twisted.internet.protocol import Protocol, DatagramProtocol -class MAVLinkProtocol(Protocol): - def __init__(self, src_system, src_component): - self.mav = mavlink.MAVLink(self, srcSystem=src_system, srcComponent=src_component) +def parse_mavlink_l2_v1(msg): + plen, seq, sys_id, comp_id, msg_id = struct.unpack(' 4096: + buffer = buffer[skip:] + skip = 0 + + data = yield mlist + mlist = [] + + if not data: + continue + + buffer.extend(data) + + while len(buffer) - skip >= 8: + version = buffer[skip] + + # mavlink 1 + if version == 0xfe: + mlen = 8 + buffer[skip + 1] + + # mavlink 2 + elif version == 0xfd: + mlen, flags = struct.unpack('BB', buffer[skip + 1 : skip + 3]) + + if flags & ~0x01: + log.msg('Unsupported mavlink flags: 0x%x' % (flags,)) + + mlen += (25 if flags & 0x01 else 12) + else: + skip += 1 + bad += 1 + continue + + if bad: + log.msg('skip %d bad bytes before sync' % (bad,)) + bad = 0 + + if len(buffer) - skip < mlen: + break + + if parse_l2: + mlist.append(parse_map[version](buffer[skip: skip + mlen])) + else: + mlist.append(bytes(buffer[skip: skip + mlen])) + + skip += mlen + + + +class MavlinkARMProtocol(object): + def __init__(self, call_on_arm, call_on_disarm): + self.call_on_arm = call_on_arm + self.call_on_disarm = call_on_disarm + self.armed = None + self.locked = False + self.mavlink_fsm = mavlink_parser_gen(parse_l2=True) + self.mavlink_fsm.send(None) def dataReceived(self, data): - try: - m_list = self.mav.parse_buffer(data) - except mavlink.MAVError as e: - log.msg('Mavlink error: %s' % (e,)) + for l2_headers, m in self.mavlink_fsm.send(data): + self.messageReceived(l2_headers, m) + + def messageReceived(self, l2_headers, message): + seq, sys_id, comp_id, msg_id = l2_headers + + if (sys_id, comp_id, msg_id) != (1, 1, MAVLINK_MSG_ID_HEARTBEAT): return - if m_list is not None: - for m in m_list: - self.messageReceived(m) + armed = bool(message[6] & MAV_MODE_FLAG_SAFETY_ARMED) + if not self.locked: + self.locked = True -class MAVLinkUDPProtocol(MAVLinkProtocol, DatagramProtocol): - def __init__(self, src_system, src_component, peer=None): - MAVLinkProtocol.__init__(self, src_system, src_component) - self.reply_addr = peer + def _unlock(x): + self.locked = False + return x - def write(self, msg): - if self.transport is not None and self.reply_addr is not None: - self.transport.write(msg, self.reply_addr) + return defer.maybeDeferred(self.change_state, armed).addBoth(_unlock) - def datagramReceived(self, data, addr): - self.reply_addr = addr - self.dataReceived(data) + def change_state(self, armed): + if armed == self.armed: + return + self.armed = armed + cmd = None + + if armed: + log.msg('State change: ARMED') + cmd = self.call_on_arm + else: + log.msg('State change: DISARMED') + cmd = self.call_on_disarm + + def on_err(f): + log.msg('Command exec failed: %s' % (f.value,), isError=1) + + if f.value.stdout: + log.msg(f.value.stdout, isError=1) + + if f.value.stderr: + log.msg(f.value.stderr, isError=1) + + if cmd is not None: + return call_and_check_rc(cmd).addErrback(on_err) -class MAVLinkSerialProtocol(MAVLinkProtocol): - def write(self, msg): - if self.transport is not None: - self.transport.write(msg) diff --git a/wfb_ng/proxy.py b/wfb_ng/proxy.py index 88a49c3..d37a52a 100644 --- a/wfb_ng/proxy.py +++ b/wfb_ng/proxy.py @@ -21,44 +21,13 @@ import struct import os +from contextlib import closing from twisted.python import log -from twisted.internet import reactor, defer, utils +from twisted.internet import reactor, defer from twisted.internet.protocol import DatagramProtocol, Protocol -from . import mavlink +from . import mavlink, mavlink_protocol from .conf import settings -from .mavlink_protocol import MAVLinkProtocol - - -class ExecError(Exception): - pass - - -def call_and_check_rc(cmd, *args, **kwargs): - def _check_rc(_args): - (stdout, stderr, rc) = _args - if rc != 0: - err = ExecError('RC %d: %s %s' % (rc, cmd, ' '.join(args))) - err.stdout = stdout.strip() - err.stderr = stderr.strip() - raise err - - log.msg('# %s' % (' '.join((cmd,) + args),)) - - if stdout and kwargs.get('log_stdout', True): - log.msg(stdout) - - return stdout - - def _got_signal(f): - f.trap(tuple) - stdout, stderr, signum = f.value - err = ExecError('Got signal %d: %s %s' % (signum, cmd, ' '.join(args))) - err.stdout = stdout.strip() - err.stderr = stderr.strip() - raise err - - return utils.getProcessOutputAndValue(cmd, args, env=os.environ).addCallbacks(_check_rc, _got_signal) class ProxyProtocol: @@ -212,37 +181,11 @@ class MavlinkUDPProxyProtocol(DatagramProtocol, MavlinkProxyProtocol): # Split batch of mavlink packets due to issues with mavlink-router - i = 0 - while i < len(msg): - if len(msg) - i < 8: - log.msg('Too short mavlink packet: %r' % (msg[i:],)) - break + with closing(mavlink_protocol.mavlink_parser_gen()) as mavlink_fsm: + mavlink_fsm.send(None) - version = struct.unpack('B', msg[i : i + 1])[0] - - # mavlink 1 - if version == 0xfe: - mlen = 8 + struct.unpack('B', msg[i + 1 : i + 2])[0] - self.transport.write(msg[i: i + mlen], self.reply_addr) - i += mlen - - # mavlink 2 - elif version == 0xfd: - mlen, flags = struct.unpack('BB', msg[i + 1 : i + 3]) - - if flags & ~0x01: - log.msg('Unsupported mavlink flags: 0x%x' % (flags,)) - self.transport.write(msg[i:], self.reply_addr) - break - - mlen += (25 if flags & 0x01 else 12) - self.transport.write(msg[i : i + mlen], self.reply_addr) - i += mlen - - else: - log.msg('Unsupported mavlink version: 0x%x' % (version,)) - self.transport.write(msg[i:], self.reply_addr) - break + for m in mavlink_fsm.send(msg): + self.transport.write(m, self.reply_addr) @@ -258,59 +201,9 @@ class MavlinkSerialProxyProtocol(Protocol, MavlinkProxyProtocol): mavlink_sys_id=mavlink_sys_id, mavlink_comp_id=mavlink_comp_id) self.arm_proto = arm_proto - self.mavlink_fsm = self.mavlink_parser() + self.mavlink_fsm = mavlink_protocol.mavlink_parser_gen() self.mavlink_fsm.send(None) - def mavlink_parser(self): - buffer = bytearray() - mlist = [] - skip = 0 - bad = 0 - - while True: - # GC - if skip > 4096: - buffer = buffer[skip:] - skip = 0 - - data = yield mlist - mlist = [] - - if not data: - continue - - buffer.extend(data) - - while len(buffer) - skip >= 8: - version = buffer[skip] - - # mavlink 1 - if version == 0xfe: - mlen = 8 + buffer[skip + 1] - - # mavlink 2 - elif version == 0xfd: - mlen, flags = struct.unpack('BB', buffer[skip + 1 : skip + 3]) - - if flags & ~0x01: - log.msg('Unsupported mavlink flags: 0x%x' % (flags,)) - - mlen += (25 if flags & 0x01 else 12) - else: - skip += 1 - bad += 1 - continue - - if bad: - log.msg('skip %d bad bytes before sync' % (bad,)) - bad = 0 - - if len(buffer) - skip < mlen: - break - - mlist.append(bytes(buffer[skip: skip + mlen])) - skip += mlen - def write(self, msg): if self.arm_proto: self.arm_proto.dataReceived(msg) @@ -319,60 +212,7 @@ class MavlinkSerialProxyProtocol(Protocol, MavlinkProxyProtocol): self.transport.write(msg) def dataReceived(self, data): - m_list = self.mavlink_fsm.send(data) - - for m in m_list: + for m in self.mavlink_fsm.send(data): if self.arm_proto: self.arm_proto.dataReceived(m) self.messageReceived(m) - - -class MavlinkARMProtocol(MAVLinkProtocol): - def __init__(self, call_on_arm, call_on_disarm): - MAVLinkProtocol.__init__(self, None, None) - self.call_on_arm = call_on_arm - self.call_on_disarm = call_on_disarm - self.armed = None - self.locked = False - - def messageReceived(self, message): - if (message._header.msgId, message._header.srcSystem, message._header.srcComponent) != (mavlink.MAVLINK_MSG_ID_HEARTBEAT, 1, 1): - return - - armed = bool(message.base_mode & mavlink.MAV_MODE_FLAG_SAFETY_ARMED) - - if not self.locked: - self.locked = True - - def _unlock(x): - self.locked = False - return x - - return defer.maybeDeferred(self.change_state, armed).addBoth(_unlock) - - def change_state(self, armed): - if armed == self.armed: - return - - self.armed = armed - cmd = None - - if armed: - log.msg('State change: ARMED') - cmd = self.call_on_arm - else: - log.msg('State change: DISARMED') - cmd = self.call_on_disarm - - def on_err(f): - log.msg('Command exec failed: %s' % (f.value,), isError=1) - - if f.value.stdout: - log.msg(f.value.stdout, isError=1) - - if f.value.stderr: - log.msg(f.value.stderr, isError=1) - - if cmd is not None: - return call_and_check_rc(cmd).addErrback(on_err) - diff --git a/wfb_ng/server.py b/wfb_ng/server.py index fff0f4a..5bc2a23 100644 --- a/wfb_ng/server.py +++ b/wfb_ng/server.py @@ -34,9 +34,10 @@ from twisted.protocols.basic import LineReceiver from twisted.internet.error import ReactorNotRunning from twisted.internet.serialport import SerialPort -from . import _log_msg, ConsoleObserver +from . import _log_msg, ConsoleObserver, call_and_check_rc, ExecError from .common import abort_on_crash, exit_status, df_sleep -from .proxy import UDPProxyProtocol, MavlinkSerialProxyProtocol, MavlinkUDPProxyProtocol, MavlinkARMProtocol, call_and_check_rc, ExecError +from .proxy import UDPProxyProtocol, MavlinkSerialProxyProtocol, MavlinkUDPProxyProtocol +from .mavlink_protocol import MavlinkARMProtocol from .tuntap import TUNTAPProtocol, TUNTAPTransport from .conf import settings, cfg_files diff --git a/wfb_ng/tests/test_proxy.py b/wfb_ng/tests/test_proxy.py index 1689514..cc409ab 100644 --- a/wfb_ng/tests/test_proxy.py +++ b/wfb_ng/tests/test_proxy.py @@ -6,7 +6,9 @@ from twisted.python import log from twisted.trial import unittest from twisted.internet import reactor, defer from twisted.internet.protocol import DatagramProtocol +from ..mavlink import MAVLink_heartbeat_message, MAVLink from ..proxy import UDPProxyProtocol, MavlinkUDPProxyProtocol +from ..mavlink_protocol import MavlinkARMProtocol from ..common import df_sleep class Echo(DatagramProtocol): @@ -21,20 +23,27 @@ class SendPacket(DatagramProtocol): self.msg = msg self.addr = addr self.count = count + self.replies = [] def startProtocol(self): - log.msg('send %d of %s to %s' % (self.count, self.msg, self.addr)) + log.msg('send %d of %r to %s' % (self.count, self.msg, self.addr)) for i in range(self.count): self.transport.write(self.msg, self.addr) def datagramReceived(self, data, addr): log.msg("received back %r from %s" % (data, addr)) - self.df.callback((data, addr)) + self.replies.append((data, addr)) + + if len(self.replies) == self.count: + self.df.callback(self.replies) class UDPProxyTestCase(unittest.TestCase): def setUp(self): - self.p1 = MavlinkUDPProxyProtocol(addr=None, mirror=None, arm_proto=None, agg_max_size=1445, agg_timeout=1, inject_rssi=True, mavlink_sys_id=3, mavlink_comp_id=242) + self.arm_proto = MavlinkARMProtocol(call_on_arm='/bin/true', + call_on_disarm='/bin/true') + + self.p1 = MavlinkUDPProxyProtocol(addr=None, mirror=None, arm_proto=self.arm_proto, agg_max_size=1445, agg_timeout=1, inject_rssi=True, mavlink_sys_id=3, mavlink_comp_id=242) self.p2 = UDPProxyProtocol(('127.0.0.1', 14553)) self.p1.peer = self.p2 self.p2.peer = self.p1 @@ -50,15 +59,15 @@ class UDPProxyTestCase(unittest.TestCase): @defer.inlineCallbacks def test_proxy(self): addr = ('127.0.0.1', 14551) - p = SendPacket(b'test', addr, 10) + p = SendPacket(b'\xfd\t\x00\x00\x00\x03\xf2m\x00\x00\x02\x00\x03\x00\x01\x01d\x00\x04\xa8\xad', addr, 10) ep3 = reactor.listenUDP(9999, p) ep4 = reactor.listenUDP(14553, Echo()) try: ts = time.time() - _data, _addr = yield p.df + _replies = yield p.df + _expected = [(b'\xfd\t\x00\x00\x00\x03\xf2m\x00\x00\x02\x00\x03\x00\x01\x01d\x00\x04\xa8\xad', addr)] * 10 self.assertGreater(time.time() - ts, 1.0) - self.assertEqual(_addr, addr) - self.assertEqual(_data, b'test' * 10) + self.assertEqual(_replies, _expected) finally: ep4.stopListening() ep3.stopListening() @@ -74,9 +83,34 @@ class UDPProxyTestCase(unittest.TestCase): try: self.p1.send_rssi(1, 2, 3, 4) ts = time.time() - _data, _addr = yield p.df + _replies = yield p.df + _expected = [(b'\xfd\t\x00\x00\x00\x03\xf2m\x00\x00\x02\x00\x03\x00\x01\x01d\x00\x04\xa8\xad', addr)] self.assertLess(time.time() - ts, 1.0) - self.assertEqual(_addr, addr) - self.assertEqual(_data, b'\xfd\t\x00\x00\x00\x03\xf2m\x00\x00\x02\x00\x03\x00\x01\x01d\x00\x04\xa8\xad') + self.assertEqual(_replies, _expected) finally: ep3.stopListening() + + @defer.inlineCallbacks + def __test_arm_protocol(self, force_mavlink1): + addr = ('127.0.0.1', 14551) + mav = MAVLink(None, srcSystem=1, srcComponent=1) + msg = MAVLink_heartbeat_message(1, 8, 128, 0, 0, 1).pack(mav, force_mavlink1=force_mavlink1) + + p = SendPacket(msg, addr) + + ep3 = reactor.listenUDP(9999, p) + ep4 = reactor.listenUDP(14553, Echo()) + try: + ts = time.time() + yield p.df + self.assertEqual(self.arm_proto.armed, True) + finally: + ep3.stopListening() + ep4.stopListening() + + + def test_arm_protocol_mav1(self): + return self.__test_arm_protocol(True) + + def test_arm_protocol_mav2(self): + return self.__test_arm_protocol(False)