bpo-46805: Add low level UDP socket functions to asyncio (GH-31455)

This commit is contained in:
Alex Grönholm 2022-03-13 18:42:29 +02:00 committed by GitHub
parent 7e473e94a5
commit 9f04ee569c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 489 additions and 7 deletions

View File

@ -922,6 +922,29 @@ convenient.
.. versionadded:: 3.7
.. coroutinemethod:: loop.sock_recvfrom(sock, bufsize)
Receive a datagram of up to *bufsize* from *sock*. Asynchronous version of
:meth:`socket.recvfrom() <socket.socket.recvfrom>`.
Return a tuple of (received data, remote address).
*sock* must be a non-blocking socket.
.. versionadded:: 3.11
.. coroutinemethod:: loop.sock_recvfrom_into(sock, buf, nbytes=0)
Receive a datagram of up to *nbytes* from *sock* into *buf*.
Asynchronous version of
:meth:`socket.recvfrom_into() <socket.socket.recvfrom_into>`.
Return a tuple of (number of bytes received, remote address).
*sock* must be a non-blocking socket.
.. versionadded:: 3.11
.. coroutinemethod:: loop.sock_sendall(sock, data)
Send *data* to the *sock* socket. Asynchronous version of
@ -940,6 +963,18 @@ convenient.
method, before Python 3.7 it returned a :class:`Future`.
Since Python 3.7, this is an ``async def`` method.
.. coroutinemethod:: loop.sock_sendto(sock, data, address)
Send a datagram from *sock* to *address*.
Asynchronous version of
:meth:`socket.sendto() <socket.socket.sendto>`.
Return the number of bytes sent.
*sock* must be a non-blocking socket.
.. versionadded:: 3.11
.. coroutinemethod:: loop.sock_connect(sock, address)
Connect *sock* to a remote socket at *address*.

View File

@ -189,9 +189,18 @@ See also the main documentation section about the
* - ``await`` :meth:`loop.sock_recv_into`
- Receive data from the :class:`~socket.socket` into a buffer.
* - ``await`` :meth:`loop.sock_recvfrom`
- Receive a datagram from the :class:`~socket.socket`.
* - ``await`` :meth:`loop.sock_recvfrom_into`
- Receive a datagram from the :class:`~socket.socket` into a buffer.
* - ``await`` :meth:`loop.sock_sendall`
- Send data to the :class:`~socket.socket`.
* - ``await`` :meth:`loop.sock_sendto`
- Send a datagram via the :class:`~socket.socket` to the given address.
* - ``await`` :meth:`loop.sock_connect`
- Connect the :class:`~socket.socket`.

View File

@ -226,6 +226,15 @@ New Modules
Improved Modules
================
asyncio
-------
* Add raw datagram socket functions to the event loop:
:meth:`~asyncio.AbstractEventLoop.sock_sendto`,
:meth:`~asyncio.AbstractEventLoop.sock_recvfrom` and
:meth:`~asyncio.AbstractEventLoop.sock_recvfrom_into`.
(Contributed by Alex Grönholm in :issue:`46805`.)
fractions
---------

View File

@ -546,9 +546,18 @@ class AbstractEventLoop:
async def sock_recv_into(self, sock, buf):
raise NotImplementedError
async def sock_recvfrom(self, sock, bufsize):
raise NotImplementedError
async def sock_recvfrom_into(self, sock, buf, nbytes=0):
raise NotImplementedError
async def sock_sendall(self, sock, data):
raise NotImplementedError
async def sock_sendto(self, sock, data, address):
raise NotImplementedError
async def sock_connect(self, sock, address):
raise NotImplementedError

View File

@ -700,9 +700,21 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
async def sock_recv_into(self, sock, buf):
return await self._proactor.recv_into(sock, buf)
async def sock_recvfrom(self, sock, bufsize):
return await self._proactor.recvfrom(sock, bufsize)
async def sock_recvfrom_into(self, sock, buf, nbytes=0):
if not nbytes:
nbytes = len(buf)
return await self._proactor.recvfrom_into(sock, buf, nbytes)
async def sock_sendall(self, sock, data):
return await self._proactor.send(sock, data)
async def sock_sendto(self, sock, data, address):
return await self._proactor.sendto(sock, data, 0, address)
async def sock_connect(self, sock, address):
return await self._proactor.connect(sock, address)

View File

@ -434,6 +434,88 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
else:
fut.set_result(nbytes)
async def sock_recvfrom(self, sock, bufsize):
"""Receive a datagram from a datagram socket.
The return value is a tuple of (bytes, address) representing the
datagram received and the address it came from.
The maximum amount of data to be received at once is specified by
nbytes.
"""
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
return sock.recvfrom(bufsize)
except (BlockingIOError, InterruptedError):
pass
fut = self.create_future()
fd = sock.fileno()
self._ensure_fd_no_transport(fd)
handle = self._add_reader(fd, self._sock_recvfrom, fut, sock, bufsize)
fut.add_done_callback(
functools.partial(self._sock_read_done, fd, handle=handle))
return await fut
def _sock_recvfrom(self, fut, sock, bufsize):
# _sock_recvfrom() can add itself as an I/O callback if the operation
# can't be done immediately. Don't use it directly, call
# sock_recvfrom().
if fut.done():
return
try:
result = sock.recvfrom(bufsize)
except (BlockingIOError, InterruptedError):
return # try again next time
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
fut.set_exception(exc)
else:
fut.set_result(result)
async def sock_recvfrom_into(self, sock, buf, nbytes=0):
"""Receive data from the socket.
The received data is written into *buf* (a writable buffer).
The return value is a tuple of (number of bytes written, address).
"""
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
if not nbytes:
nbytes = len(buf)
try:
return sock.recvfrom_into(buf, nbytes)
except (BlockingIOError, InterruptedError):
pass
fut = self.create_future()
fd = sock.fileno()
self._ensure_fd_no_transport(fd)
handle = self._add_reader(fd, self._sock_recvfrom_into, fut, sock, buf,
nbytes)
fut.add_done_callback(
functools.partial(self._sock_read_done, fd, handle=handle))
return await fut
def _sock_recvfrom_into(self, fut, sock, buf, bufsize):
# _sock_recv_into() can add itself as an I/O callback if the operation
# can't be done immediately. Don't use it directly, call
# sock_recv_into().
if fut.done():
return
try:
result = sock.recvfrom_into(buf, bufsize)
except (BlockingIOError, InterruptedError):
return # try again next time
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
fut.set_exception(exc)
else:
fut.set_result(result)
async def sock_sendall(self, sock, data):
"""Send data to the socket.
@ -487,6 +569,48 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
else:
pos[0] = start
async def sock_sendto(self, sock, data, address):
"""Send data to the socket.
The socket must be connected to a remote socket. This method continues
to send data from data until either all data has been sent or an
error occurs. None is returned on success. On error, an exception is
raised, and there is no way to determine how much data, if any, was
successfully processed by the receiving end of the connection.
"""
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
return sock.sendto(data, address)
except (BlockingIOError, InterruptedError):
pass
fut = self.create_future()
fd = sock.fileno()
self._ensure_fd_no_transport(fd)
# use a trick with a list in closure to store a mutable state
handle = self._add_writer(fd, self._sock_sendto, fut, sock, data,
address)
fut.add_done_callback(
functools.partial(self._sock_write_done, fd, handle=handle))
return await fut
def _sock_sendto(self, fut, sock, data, address):
if fut.done():
# Future cancellation can be scheduled on previous loop iteration
return
try:
n = sock.sendto(data, 0, address)
except (BlockingIOError, InterruptedError):
return
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
fut.set_exception(exc)
else:
fut.set_result(n)
async def sock_connect(self, sock, address):
"""Connect to a remote socket at address.

View File

@ -512,6 +512,26 @@ class IocpProactor:
return self._register(ov, conn, finish_recv)
def recvfrom_into(self, conn, buf, flags=0):
self._register_with_iocp(conn)
ov = _overlapped.Overlapped(NULL)
try:
ov.WSARecvFromInto(conn.fileno(), buf, flags)
except BrokenPipeError:
return self._result((0, None))
def finish_recv(trans, key, ov):
try:
return ov.getresult()
except OSError as exc:
if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
_overlapped.ERROR_OPERATION_ABORTED):
raise ConnectionResetError(*exc.args)
else:
raise
return self._register(ov, conn, finish_recv)
def sendto(self, conn, buf, flags=0, addr=None):
self._register_with_iocp(conn)
ov = _overlapped.Overlapped(NULL)

View File

@ -5,11 +5,11 @@ import unittest
from asyncio import proactor_events
from itertools import cycle, islice
from unittest.mock import patch, Mock
from test.test_asyncio import utils as test_utils
from test import support
from test.support import socket_helper
def tearDownModule():
asyncio.set_event_loop_policy(None)
@ -380,6 +380,79 @@ class BaseSockTestsMixin:
self.loop.run_until_complete(
self._basetest_huge_content_recvinto(httpd.address))
async def _basetest_datagram_recvfrom(self, server_address):
# Happy path, sock.sendto() returns immediately
data = b'\x01' * 4096
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)
await self.loop.sock_sendto(sock, data, server_address)
received_data, from_addr = await self.loop.sock_recvfrom(
sock, 4096)
self.assertEqual(received_data, data)
self.assertEqual(from_addr, server_address)
def test_recvfrom(self):
with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_recvfrom(server_address))
async def _basetest_datagram_recvfrom_into(self, server_address):
# Happy path, sock.sendto() returns immediately
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)
buf = bytearray(4096)
data = b'\x01' * 4096
await self.loop.sock_sendto(sock, data, server_address)
num_bytes, from_addr = await self.loop.sock_recvfrom_into(
sock, buf)
self.assertEqual(num_bytes, 4096)
self.assertEqual(buf, data)
self.assertEqual(from_addr, server_address)
buf = bytearray(8192)
await self.loop.sock_sendto(sock, data, server_address)
num_bytes, from_addr = await self.loop.sock_recvfrom_into(
sock, buf, 4096)
self.assertEqual(num_bytes, 4096)
self.assertEqual(buf[:4096], data[:4096])
self.assertEqual(from_addr, server_address)
def test_recvfrom_into(self):
with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_recvfrom_into(server_address))
async def _basetest_datagram_sendto_blocking(self, server_address):
# Sad path, sock.sendto() raises BlockingIOError
# This involves patching sock.sendto() to raise BlockingIOError but
# sendto() is not used by the proactor event loop
data = b'\x01' * 4096
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)
mock_sock = Mock(sock)
mock_sock.gettimeout = sock.gettimeout
mock_sock.sendto.configure_mock(side_effect=BlockingIOError)
mock_sock.fileno = sock.fileno
self.loop.call_soon(
lambda: setattr(mock_sock, 'sendto', sock.sendto)
)
await self.loop.sock_sendto(mock_sock, data, server_address)
received_data, from_addr = await self.loop.sock_recvfrom(
sock, 4096)
self.assertEqual(received_data, data)
self.assertEqual(from_addr, server_address)
def test_sendto_blocking(self):
if sys.platform == 'win32':
if isinstance(self.loop, asyncio.ProactorEventLoop):
raise unittest.SkipTest('Not relevant to ProactorEventLoop')
with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_sendto_blocking(server_address))
@socket_helper.skip_unless_bind_unix_socket
def test_unix_sock_client_ops(self):
with test_utils.run_test_unix_server() as httpd:

View File

@ -281,6 +281,31 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
server_ssl_cls=SSLWSGIServer)
def echo_datagrams(sock):
while True:
data, addr = sock.recvfrom(4096)
if data == b'STOP':
sock.close()
break
else:
sock.sendto(data, addr)
@contextlib.contextmanager
def run_udp_echo_server(*, host='127.0.0.1', port=0):
addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
family, type, proto, _, sockaddr = addr_info[0]
sock = socket.socket(family, type, proto)
sock.bind((host, port))
thread = threading.Thread(target=lambda: echo_datagrams(sock))
thread.start()
try:
yield sock.getsockname()
finally:
sock.sendto(b'STOP', sock.getsockname())
thread.join()
def make_test_protocol(base):
dct = {}
for name in dir(base):

View File

@ -0,0 +1,4 @@
Added raw datagram socket functions for asyncio:
:meth:`~asyncio.AbstractEventLoop.sock_sendto`,
:meth:`~asyncio.AbstractEventLoop.sock_recvfrom` and
:meth:`~asyncio.AbstractEventLoop.sock_recvfrom_into`.

View File

@ -905,4 +905,42 @@ _overlapped_Overlapped_WSARecvFrom(OverlappedObject *self, PyObject *const *args
exit:
return return_value;
}
/*[clinic end generated code: output=ee2ec2f93c8d334b input=a9049054013a1b77]*/
PyDoc_STRVAR(_overlapped_Overlapped_WSARecvFromInto__doc__,
"WSARecvFromInto($self, handle, buf, size, flags=0, /)\n"
"--\n"
"\n"
"Start overlapped receive.");
#define _OVERLAPPED_OVERLAPPED_WSARECVFROMINTO_METHODDEF \
{"WSARecvFromInto", (PyCFunction)(void(*)(void))_overlapped_Overlapped_WSARecvFromInto, METH_FASTCALL, _overlapped_Overlapped_WSARecvFromInto__doc__},
static PyObject *
_overlapped_Overlapped_WSARecvFromInto_impl(OverlappedObject *self,
HANDLE handle, Py_buffer *bufobj,
DWORD size, DWORD flags);
static PyObject *
_overlapped_Overlapped_WSARecvFromInto(OverlappedObject *self, PyObject *const *args, Py_ssize_t nargs)
{
PyObject *return_value = NULL;
HANDLE handle;
Py_buffer bufobj = {NULL, NULL};
DWORD size;
DWORD flags = 0;
if (!_PyArg_ParseStack(args, nargs, ""F_HANDLE"y*k|k:WSARecvFromInto",
&handle, &bufobj, &size, &flags)) {
goto exit;
}
return_value = _overlapped_Overlapped_WSARecvFromInto_impl(self, handle, &bufobj, size, flags);
exit:
/* Cleanup for bufobj */
if (bufobj.obj) {
PyBuffer_Release(&bufobj);
}
return return_value;
}
/*[clinic end generated code: output=5c9b17890ef29d52 input=a9049054013a1b77]*/

View File

@ -64,7 +64,7 @@ class _overlapped.Overlapped "OverlappedObject *" "&OverlappedType"
enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_READINTO, TYPE_WRITE,
TYPE_ACCEPT, TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE,
TYPE_WAIT_NAMED_PIPE_AND_CONNECT, TYPE_TRANSMIT_FILE, TYPE_READ_FROM,
TYPE_WRITE_TO};
TYPE_WRITE_TO, TYPE_READ_FROM_INTO};
typedef struct {
PyObject_HEAD
@ -91,6 +91,17 @@ typedef struct {
struct sockaddr_in6 address;
int address_length;
} read_from;
/* Data used for reading from a connectionless socket:
TYPE_READ_FROM_INTO */
struct {
// A (number of bytes read, (host, port)) tuple
PyObject* result;
/* Buffer passed by the user */
Py_buffer *user_buffer;
struct sockaddr_in6 address;
int address_length;
} read_from_into;
};
} OverlappedObject;
@ -662,6 +673,13 @@ Overlapped_clear(OverlappedObject *self)
}
break;
}
case TYPE_READ_FROM_INTO: {
if (self->read_from_into.result) {
// We've received a message, free the result tuple.
Py_CLEAR(self->read_from_into.result);
}
break;
}
case TYPE_WRITE:
case TYPE_WRITE_TO:
case TYPE_READINTO: {
@ -866,6 +884,11 @@ _overlapped_Overlapped_getresult_impl(OverlappedObject *self, BOOL wait)
{
break;
}
else if (self->type == TYPE_READ_FROM_INTO &&
self->read_from_into.result != NULL)
{
break;
}
/* fall through */
default:
return SetFromWindowsErr(err);
@ -914,6 +937,30 @@ _overlapped_Overlapped_getresult_impl(OverlappedObject *self, BOOL wait)
Py_INCREF(self->read_from.result);
return self->read_from.result;
case TYPE_READ_FROM_INTO:
// unparse the address
addr = unparse_address((SOCKADDR*)&self->read_from_into.address,
self->read_from_into.address_length);
if (addr == NULL) {
return NULL;
}
// The result is a two item tuple: (number of bytes read, address)
self->read_from_into.result = PyTuple_New(2);
if (self->read_from_into.result == NULL) {
Py_CLEAR(addr);
return NULL;
}
// first item: number of bytes read
PyTuple_SET_ITEM(self->read_from_into.result, 0,
PyLong_FromUnsignedLong((unsigned long)transferred));
// second item: address
PyTuple_SET_ITEM(self->read_from_into.result, 1, addr);
Py_INCREF(self->read_from_into.result);
return self->read_from_into.result;
default:
return PyLong_FromUnsignedLong((unsigned long) transferred);
}
@ -1053,6 +1100,7 @@ do_WSARecv(OverlappedObject *self, HANDLE handle,
}
}
/*[clinic input]
_overlapped.Overlapped.WSARecv
@ -1617,6 +1665,13 @@ Overlapped_traverse(OverlappedObject *self, visitproc visit, void *arg)
case TYPE_READ_FROM:
Py_VISIT(self->read_from.result);
Py_VISIT(self->read_from.allocated_buffer);
break;
case TYPE_READ_FROM_INTO:
Py_VISIT(self->read_from_into.result);
if (self->read_from_into.user_buffer->obj) {
Py_VISIT(&self->read_from_into.user_buffer->obj);
}
break;
}
return 0;
}
@ -1766,8 +1821,8 @@ _overlapped_Overlapped_WSARecvFrom_impl(OverlappedObject *self,
DWORD flags)
/*[clinic end generated code: output=13832a2025b86860 input=1b2663fa130e0286]*/
{
DWORD nread;
PyObject *buf;
DWORD nread;
WSABUF wsabuf;
int ret;
DWORD err;
@ -1785,8 +1840,8 @@ _overlapped_Overlapped_WSARecvFrom_impl(OverlappedObject *self,
return NULL;
}
wsabuf.len = size;
wsabuf.buf = PyBytes_AS_STRING(buf);
wsabuf.len = size;
self->type = TYPE_READ_FROM;
self->handle = handle;
@ -1802,8 +1857,7 @@ _overlapped_Overlapped_WSARecvFrom_impl(OverlappedObject *self,
Py_END_ALLOW_THREADS
self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS);
switch(err) {
switch (err) {
case ERROR_BROKEN_PIPE:
mark_as_completed(&self->overlapped);
return SetFromWindowsErr(err);
@ -1817,6 +1871,74 @@ _overlapped_Overlapped_WSARecvFrom_impl(OverlappedObject *self,
}
}
/*[clinic input]
_overlapped.Overlapped.WSARecvFromInto
handle: HANDLE
buf as bufobj: Py_buffer
size: DWORD
flags: DWORD = 0
/
Start overlapped receive.
[clinic start generated code]*/
static PyObject *
_overlapped_Overlapped_WSARecvFromInto_impl(OverlappedObject *self,
HANDLE handle, Py_buffer *bufobj,
DWORD size, DWORD flags)
/*[clinic end generated code: output=30c7ea171a691757 input=4be4b08d03531e76]*/
{
DWORD nread;
WSABUF wsabuf;
int ret;
DWORD err;
if (self->type != TYPE_NONE) {
PyErr_SetString(PyExc_ValueError, "operation already attempted");
return NULL;
}
#if SIZEOF_SIZE_T > SIZEOF_LONG
if (bufobj->len > (Py_ssize_t)ULONG_MAX) {
PyErr_SetString(PyExc_ValueError, "buffer too large");
return NULL;
}
#endif
wsabuf.buf = bufobj->buf;
wsabuf.len = size;
self->type = TYPE_READ_FROM_INTO;
self->handle = handle;
self->read_from_into.user_buffer = bufobj;
memset(&self->read_from_into.address, 0, sizeof(self->read_from_into.address));
self->read_from_into.address_length = sizeof(self->read_from_into.address);
Py_BEGIN_ALLOW_THREADS
ret = WSARecvFrom((SOCKET)handle, &wsabuf, 1, &nread, &flags,
(SOCKADDR*)&self->read_from_into.address,
&self->read_from_into.address_length,
&self->overlapped, NULL);
Py_END_ALLOW_THREADS
self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS);
switch (err) {
case ERROR_BROKEN_PIPE:
mark_as_completed(&self->overlapped);
return SetFromWindowsErr(err);
case ERROR_SUCCESS:
case ERROR_MORE_DATA:
case ERROR_IO_PENDING:
Py_RETURN_NONE;
default:
self->type = TYPE_NOT_STARTED;
return SetFromWindowsErr(err);
}
}
#include "clinic/overlapped.c.h"
static PyMethodDef Overlapped_methods[] = {
@ -1826,6 +1948,8 @@ static PyMethodDef Overlapped_methods[] = {
_OVERLAPPED_OVERLAPPED_READFILEINTO_METHODDEF
_OVERLAPPED_OVERLAPPED_WSARECV_METHODDEF
_OVERLAPPED_OVERLAPPED_WSARECVINTO_METHODDEF
_OVERLAPPED_OVERLAPPED_WSARECVFROM_METHODDEF
_OVERLAPPED_OVERLAPPED_WSARECVFROMINTO_METHODDEF
_OVERLAPPED_OVERLAPPED_WRITEFILE_METHODDEF
_OVERLAPPED_OVERLAPPED_WSASEND_METHODDEF
_OVERLAPPED_OVERLAPPED_ACCEPTEX_METHODDEF