Refactor mavlink parser

This commit is contained in:
Vasily Evseenko 2023-02-01 17:42:07 +03:00
parent bf894cede1
commit d41554d880
5 changed files with 206 additions and 213 deletions

View File

@ -1,9 +1,10 @@
import sys
import os
from twisted.internet import utils
from logging import currentframe
from twisted.python import log as twisted_log
from twisted.python import log
__orig_msg = twisted_log.msg
__orig_msg = log.msg
_srcfile = os.path.splitext(os.path.normcase(__file__))[0] + '.py'
# Returns escape codes from format codes
@ -132,3 +133,34 @@ def _log_msg(*args, **kwargs):
kwargs['system'] = '%s #%s' % ('.'.join(tmp), log_level_map[level])
return __orig_msg(*args, **kwargs)
class ExecError(Exception):
pass
def call_and_check_rc(cmd, *args, **kwargs):
def _check_rc(_args):
(stdout, stderr, rc) = _args
if rc != 0:
err = ExecError('RC %d: %s %s' % (rc, cmd, ' '.join(args)))
err.stdout = stdout.strip()
err.stderr = stderr.strip()
raise err
log.msg('# %s' % (' '.join((cmd,) + args),))
if stdout and kwargs.get('log_stdout', True):
log.msg(stdout)
return stdout
def _got_signal(f):
f.trap(tuple)
stdout, stderr, signum = f.value
err = ExecError('Got signal %d: %s %s' % (signum, cmd, ' '.join(args)))
err.stdout = stdout.strip()
err.stderr = stderr.strip()
raise err
return utils.getProcessOutputAndValue(cmd, args, env=os.environ).addCallbacks(_check_rc, _got_signal)

View File

@ -18,49 +18,135 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
from . import mavlink
import struct
from . import call_and_check_rc, ExecError
from .mavlink import MAV_MODE_FLAG_SAFETY_ARMED, MAVLINK_MSG_ID_HEARTBEAT
from twisted.python import log
from twisted.internet import reactor, defer
from twisted.internet import reactor, defer, utils
from twisted.internet.protocol import Protocol, DatagramProtocol
class MAVLinkProtocol(Protocol):
def __init__(self, src_system, src_component):
self.mav = mavlink.MAVLink(self, srcSystem=src_system, srcComponent=src_component)
def parse_mavlink_l2_v1(msg):
plen, seq, sys_id, comp_id, msg_id = struct.unpack('<BBBBB', msg[1:6])
return ((seq, sys_id, comp_id, msg_id), bytes(msg[6:6 + plen]))
def write(self, msg):
raise NotImplementedError
def messageReceived(self, message):
raise NotImplementedError
def parse_mavlink_l2_v2(msg):
plen, iflags, cflags, seq, sys_id, comp_id, msg_id_low, msg_id_high = struct.unpack('<BBBBBBHB', msg[1:10])
return ((seq, sys_id, comp_id, msg_id_low + (msg_id_high << 16)), bytes(msg[10:10 + plen]))
def mavlink_parser_gen(parse_l2=False):
buffer = bytearray()
mlist = []
skip = 0
bad = 0
parse_map = { 0xfe: parse_mavlink_l2_v1,
0xfd: parse_mavlink_l2_v2 }
while True:
# GC
if skip > 4096:
buffer = buffer[skip:]
skip = 0
data = yield mlist
mlist = []
if not data:
continue
buffer.extend(data)
while len(buffer) - skip >= 8:
version = buffer[skip]
# mavlink 1
if version == 0xfe:
mlen = 8 + buffer[skip + 1]
# mavlink 2
elif version == 0xfd:
mlen, flags = struct.unpack('BB', buffer[skip + 1 : skip + 3])
if flags & ~0x01:
log.msg('Unsupported mavlink flags: 0x%x' % (flags,))
mlen += (25 if flags & 0x01 else 12)
else:
skip += 1
bad += 1
continue
if bad:
log.msg('skip %d bad bytes before sync' % (bad,))
bad = 0
if len(buffer) - skip < mlen:
break
if parse_l2:
mlist.append(parse_map[version](buffer[skip: skip + mlen]))
else:
mlist.append(bytes(buffer[skip: skip + mlen]))
skip += mlen
class MavlinkARMProtocol(object):
def __init__(self, call_on_arm, call_on_disarm):
self.call_on_arm = call_on_arm
self.call_on_disarm = call_on_disarm
self.armed = None
self.locked = False
self.mavlink_fsm = mavlink_parser_gen(parse_l2=True)
self.mavlink_fsm.send(None)
def dataReceived(self, data):
try:
m_list = self.mav.parse_buffer(data)
except mavlink.MAVError as e:
log.msg('Mavlink error: %s' % (e,))
for l2_headers, m in self.mavlink_fsm.send(data):
self.messageReceived(l2_headers, m)
def messageReceived(self, l2_headers, message):
seq, sys_id, comp_id, msg_id = l2_headers
if (sys_id, comp_id, msg_id) != (1, 1, MAVLINK_MSG_ID_HEARTBEAT):
return
if m_list is not None:
for m in m_list:
self.messageReceived(m)
armed = bool(message[6] & MAV_MODE_FLAG_SAFETY_ARMED)
if not self.locked:
self.locked = True
class MAVLinkUDPProtocol(MAVLinkProtocol, DatagramProtocol):
def __init__(self, src_system, src_component, peer=None):
MAVLinkProtocol.__init__(self, src_system, src_component)
self.reply_addr = peer
def _unlock(x):
self.locked = False
return x
def write(self, msg):
if self.transport is not None and self.reply_addr is not None:
self.transport.write(msg, self.reply_addr)
return defer.maybeDeferred(self.change_state, armed).addBoth(_unlock)
def datagramReceived(self, data, addr):
self.reply_addr = addr
self.dataReceived(data)
def change_state(self, armed):
if armed == self.armed:
return
self.armed = armed
cmd = None
if armed:
log.msg('State change: ARMED')
cmd = self.call_on_arm
else:
log.msg('State change: DISARMED')
cmd = self.call_on_disarm
def on_err(f):
log.msg('Command exec failed: %s' % (f.value,), isError=1)
if f.value.stdout:
log.msg(f.value.stdout, isError=1)
if f.value.stderr:
log.msg(f.value.stderr, isError=1)
if cmd is not None:
return call_and_check_rc(cmd).addErrback(on_err)
class MAVLinkSerialProtocol(MAVLinkProtocol):
def write(self, msg):
if self.transport is not None:
self.transport.write(msg)

View File

@ -21,44 +21,13 @@
import struct
import os
from contextlib import closing
from twisted.python import log
from twisted.internet import reactor, defer, utils
from twisted.internet import reactor, defer
from twisted.internet.protocol import DatagramProtocol, Protocol
from . import mavlink
from . import mavlink, mavlink_protocol
from .conf import settings
from .mavlink_protocol import MAVLinkProtocol
class ExecError(Exception):
pass
def call_and_check_rc(cmd, *args, **kwargs):
def _check_rc(_args):
(stdout, stderr, rc) = _args
if rc != 0:
err = ExecError('RC %d: %s %s' % (rc, cmd, ' '.join(args)))
err.stdout = stdout.strip()
err.stderr = stderr.strip()
raise err
log.msg('# %s' % (' '.join((cmd,) + args),))
if stdout and kwargs.get('log_stdout', True):
log.msg(stdout)
return stdout
def _got_signal(f):
f.trap(tuple)
stdout, stderr, signum = f.value
err = ExecError('Got signal %d: %s %s' % (signum, cmd, ' '.join(args)))
err.stdout = stdout.strip()
err.stderr = stderr.strip()
raise err
return utils.getProcessOutputAndValue(cmd, args, env=os.environ).addCallbacks(_check_rc, _got_signal)
class ProxyProtocol:
@ -212,37 +181,11 @@ class MavlinkUDPProxyProtocol(DatagramProtocol, MavlinkProxyProtocol):
# Split batch of mavlink packets due to issues with mavlink-router
i = 0
while i < len(msg):
if len(msg) - i < 8:
log.msg('Too short mavlink packet: %r' % (msg[i:],))
break
with closing(mavlink_protocol.mavlink_parser_gen()) as mavlink_fsm:
mavlink_fsm.send(None)
version = struct.unpack('B', msg[i : i + 1])[0]
# mavlink 1
if version == 0xfe:
mlen = 8 + struct.unpack('B', msg[i + 1 : i + 2])[0]
self.transport.write(msg[i: i + mlen], self.reply_addr)
i += mlen
# mavlink 2
elif version == 0xfd:
mlen, flags = struct.unpack('BB', msg[i + 1 : i + 3])
if flags & ~0x01:
log.msg('Unsupported mavlink flags: 0x%x' % (flags,))
self.transport.write(msg[i:], self.reply_addr)
break
mlen += (25 if flags & 0x01 else 12)
self.transport.write(msg[i : i + mlen], self.reply_addr)
i += mlen
else:
log.msg('Unsupported mavlink version: 0x%x' % (version,))
self.transport.write(msg[i:], self.reply_addr)
break
for m in mavlink_fsm.send(msg):
self.transport.write(m, self.reply_addr)
@ -258,59 +201,9 @@ class MavlinkSerialProxyProtocol(Protocol, MavlinkProxyProtocol):
mavlink_sys_id=mavlink_sys_id, mavlink_comp_id=mavlink_comp_id)
self.arm_proto = arm_proto
self.mavlink_fsm = self.mavlink_parser()
self.mavlink_fsm = mavlink_protocol.mavlink_parser_gen()
self.mavlink_fsm.send(None)
def mavlink_parser(self):
buffer = bytearray()
mlist = []
skip = 0
bad = 0
while True:
# GC
if skip > 4096:
buffer = buffer[skip:]
skip = 0
data = yield mlist
mlist = []
if not data:
continue
buffer.extend(data)
while len(buffer) - skip >= 8:
version = buffer[skip]
# mavlink 1
if version == 0xfe:
mlen = 8 + buffer[skip + 1]
# mavlink 2
elif version == 0xfd:
mlen, flags = struct.unpack('BB', buffer[skip + 1 : skip + 3])
if flags & ~0x01:
log.msg('Unsupported mavlink flags: 0x%x' % (flags,))
mlen += (25 if flags & 0x01 else 12)
else:
skip += 1
bad += 1
continue
if bad:
log.msg('skip %d bad bytes before sync' % (bad,))
bad = 0
if len(buffer) - skip < mlen:
break
mlist.append(bytes(buffer[skip: skip + mlen]))
skip += mlen
def write(self, msg):
if self.arm_proto:
self.arm_proto.dataReceived(msg)
@ -319,60 +212,7 @@ class MavlinkSerialProxyProtocol(Protocol, MavlinkProxyProtocol):
self.transport.write(msg)
def dataReceived(self, data):
m_list = self.mavlink_fsm.send(data)
for m in m_list:
for m in self.mavlink_fsm.send(data):
if self.arm_proto:
self.arm_proto.dataReceived(m)
self.messageReceived(m)
class MavlinkARMProtocol(MAVLinkProtocol):
def __init__(self, call_on_arm, call_on_disarm):
MAVLinkProtocol.__init__(self, None, None)
self.call_on_arm = call_on_arm
self.call_on_disarm = call_on_disarm
self.armed = None
self.locked = False
def messageReceived(self, message):
if (message._header.msgId, message._header.srcSystem, message._header.srcComponent) != (mavlink.MAVLINK_MSG_ID_HEARTBEAT, 1, 1):
return
armed = bool(message.base_mode & mavlink.MAV_MODE_FLAG_SAFETY_ARMED)
if not self.locked:
self.locked = True
def _unlock(x):
self.locked = False
return x
return defer.maybeDeferred(self.change_state, armed).addBoth(_unlock)
def change_state(self, armed):
if armed == self.armed:
return
self.armed = armed
cmd = None
if armed:
log.msg('State change: ARMED')
cmd = self.call_on_arm
else:
log.msg('State change: DISARMED')
cmd = self.call_on_disarm
def on_err(f):
log.msg('Command exec failed: %s' % (f.value,), isError=1)
if f.value.stdout:
log.msg(f.value.stdout, isError=1)
if f.value.stderr:
log.msg(f.value.stderr, isError=1)
if cmd is not None:
return call_and_check_rc(cmd).addErrback(on_err)

View File

@ -34,9 +34,10 @@ from twisted.protocols.basic import LineReceiver
from twisted.internet.error import ReactorNotRunning
from twisted.internet.serialport import SerialPort
from . import _log_msg, ConsoleObserver
from . import _log_msg, ConsoleObserver, call_and_check_rc, ExecError
from .common import abort_on_crash, exit_status, df_sleep
from .proxy import UDPProxyProtocol, MavlinkSerialProxyProtocol, MavlinkUDPProxyProtocol, MavlinkARMProtocol, call_and_check_rc, ExecError
from .proxy import UDPProxyProtocol, MavlinkSerialProxyProtocol, MavlinkUDPProxyProtocol
from .mavlink_protocol import MavlinkARMProtocol
from .tuntap import TUNTAPProtocol, TUNTAPTransport
from .conf import settings, cfg_files

View File

@ -6,7 +6,9 @@ from twisted.python import log
from twisted.trial import unittest
from twisted.internet import reactor, defer
from twisted.internet.protocol import DatagramProtocol
from ..mavlink import MAVLink_heartbeat_message, MAVLink
from ..proxy import UDPProxyProtocol, MavlinkUDPProxyProtocol
from ..mavlink_protocol import MavlinkARMProtocol
from ..common import df_sleep
class Echo(DatagramProtocol):
@ -21,20 +23,27 @@ class SendPacket(DatagramProtocol):
self.msg = msg
self.addr = addr
self.count = count
self.replies = []
def startProtocol(self):
log.msg('send %d of %s to %s' % (self.count, self.msg, self.addr))
log.msg('send %d of %r to %s' % (self.count, self.msg, self.addr))
for i in range(self.count):
self.transport.write(self.msg, self.addr)
def datagramReceived(self, data, addr):
log.msg("received back %r from %s" % (data, addr))
self.df.callback((data, addr))
self.replies.append((data, addr))
if len(self.replies) == self.count:
self.df.callback(self.replies)
class UDPProxyTestCase(unittest.TestCase):
def setUp(self):
self.p1 = MavlinkUDPProxyProtocol(addr=None, mirror=None, arm_proto=None, agg_max_size=1445, agg_timeout=1, inject_rssi=True, mavlink_sys_id=3, mavlink_comp_id=242)
self.arm_proto = MavlinkARMProtocol(call_on_arm='/bin/true',
call_on_disarm='/bin/true')
self.p1 = MavlinkUDPProxyProtocol(addr=None, mirror=None, arm_proto=self.arm_proto, agg_max_size=1445, agg_timeout=1, inject_rssi=True, mavlink_sys_id=3, mavlink_comp_id=242)
self.p2 = UDPProxyProtocol(('127.0.0.1', 14553))
self.p1.peer = self.p2
self.p2.peer = self.p1
@ -50,15 +59,15 @@ class UDPProxyTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_proxy(self):
addr = ('127.0.0.1', 14551)
p = SendPacket(b'test', addr, 10)
p = SendPacket(b'\xfd\t\x00\x00\x00\x03\xf2m\x00\x00\x02\x00\x03\x00\x01\x01d\x00\x04\xa8\xad', addr, 10)
ep3 = reactor.listenUDP(9999, p)
ep4 = reactor.listenUDP(14553, Echo())
try:
ts = time.time()
_data, _addr = yield p.df
_replies = yield p.df
_expected = [(b'\xfd\t\x00\x00\x00\x03\xf2m\x00\x00\x02\x00\x03\x00\x01\x01d\x00\x04\xa8\xad', addr)] * 10
self.assertGreater(time.time() - ts, 1.0)
self.assertEqual(_addr, addr)
self.assertEqual(_data, b'test' * 10)
self.assertEqual(_replies, _expected)
finally:
ep4.stopListening()
ep3.stopListening()
@ -74,9 +83,34 @@ class UDPProxyTestCase(unittest.TestCase):
try:
self.p1.send_rssi(1, 2, 3, 4)
ts = time.time()
_data, _addr = yield p.df
_replies = yield p.df
_expected = [(b'\xfd\t\x00\x00\x00\x03\xf2m\x00\x00\x02\x00\x03\x00\x01\x01d\x00\x04\xa8\xad', addr)]
self.assertLess(time.time() - ts, 1.0)
self.assertEqual(_addr, addr)
self.assertEqual(_data, b'\xfd\t\x00\x00\x00\x03\xf2m\x00\x00\x02\x00\x03\x00\x01\x01d\x00\x04\xa8\xad')
self.assertEqual(_replies, _expected)
finally:
ep3.stopListening()
@defer.inlineCallbacks
def __test_arm_protocol(self, force_mavlink1):
addr = ('127.0.0.1', 14551)
mav = MAVLink(None, srcSystem=1, srcComponent=1)
msg = MAVLink_heartbeat_message(1, 8, 128, 0, 0, 1).pack(mav, force_mavlink1=force_mavlink1)
p = SendPacket(msg, addr)
ep3 = reactor.listenUDP(9999, p)
ep4 = reactor.listenUDP(14553, Echo())
try:
ts = time.time()
yield p.df
self.assertEqual(self.arm_proto.armed, True)
finally:
ep3.stopListening()
ep4.stopListening()
def test_arm_protocol_mav1(self):
return self.__test_arm_protocol(True)
def test_arm_protocol_mav2(self):
return self.__test_arm_protocol(False)