diff --git a/libraries/AP_Scripting/AP_Scripting.cpp b/libraries/AP_Scripting/AP_Scripting.cpp index 347a7d440f..82c6f220cb 100644 --- a/libraries/AP_Scripting/AP_Scripting.cpp +++ b/libraries/AP_Scripting/AP_Scripting.cpp @@ -53,6 +53,10 @@ static_assert(SCRIPTING_STACK_SIZE <= SCRIPTING_STACK_MAX_SIZE, "Scripting requi #define SCRIPTING_ENABLE_DEFAULT 0 #endif +#if AP_NETWORKING_ENABLED +#include +#endif + extern const AP_HAL::HAL& hal; const AP_Param::GroupInfo AP_Scripting::var_info[] = { @@ -287,6 +291,17 @@ void AP_Scripting::thread(void) { } num_pwm_source = 0; +#if AP_NETWORKING_ENABLED + // clear allocated sockets + for (uint8_t i=0; i #include #include "AP_Scripting_CANSensor.h" +#include #ifndef SCRIPTING_MAX_NUM_I2C_DEVICE #define SCRIPTING_MAX_NUM_I2C_DEVICE 4 @@ -31,6 +32,13 @@ #define SCRIPTING_MAX_NUM_PWM_SOURCE 4 +#if AP_NETWORKING_ENABLED +#ifndef SCRIPTING_MAX_NUM_NET_SOCKET +#define SCRIPTING_MAX_NUM_NET_SOCKET 50 +#endif +class SocketAPM; +#endif + class AP_Scripting { public: @@ -111,6 +119,12 @@ public: int get_current_ref() { return current_ref; } void set_current_ref(int ref) { current_ref = ref; } +#if AP_NETWORKING_ENABLED + // SocketAPM storage + uint8_t num_net_sockets; + SocketAPM *_net_sockets[SCRIPTING_MAX_NUM_NET_SOCKET]; +#endif + struct mavlink_msg { mavlink_message_t msg; mavlink_channel_t chan; diff --git a/libraries/AP_Scripting/docs/docs.lua b/libraries/AP_Scripting/docs/docs.lua index 40ddb7ac9c..4f70ca6067 100644 --- a/libraries/AP_Scripting/docs/docs.lua +++ b/libraries/AP_Scripting/docs/docs.lua @@ -452,6 +452,51 @@ function motor_factor_table_ud:roll(index) end ---@param value number function motor_factor_table_ud:roll(index, value) end +-- network socket class +---@class SocketAPM_ud +local SocketAPM_ud = {} + +-- desc +---@return boolean +function SocketAPM_ud:is_connected() end + +-- desc +---@param param1 boolean +---@return boolean +function SocketAPM_ud:set_blocking(param1) end + +-- desc +---@param param1 integer +---@return boolean +function SocketAPM_ud:listen(param1) end + +-- desc +---@param param1 string +---@param param2 uint32_t_ud +---@return integer +function SocketAPM_ud:send(param1, param2) end + +-- desc +---@param param1 string +---@param param2 integer +---@return boolean +function SocketAPM_ud:bind(param1, param2) end + +-- desc +---@param param1 string +---@param param2 integer +---@return boolean +function SocketAPM_ud:connect(param1, param2) end + +-- desc +function SocketAPM_ud:__gc() end + +-- desc +function SocketAPM_ud:accept(param1) end + +-- desc +function SocketAPM_ud:recv(param1) end + -- desc ---@class AP_HAL__PWMSource_ud diff --git a/libraries/AP_Scripting/examples/net_test.lua b/libraries/AP_Scripting/examples/net_test.lua new file mode 100644 index 0000000000..e40f4d59aa --- /dev/null +++ b/libraries/AP_Scripting/examples/net_test.lua @@ -0,0 +1,140 @@ +--[[ + example script to test lua socket API +--]] + +local MAV_SEVERITY = {EMERGENCY=0, ALERT=1, CRITICAL=2, ERROR=3, WARNING=4, NOTICE=5, INFO=6, DEBUG=7} + +PARAM_TABLE_KEY = 46 +PARAM_TABLE_PREFIX = "NT_" + +-- bind a parameter to a variable given +function bind_param(name) + local p = Parameter() + assert(p:init(name), string.format('could not find %s parameter', name)) + return p +end + +-- add a parameter and bind it to a variable +function bind_add_param(name, idx, default_value) + assert(param:add_param(PARAM_TABLE_KEY, idx, name, default_value), string.format('could not add param %s', name)) + return bind_param(PARAM_TABLE_PREFIX .. name) +end + +-- Setup Parameters +assert(param:add_table(PARAM_TABLE_KEY, PARAM_TABLE_PREFIX, 6), 'net_test: could not add param table') + +--[[ + // @Param: NT_ENABLE + // @DisplayName: enable network tests + // @Description: Enable network tests + // @Values: 0:Disabled,1:Enabled + // @User: Standard +--]] +local NT_ENABLE = bind_add_param('ENABLE', 1, 0) +if NT_ENABLE:get() == 0 then + return +end + +local NT_TEST_IP = { bind_add_param('TEST_IP0', 2, 192), + bind_add_param('TEST_IP1', 3, 168), + bind_add_param('TEST_IP2', 4, 13), + bind_add_param('TEST_IP3', 5, 15) } + +local NT_BIND_PORT = bind_add_param('BIND_PORT', 6, 15001) + +local PORT_ECHO = 7 + +gcs:send_text(MAV_SEVERITY.INFO, "net_test: starting") + +local function test_ip() + return string.format("%u.%u.%u.%u", NT_TEST_IP[1]:get(), NT_TEST_IP[2]:get(), NT_TEST_IP[3]:get(), NT_TEST_IP[4]:get()) +end + +local counter = 0 +local sock_tcp_echo = SocketAPM(0) +local sock_udp_echo = SocketAPM(1) +local sock_tcp_in = SocketAPM(0) +local sock_tcp_in2 = nil +local sock_udp_in = SocketAPM(1) + +if not sock_tcp_echo then + gcs:send_text(MAV_SEVERITY.ERROR, "net_test: failed to create tcp echo socket") + return +end + +if not sock_udp_echo then + gcs:send_text(MAV_SEVERITY.ERROR, "net_test: failed to create udp echo socket") + return +end + +if not sock_tcp_in:bind("0.0.0.0", NT_BIND_PORT:get()) then + gcs:send_text(MAV_SEVERITY.ERROR, "net_test: failed to bind to TCP 5001") +end + +if not sock_tcp_in:listen(1) then + gcs:send_text(MAV_SEVERITY.ERROR, "net_test: failed to listen") +end + +if not sock_udp_in:bind("0.0.0.0", NT_BIND_PORT:get()) then + gcs:send_text(MAV_SEVERITY.ERROR, "net_test: failed to bind to UDP 5001") +end + +--[[ + test TCP or UDP echo +--]] +local function test_echo(name, sock) + if not sock:is_connected() then + if not sock:connect(test_ip(), PORT_ECHO) then + gcs:send_text(MAV_SEVERITY.ERROR, string.format("test_echo(%s): failed to connect", name)) + return + end + + if not sock:set_blocking(true) then + gcs:send_text(MAV_SEVERITY.ERROR, string.format("test_echo(%s): failed to set blocking", name)) + return + end + end + + local s = string.format("testing %u", counter) + local nsent = sock:send(s, #s) + if nsent ~= #s then + gcs:send_text(MAV_SEVERITY.ERROR, string.format("test_echo(%s): failed to send", name)) + return + end + local r = sock:recv(#s) + if r then + gcs:send_text(MAV_SEVERITY.ERROR, string.format("test_echo(%s): got reply '%s'", name, r)) + end +end + +--[[ + test a simple server +--]] +local function test_server(name, sock) + if name == "TCP" then + if not sock_tcp_in2 then + sock_tcp_in2 = sock:accept() + if not sock_tcp_in2 then + return + end + gcs:send_text(MAV_SEVERITY.ERROR, string.format("test_server(%s): new connection", name)) + end + sock = sock_tcp_in2 + end + + local r = sock:recv(1024) + if r and #r > 0 then + gcs:send_text(MAV_SEVERITY.ERROR, string.format("test_server(%s): got input '%s'", name, r)) + end +end + +local function update() + test_echo("TCP", sock_tcp_echo) + test_echo("UDP", sock_udp_echo) + test_server("TCP", sock_tcp_in) + test_server("UDP", sock_udp_in) + counter = counter + 1 + return update,1000 +end + +return update,100 diff --git a/libraries/AP_Scripting/generator/description/bindings.desc b/libraries/AP_Scripting/generator/description/bindings.desc index a37476a525..d2c25274b3 100644 --- a/libraries/AP_Scripting/generator/description/bindings.desc +++ b/libraries/AP_Scripting/generator/description/bindings.desc @@ -536,6 +536,19 @@ ap_object AP_HAL::I2CDevice method write_register boolean uint8_t'skip_check uin ap_object AP_HAL::I2CDevice manual read_registers AP_HAL__I2CDevice_read_registers 2 ap_object AP_HAL::I2CDevice method set_address void uint8_t'skip_check +include AP_HAL/utility/Socket.h depends (AP_NETWORKING_ENABLED==1) +global manual SocketAPM lua_get_SocketAPM 1 depends (AP_NETWORKING_ENABLED==1) + +ap_object SocketAPM depends (AP_NETWORKING_ENABLED==1) +ap_object SocketAPM method connect boolean string uint16_t'skip_check +ap_object SocketAPM method bind boolean string uint16_t'skip_check +ap_object SocketAPM method send int32_t string uint32_t'skip_check +ap_object SocketAPM method listen boolean uint8_t'skip_check +ap_object SocketAPM method set_blocking boolean boolean +ap_object SocketAPM method is_connected boolean +ap_object SocketAPM manual close SocketAPM_close 0 +ap_object SocketAPM manual recv SocketAPM_recv 1 +ap_object SocketAPM manual accept SocketAPM_accept 1 ap_object AP_HAL::AnalogSource depends !defined(HAL_DISABLE_ADC_DRIVER) ap_object AP_HAL::AnalogSource method set_pin boolean uint8_t'skip_check diff --git a/libraries/AP_Scripting/lua_bindings.cpp b/libraries/AP_Scripting/lua_bindings.cpp index 1ff491ed74..a3d574e8ed 100644 --- a/libraries/AP_Scripting/lua_bindings.cpp +++ b/libraries/AP_Scripting/lua_bindings.cpp @@ -15,6 +15,11 @@ #include #include +#include +#if AP_NETWORKING_ENABLED +#include +#endif + extern const AP_HAL::HAL& hal; extern "C" { @@ -755,6 +760,121 @@ int lua_get_PWMSource(lua_State *L) { return 1; } +#if AP_NETWORKING_ENABLED +/* + allocate a SocketAPM + */ +int lua_get_SocketAPM(lua_State *L) { + binding_argcheck(L, 1); + const uint8_t datagram = get_uint8_t(L, 1); + auto *scripting = AP::scripting(); + + lua_gc(L, LUA_GCCOLLECT, 0); + + if (scripting->num_net_sockets >= SCRIPTING_MAX_NUM_NET_SOCKET) { + return luaL_argerror(L, 1, "no sockets available"); + } + + auto *sock = new SocketAPM(datagram); + if (sock == nullptr) { + return luaL_argerror(L, 1, "SocketAPM device nullptr"); + } + scripting->_net_sockets[scripting->num_net_sockets] = sock; + + new_SocketAPM(L); + *((SocketAPM**)luaL_checkudata(L, -1, "SocketAPM")) = scripting->_net_sockets[scripting->num_net_sockets]; + + scripting->num_net_sockets++; + + return 1; +} + +/* + socket close + */ +int SocketAPM_close(lua_State *L) { + binding_argcheck(L, 1); + + SocketAPM *ud = *check_SocketAPM(L, 1); + + auto *scripting = AP::scripting(); + + if (scripting->num_net_sockets == 0) { + return luaL_argerror(L, 1, "socket close error"); + } + + // clear allocated socket + for (uint8_t i=0; i_net_sockets[i] == ud) { + ud->close(); + delete ud; + scripting->_net_sockets[i] = nullptr; + scripting->num_net_sockets--; + break; + } + } + + return 0; +} + +/* + receive from a socket to a lua string + */ +int SocketAPM_recv(lua_State *L) { + binding_argcheck(L, 2); + + SocketAPM * ud = *check_SocketAPM(L, 1); + + const uint16_t count = get_uint16_t(L, 2); + uint8_t *data = (uint8_t*)malloc(count); + if (data == nullptr) { + return 0; + } + + const auto ret = ud->recv(data, count, 0); + if (ret < 0) { + free(data); + return 0; + } + + // push to lua string + lua_pushlstring(L, (const char *)data, ret); + free(data); + + return 1; +} + +/* + TCP socket accept() call + */ +int SocketAPM_accept(lua_State *L) { + binding_argcheck(L, 1); + + SocketAPM * ud = *check_SocketAPM(L, 1); + + auto *scripting = AP::scripting(); + if (scripting->num_net_sockets >= SCRIPTING_MAX_NUM_NET_SOCKET) { + return luaL_argerror(L, 1, "no sockets available"); + } + + auto *sock = ud->accept(0); + if (sock == nullptr) { + return 0; + } + + scripting->_net_sockets[scripting->num_net_sockets] = sock; + + new_SocketAPM(L); + *((SocketAPM**)luaL_checkudata(L, -1, "SocketAPM")) = scripting->_net_sockets[scripting->num_net_sockets]; + + scripting->num_net_sockets++; + + return 1; +} + +#endif // AP_NETWORKING_ENABLED + + int lua_get_current_ref() { auto *scripting = AP::scripting(); diff --git a/libraries/AP_Scripting/lua_bindings.h b/libraries/AP_Scripting/lua_bindings.h index 8eef08bca3..797a551b20 100644 --- a/libraries/AP_Scripting/lua_bindings.h +++ b/libraries/AP_Scripting/lua_bindings.h @@ -14,6 +14,10 @@ int lua_dirlist(lua_State *L); int lua_removefile(lua_State *L); int SRV_Channels_get_safety_state(lua_State *L); int lua_get_PWMSource(lua_State *L); +int lua_get_SocketAPM(lua_State *L); +int SocketAPM_recv(lua_State *L); +int SocketAPM_accept(lua_State *L); +int SocketAPM_close(lua_State *L); int lua_mavlink_init(lua_State *L); int lua_mavlink_receive_chan(lua_State *L); int lua_mavlink_register_rx_msgid(lua_State *L);