AP_HAL: make Socket.cpp safe for lwip and SITL usage

This commit is contained in:
Andrew Tridgell 2023-12-12 19:02:42 +11:00
parent 896b95654c
commit e53729f331
2 changed files with 40 additions and 26 deletions

View File

@ -21,9 +21,23 @@
#if AP_NETWORKING_SOCKETS_ENABLED #if AP_NETWORKING_SOCKETS_ENABLED
#include "Socket.h" #include "Socket.h"
#if AP_NETWORKING_BACKEND_CHIBIOS || AP_NETWORKING_BACKEND_PPP
#include <lwip/sockets.h>
#else
// SITL or Linux
#include <fcntl.h>
#include <unistd.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <sys/select.h>
#endif
#include <errno.h> #include <errno.h>
#if AP_NETWORKING_BACKEND_CHIBIOS #if AP_NETWORKING_BACKEND_CHIBIOS || AP_NETWORKING_BACKEND_PPP
#define CALL_PREFIX(x) ::lwip_##x #define CALL_PREFIX(x) ::lwip_##x
#else #else
#define CALL_PREFIX(x) ::x #define CALL_PREFIX(x) ::x
@ -33,6 +47,8 @@
#define MSG_NOSIGNAL 0 #define MSG_NOSIGNAL 0
#endif #endif
static_assert(sizeof(last_in_addr) == sizeof(struct sockaddr_in), "last_in_addr must match sockaddr_in size");
/* /*
constructor constructor
*/ */
@ -185,7 +201,7 @@ bool SocketAPM::connect_timeout(const char *address, uint16_t port, uint32_t tim
} }
int sock_error = 0; int sock_error = 0;
socklen_t len = sizeof(sock_error); socklen_t len = sizeof(sock_error);
if (getsockopt(fd, SOL_SOCKET, SO_ERROR, (void*)&sock_error, &len) != 0) { if (CALL_PREFIX(getsockopt)(fd, SOL_SOCKET, SO_ERROR, (void*)&sock_error, &len) != 0) {
return false; return false;
} }
connected = sock_error == 0; connected = sock_error == 0;
@ -294,10 +310,10 @@ ssize_t SocketAPM::recv(void *buf, size_t size, uint32_t timeout_ms)
errno = EWOULDBLOCK; errno = EWOULDBLOCK;
return -1; return -1;
} }
socklen_t len = sizeof(in_addr); socklen_t len = sizeof(last_in_addr);
int fin = get_read_fd(); int fin = get_read_fd();
ssize_t ret; ssize_t ret;
ret = CALL_PREFIX(recvfrom)(fin, buf, size, MSG_DONTWAIT, (sockaddr *)&in_addr, &len); ret = CALL_PREFIX(recvfrom)(fin, buf, size, MSG_DONTWAIT, (sockaddr *)&last_in_addr, &len);
if (ret <= 0) { if (ret <= 0) {
if (!datagram && connected && ret == 0) { if (!datagram && connected && ret == 0) {
// remote host has closed connection // remote host has closed connection
@ -314,9 +330,9 @@ ssize_t SocketAPM::recv(void *buf, size_t size, uint32_t timeout_ms)
if (CALL_PREFIX(getsockname)(fd, (struct sockaddr *)&send_addr, &send_len) != 0) { if (CALL_PREFIX(getsockname)(fd, (struct sockaddr *)&send_addr, &send_len) != 0) {
return -1; return -1;
} }
if (in_addr.sin_port == send_addr.sin_port && if (last_in_addr.sin_port == send_addr.sin_port &&
in_addr.sin_family == send_addr.sin_family && last_in_addr.sin_family == send_addr.sin_family &&
in_addr.sin_addr.s_addr == send_addr.sin_addr.s_addr) { last_in_addr.sin_addr.s_addr == send_addr.sin_addr.s_addr) {
// discard packets from ourselves // discard packets from ourselves
return -1; return -1;
} }
@ -329,8 +345,9 @@ ssize_t SocketAPM::recv(void *buf, size_t size, uint32_t timeout_ms)
*/ */
void SocketAPM::last_recv_address(const char *&ip_addr, uint16_t &port) const void SocketAPM::last_recv_address(const char *&ip_addr, uint16_t &port) const
{ {
ip_addr = inet_ntoa(in_addr.sin_addr); static char buf[16];
port = ntohs(in_addr.sin_port); auto *str = last_recv_address(buf, sizeof(buf), port);
ip_addr = str;
} }
/* /*
@ -338,11 +355,11 @@ void SocketAPM::last_recv_address(const char *&ip_addr, uint16_t &port) const
*/ */
const char *SocketAPM::last_recv_address(char *ip_addr_buf, uint8_t buflen, uint16_t &port) const const char *SocketAPM::last_recv_address(char *ip_addr_buf, uint8_t buflen, uint16_t &port) const
{ {
const char *ret = inet_ntop(AF_INET, (void*)&in_addr.sin_addr, ip_addr_buf, buflen); const char *ret = CALL_PREFIX(inet_ntop)(AF_INET, (void*)&last_in_addr.sin_addr, ip_addr_buf, buflen);
if (ret == nullptr) { if (ret == nullptr) {
return nullptr; return nullptr;
} }
port = ntohs(in_addr.sin_port); port = ntohs(last_in_addr.sin_port);
return ret; return ret;
} }
@ -427,8 +444,8 @@ SocketAPM *SocketAPM::accept(uint32_t timeout_ms)
return nullptr; return nullptr;
} }
socklen_t len = sizeof(in_addr); socklen_t len = sizeof(last_in_addr);
int newfd = CALL_PREFIX(accept)(fd, (sockaddr *)&in_addr, &len); int newfd = CALL_PREFIX(accept)(fd, (sockaddr *)&last_in_addr, &len);
if (newfd == -1) { if (newfd == -1) {
return nullptr; return nullptr;
} }

View File

@ -19,22 +19,12 @@
#include <AP_HAL/AP_HAL.h> #include <AP_HAL/AP_HAL.h>
#include <AP_Networking/AP_Networking_Config.h> #include <AP_Networking/AP_Networking_Config.h>
#if AP_NETWORKING_SOCKETS_ENABLED #if AP_NETWORKING_SOCKETS_ENABLED
#if HAL_OS_SOCKETS #if HAL_OS_SOCKETS
#include <fcntl.h> struct sockaddr_in;
#include <unistd.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <sys/select.h>
#elif AP_NETWORKING_BACKEND_CHIBIOS
#include <AP_Networking/AP_Networking_ChibiOS.h>
#include <lwip/sockets.h>
#endif
class SocketAPM { class SocketAPM {
public: public:
@ -91,7 +81,13 @@ public:
private: private:
bool datagram; bool datagram;
struct sockaddr_in in_addr {}; struct {
uint16_t sin_family;
uint16_t sin_port;
struct {
uint32_t s_addr;
} sin_addr;
} last_in_addr;
bool is_multicast_address(struct sockaddr_in &addr) const; bool is_multicast_address(struct sockaddr_in &addr) const;
int fd = -1; int fd = -1;
@ -104,4 +100,5 @@ private:
void make_sockaddr(const char *address, uint16_t port, struct sockaddr_in &sockaddr); void make_sockaddr(const char *address, uint16_t port, struct sockaddr_in &sockaddr);
}; };
#endif // HAL_OS_SOCKETS
#endif // AP_NETWORKING_SOCKETS_ENABLED #endif // AP_NETWORKING_SOCKETS_ENABLED