From bafd4b5ac83b6cc0b7455290a04c4bfad34bdc90 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 28 May 2019 12:52:15 +0300 Subject: [PATCH] bpo-29883: Asyncio proactor udp (GH-13440) Follow-up for #1067 https://bugs.python.org/issue29883 --- Doc/library/asyncio-eventloop.rst | 5 +- Doc/library/asyncio-platforms.rst | 3 - Lib/asyncio/proactor_events.py | 169 +++++++- Lib/asyncio/windows_events.py | 46 +++ Lib/test/test_asyncio/test_events.py | 9 - Lib/test/test_asyncio/test_proactor_events.py | 278 +++++++++++++ .../2018-09-15-11-36-55.bpo-29883.HErerE.rst | 2 + Modules/overlapped.c | 372 +++++++++++++++++- 8 files changed, 838 insertions(+), 46 deletions(-) create mode 100644 Misc/NEWS.d/next/Windows/2018-09-15-11-36-55.bpo-29883.HErerE.rst diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst index 4acd23f5946..8673f84e963 100644 --- a/Doc/library/asyncio-eventloop.rst +++ b/Doc/library/asyncio-eventloop.rst @@ -504,8 +504,6 @@ Opening network connections transport. If specified, *local_addr* and *remote_addr* should be omitted (must be :const:`None`). - On Windows, with :class:`ProactorEventLoop`, this method is not supported. - See :ref:`UDP echo client protocol ` and :ref:`UDP echo server protocol ` examples. @@ -513,6 +511,9 @@ Opening network connections The *family*, *proto*, *flags*, *reuse_address*, *reuse_port, *allow_broadcast*, and *sock* parameters were added. + .. versionchanged:: 3.8 + Added support for Windows. + .. coroutinemethod:: loop.create_unix_connection(protocol_factory, \ path=None, \*, ssl=None, sock=None, \ server_hostname=None, ssl_handshake_timeout=None) diff --git a/Doc/library/asyncio-platforms.rst b/Doc/library/asyncio-platforms.rst index 81d840e2327..7e4a70f91c6 100644 --- a/Doc/library/asyncio-platforms.rst +++ b/Doc/library/asyncio-platforms.rst @@ -53,9 +53,6 @@ All event loops on Windows do not support the following methods: :class:`ProactorEventLoop` has the following limitations: -* The :meth:`loop.create_datagram_endpoint` method - is not supported. - * The :meth:`loop.add_reader` and :meth:`loop.add_writer` methods are not supported. diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 6a53b2edaac..9b8ae064a89 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -11,6 +11,7 @@ import os import socket import warnings import signal +import collections from . import base_events from . import constants @@ -23,6 +24,24 @@ from . import trsock from .log import logger +def _set_socket_extra(transport, sock): + transport._extra['socket'] = trsock.TransportSocket(sock) + + try: + transport._extra['sockname'] = sock.getsockname() + except socket.error: + if transport._loop.get_debug(): + logger.warning( + "getsockname() failed on %r", sock, exc_info=True) + + if 'peername' not in transport._extra: + try: + transport._extra['peername'] = sock.getpeername() + except socket.error: + # UDP sockets may not have a peer name + transport._extra['peername'] = None + + class _ProactorBasePipeTransport(transports._FlowControlMixin, transports.BaseTransport): """Base class for pipe and socket transports.""" @@ -430,6 +449,134 @@ class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport): self.close() +class _ProactorDatagramTransport(_ProactorBasePipeTransport): + max_size = 256 * 1024 + def __init__(self, loop, sock, protocol, address=None, + waiter=None, extra=None): + self._address = address + self._empty_waiter = None + # We don't need to call _protocol.connection_made() since our base + # constructor does it for us. + super().__init__(loop, sock, protocol, waiter=waiter, extra=extra) + + # The base constructor sets _buffer = None, so we set it here + self._buffer = collections.deque() + self._loop.call_soon(self._loop_reading) + + def _set_extra(self, sock): + _set_socket_extra(self, sock) + + def get_write_buffer_size(self): + return sum(len(data) for data, _ in self._buffer) + + def abort(self): + self._force_close(None) + + def sendto(self, data, addr=None): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be bytes-like object (%r)', + type(data)) + + if not data: + return + + if self._address is not None and addr not in (None, self._address): + raise ValueError( + f'Invalid address: must be None or {self._address}') + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.sendto() raised exception.') + self._conn_lost += 1 + return + + # Ensure that what we buffer is immutable. + self._buffer.append((bytes(data), addr)) + + if self._write_fut is None: + # No current write operations are active, kick one off + self._loop_writing() + # else: A write operation is already kicked off + + self._maybe_pause_protocol() + + def _loop_writing(self, fut=None): + try: + if self._conn_lost: + return + + assert fut is self._write_fut + self._write_fut = None + if fut: + # We are in a _loop_writing() done callback, get the result + fut.result() + + if not self._buffer or (self._conn_lost and self._address): + # The connection has been closed + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + return + + data, addr = self._buffer.popleft() + if self._address is not None: + self._write_fut = self._loop._proactor.send(self._sock, + data) + else: + self._write_fut = self._loop._proactor.sendto(self._sock, + data, + addr=addr) + except OSError as exc: + self._protocol.error_received(exc) + except Exception as exc: + self._fatal_error(exc, 'Fatal write error on datagram transport') + else: + self._write_fut.add_done_callback(self._loop_writing) + self._maybe_resume_protocol() + + def _loop_reading(self, fut=None): + data = None + try: + if self._conn_lost: + return + + assert self._read_fut is fut or (self._read_fut is None and + self._closing) + + self._read_fut = None + if fut is not None: + res = fut.result() + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if self._address is not None: + data, addr = res, self._address + else: + data, addr = res + + if self._conn_lost: + return + if self._address is not None: + self._read_fut = self._loop._proactor.recv(self._sock, + self.max_size) + else: + self._read_fut = self._loop._proactor.recvfrom(self._sock, + self.max_size) + except OSError as exc: + self._protocol.error_received(exc) + except exceptions.CancelledError: + if not self._closing: + raise + else: + if self._read_fut is not None: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.datagram_received(data, addr) + + class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport, _ProactorBaseWritePipeTransport, transports.Transport): @@ -455,22 +602,7 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport, base_events._set_nodelay(sock) def _set_extra(self, sock): - self._extra['socket'] = trsock.TransportSocket(sock) - - try: - self._extra['sockname'] = sock.getsockname() - except (socket.error, AttributeError): - if self._loop.get_debug(): - logger.warning( - "getsockname() failed on %r", sock, exc_info=True) - - if 'peername' not in self._extra: - try: - self._extra['peername'] = sock.getpeername() - except (socket.error, AttributeError): - if self._loop.get_debug(): - logger.warning("getpeername() failed on %r", - sock, exc_info=True) + _set_socket_extra(self, sock) def can_write_eof(self): return True @@ -515,6 +647,11 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): extra=extra, server=server) return ssl_protocol._app_transport + def _make_datagram_transport(self, sock, protocol, + address=None, waiter=None, extra=None): + return _ProactorDatagramTransport(self, sock, protocol, address, + waiter, extra) + def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, extra=None): return _ProactorDuplexPipeTransport(self, diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py index 61b40ba52a6..ac51109ff1a 100644 --- a/Lib/asyncio/windows_events.py +++ b/Lib/asyncio/windows_events.py @@ -483,6 +483,44 @@ class IocpProactor: return self._register(ov, conn, finish_recv) + def recvfrom(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + try: + ov.WSARecvFrom(conn.fileno(), nbytes, flags) + except BrokenPipeError: + return self._result((b'', None)) + + def finish_recv(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): + raise ConnectionResetError(*exc.args) + else: + raise + + return self._register(ov, conn, finish_recv) + + def sendto(self, conn, buf, flags=0, addr=None): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + + ov.WSASendTo(conn.fileno(), buf, flags, addr) + + def finish_send(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): + raise ConnectionResetError(*exc.args) + else: + raise + + return self._register(ov, conn, finish_send) + def send(self, conn, buf, flags=0): self._register_with_iocp(conn) ov = _overlapped.Overlapped(NULL) @@ -532,6 +570,14 @@ class IocpProactor: return future def connect(self, conn, address): + if conn.type == socket.SOCK_DGRAM: + # WSAConnect will complete immediately for UDP sockets so we don't + # need to register any IOCP operation + _overlapped.WSAConnect(conn.fileno(), address) + fut = self._loop.create_future() + fut.set_result(None) + return fut + self._register_with_iocp(conn) # The socket needs to be locally bound before we call ConnectEx(). try: diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index e89db99df31..045654e87a8 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -1249,11 +1249,6 @@ class EventLoopTestsMixin: server.transport.close() def test_create_datagram_endpoint_sock(self): - if (sys.platform == 'win32' and - isinstance(self.loop, proactor_events.BaseProactorEventLoop)): - raise unittest.SkipTest( - 'UDP is not supported with proactor event loops') - sock = None local_address = ('127.0.0.1', 0) infos = self.loop.run_until_complete( @@ -2004,10 +1999,6 @@ if sys.platform == 'win32': def test_writer_callback_cancel(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") - def test_create_datagram_endpoint(self): - raise unittest.SkipTest( - "IocpEventLoop does not have create_datagram_endpoint()") - def test_remove_fds_after_closing(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") else: diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py index 5952ccccce0..2e9995d3280 100644 --- a/Lib/test/test_asyncio/test_proactor_events.py +++ b/Lib/test/test_asyncio/test_proactor_events.py @@ -4,6 +4,7 @@ import io import socket import unittest import sys +from collections import deque from unittest import mock import asyncio @@ -12,6 +13,7 @@ from asyncio.proactor_events import BaseProactorEventLoop from asyncio.proactor_events import _ProactorSocketTransport from asyncio.proactor_events import _ProactorWritePipeTransport from asyncio.proactor_events import _ProactorDuplexPipeTransport +from asyncio.proactor_events import _ProactorDatagramTransport from test import support from test.test_asyncio import utils as test_utils @@ -725,6 +727,208 @@ class ProactorSocketTransportBufferedProtoTests(test_utils.TestCase): self.assertFalse(tr.is_reading()) +class ProactorDatagramTransportTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.proactor = mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol) + self.sock = mock.Mock(spec_set=socket.socket) + self.sock.fileno.return_value = 7 + + def datagram_transport(self, address=None): + self.sock.getpeername.side_effect = None if address else OSError + transport = _ProactorDatagramTransport(self.loop, self.sock, + self.protocol, + address=address) + self.addCleanup(close_transport, transport) + return transport + + def test_sendto(self): + data = b'data' + transport = self.datagram_transport() + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.proactor.sendto.called) + self.proactor.sendto.assert_called_with( + self.sock, data, addr=('0.0.0.0', 1234)) + + def test_sendto_bytearray(self): + data = bytearray(b'data') + transport = self.datagram_transport() + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.proactor.sendto.called) + self.proactor.sendto.assert_called_with( + self.sock, b'data', addr=('0.0.0.0', 1234)) + + def test_sendto_memoryview(self): + data = memoryview(b'data') + transport = self.datagram_transport() + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.proactor.sendto.called) + self.proactor.sendto.assert_called_with( + self.sock, b'data', addr=('0.0.0.0', 1234)) + + def test_sendto_no_data(self): + transport = self.datagram_transport() + transport._buffer.append((b'data', ('0.0.0.0', 12345))) + transport.sendto(b'', ()) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + def test_sendto_buffer(self): + transport = self.datagram_transport() + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport._write_fut = object() + transport.sendto(b'data2', ('0.0.0.0', 12345)) + self.assertFalse(self.proactor.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + + def test_sendto_buffer_bytearray(self): + data2 = bytearray(b'data2') + transport = self.datagram_transport() + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport._write_fut = object() + transport.sendto(data2, ('0.0.0.0', 12345)) + self.assertFalse(self.proactor.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + self.assertIsInstance(transport._buffer[1][0], bytes) + + def test_sendto_buffer_memoryview(self): + data2 = memoryview(b'data2') + transport = self.datagram_transport() + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport._write_fut = object() + transport.sendto(data2, ('0.0.0.0', 12345)) + self.assertFalse(self.proactor.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + self.assertIsInstance(transport._buffer[1][0], bytes) + + @mock.patch('asyncio.proactor_events.logger') + def test_sendto_exception(self, m_log): + data = b'data' + err = self.proactor.sendto.side_effect = RuntimeError() + + transport = self.datagram_transport() + transport._fatal_error = mock.Mock() + transport.sendto(data, ()) + + self.assertTrue(transport._fatal_error.called) + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on datagram transport') + transport._conn_lost = 1 + + transport._address = ('123',) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + m_log.warning.assert_called_with('socket.sendto() raised exception.') + + def test_sendto_error_received(self): + data = b'data' + + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = self.datagram_transport() + transport._fatal_error = mock.Mock() + transport.sendto(data, ()) + + self.assertEqual(transport._conn_lost, 0) + self.assertFalse(transport._fatal_error.called) + + def test_sendto_error_received_connected(self): + data = b'data' + + self.proactor.send.side_effect = ConnectionRefusedError + + transport = self.datagram_transport(address=('0.0.0.0', 1)) + transport._fatal_error = mock.Mock() + transport.sendto(data) + + self.assertFalse(transport._fatal_error.called) + self.assertTrue(self.protocol.error_received.called) + + def test_sendto_str(self): + transport = self.datagram_transport() + self.assertRaises(TypeError, transport.sendto, 'str', ()) + + def test_sendto_connected_addr(self): + transport = self.datagram_transport(address=('0.0.0.0', 1)) + self.assertRaises( + ValueError, transport.sendto, b'str', ('0.0.0.0', 2)) + + def test_sendto_closing(self): + transport = self.datagram_transport(address=(1,)) + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.sendto(b'data', (1,)) + self.assertEqual(transport._conn_lost, 2) + + def test__loop_writing_closing(self): + transport = self.datagram_transport() + transport._closing = True + transport._loop_writing() + self.assertIsNone(transport._write_fut) + test_utils.run_briefly(self.loop) + self.sock.close.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + def test__loop_writing_exception(self): + err = self.proactor.sendto.side_effect = RuntimeError() + + transport = self.datagram_transport() + transport._fatal_error = mock.Mock() + transport._buffer.append((b'data', ())) + transport._loop_writing() + + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on datagram transport') + + def test__loop_writing_error_received(self): + self.proactor.sendto.side_effect = ConnectionRefusedError + + transport = self.datagram_transport() + transport._fatal_error = mock.Mock() + transport._buffer.append((b'data', ())) + transport._loop_writing() + + self.assertFalse(transport._fatal_error.called) + + def test__loop_writing_error_received_connection(self): + self.proactor.send.side_effect = ConnectionRefusedError + + transport = self.datagram_transport(address=('0.0.0.0', 1)) + transport._fatal_error = mock.Mock() + transport._buffer.append((b'data', ())) + transport._loop_writing() + + self.assertFalse(transport._fatal_error.called) + self.assertTrue(self.protocol.error_received.called) + + @mock.patch('asyncio.base_events.logger.error') + def test_fatal_error_connected(self, m_exc): + transport = self.datagram_transport(address=('0.0.0.0', 1)) + err = ConnectionRefusedError() + transport._fatal_error(err) + self.assertFalse(self.protocol.error_received.called) + m_exc.assert_not_called() + + class BaseProactorEventLoopTests(test_utils.TestCase): def setUp(self): @@ -864,6 +1068,80 @@ class BaseProactorEventLoopTests(test_utils.TestCase): self.assertFalse(sock2.close.called) self.assertFalse(future2.cancel.called) + def datagram_transport(self): + self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol) + return self.loop._make_datagram_transport(self.sock, self.protocol) + + def test_make_datagram_transport(self): + tr = self.datagram_transport() + self.assertIsInstance(tr, _ProactorDatagramTransport) + close_transport(tr) + + def test_datagram_loop_writing(self): + tr = self.datagram_transport() + tr._buffer.appendleft((b'data', ('127.0.0.1', 12068))) + tr._loop_writing() + self.loop._proactor.sendto.assert_called_with(self.sock, b'data', addr=('127.0.0.1', 12068)) + self.loop._proactor.sendto.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + close_transport(tr) + + def test_datagram_loop_reading(self): + tr = self.datagram_transport() + tr._loop_reading() + self.loop._proactor.recvfrom.assert_called_with(self.sock, 256 * 1024) + self.assertFalse(self.protocol.datagram_received.called) + self.assertFalse(self.protocol.error_received.called) + close_transport(tr) + + def test_datagram_loop_reading_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result((b'data', ('127.0.0.1', 12068))) + + tr = self.datagram_transport() + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recvfrom.assert_called_with(self.sock, 256 * 1024) + self.protocol.datagram_received.assert_called_with(b'data', ('127.0.0.1', 12068)) + close_transport(tr) + + def test_datagram_loop_reading_no_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result((b'', ('127.0.0.1', 12068))) + + tr = self.datagram_transport() + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertTrue(self.loop._proactor.recvfrom.called) + self.assertFalse(self.protocol.error_received.called) + self.assertFalse(tr.close.called) + close_transport(tr) + + def test_datagram_loop_reading_aborted(self): + err = self.loop._proactor.recvfrom.side_effect = ConnectionAbortedError() + + tr = self.datagram_transport() + tr._fatal_error = mock.Mock() + tr._protocol.error_received = mock.Mock() + tr._loop_reading() + tr._protocol.error_received.assert_called_with(err) + close_transport(tr) + + def test_datagram_loop_writing_aborted(self): + err = self.loop._proactor.sendto.side_effect = ConnectionAbortedError() + + tr = self.datagram_transport() + tr._fatal_error = mock.Mock() + tr._protocol.error_received = mock.Mock() + tr._buffer.appendleft((b'Hello', ('127.0.0.1', 12068))) + tr._loop_writing() + tr._protocol.error_received.assert_called_with(err) + close_transport(tr) + @unittest.skipIf(sys.platform != 'win32', 'Proactor is supported on Windows only') diff --git a/Misc/NEWS.d/next/Windows/2018-09-15-11-36-55.bpo-29883.HErerE.rst b/Misc/NEWS.d/next/Windows/2018-09-15-11-36-55.bpo-29883.HErerE.rst new file mode 100644 index 00000000000..b6d1375c775 --- /dev/null +++ b/Misc/NEWS.d/next/Windows/2018-09-15-11-36-55.bpo-29883.HErerE.rst @@ -0,0 +1,2 @@ +Add Windows support for UDP transports for the Proactor Event Loop. Patch by +Adam Meily. diff --git a/Modules/overlapped.c b/Modules/overlapped.c index e5a209bf758..aad531e4789 100644 --- a/Modules/overlapped.c +++ b/Modules/overlapped.c @@ -39,7 +39,8 @@ enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_READINTO, TYPE_WRITE, TYPE_ACCEPT, TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE, - TYPE_WAIT_NAMED_PIPE_AND_CONNECT, TYPE_TRANSMIT_FILE}; + TYPE_WAIT_NAMED_PIPE_AND_CONNECT, TYPE_TRANSMIT_FILE, TYPE_READ_FROM, + TYPE_WRITE_TO}; typedef struct { PyObject_HEAD @@ -53,8 +54,19 @@ typedef struct { union { /* Buffer allocated by us: TYPE_READ and TYPE_ACCEPT */ PyObject *allocated_buffer; - /* Buffer passed by the user: TYPE_WRITE and TYPE_READINTO */ + /* Buffer passed by the user: TYPE_WRITE, TYPE_WRITE_TO, and TYPE_READINTO */ Py_buffer user_buffer; + + /* Data used for reading from a connectionless socket: + TYPE_READ_FROM */ + struct { + // A (buffer, (host, port)) tuple + PyObject *result; + // The actual read buffer + PyObject *allocated_buffer; + struct sockaddr_in6 address; + int address_length; + } read_from; }; } OverlappedObject; @@ -570,16 +582,32 @@ static int Overlapped_clear(OverlappedObject *self) { switch (self->type) { - case TYPE_READ: - case TYPE_ACCEPT: - Py_CLEAR(self->allocated_buffer); - break; - case TYPE_WRITE: - case TYPE_READINTO: - if (self->user_buffer.obj) { - PyBuffer_Release(&self->user_buffer); + case TYPE_READ: + case TYPE_ACCEPT: { + Py_CLEAR(self->allocated_buffer); + break; + } + case TYPE_READ_FROM: { + // An initial call to WSARecvFrom will only allocate the buffer. + // The result tuple of (message, address) is only + // allocated _after_ a message has been received. + if(self->read_from.result) { + // We've received a message, free the result tuple. + Py_CLEAR(self->read_from.result); + } + if(self->read_from.allocated_buffer) { + Py_CLEAR(self->read_from.allocated_buffer); + } + break; + } + case TYPE_WRITE: + case TYPE_WRITE_TO: + case TYPE_READINTO: { + if (self->user_buffer.obj) { + PyBuffer_Release(&self->user_buffer); + } + break; } - break; } self->type = TYPE_NOT_STARTED; return 0; @@ -627,6 +655,73 @@ Overlapped_dealloc(OverlappedObject *self) SetLastError(olderr); } + +/* Convert IPv4 sockaddr to a Python str. */ + +static PyObject * +make_ipv4_addr(const struct sockaddr_in *addr) +{ + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &addr->sin_addr, buf, sizeof(buf)) == NULL) { + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + return PyUnicode_FromString(buf); +} + +#ifdef ENABLE_IPV6 +/* Convert IPv6 sockaddr to a Python str. */ + +static PyObject * +make_ipv6_addr(const struct sockaddr_in6 *addr) +{ + char buf[INET6_ADDRSTRLEN]; + if (inet_ntop(AF_INET6, &addr->sin6_addr, buf, sizeof(buf)) == NULL) { + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + return PyUnicode_FromString(buf); +} +#endif + +static PyObject* +unparse_address(LPSOCKADDR Address, DWORD Length) +{ + /* The function is adopted from mocketmodule.c makesockaddr()*/ + + switch(Address->sa_family) { + case AF_INET: { + const struct sockaddr_in *a = (const struct sockaddr_in *)Address; + PyObject *addrobj = make_ipv4_addr(a); + PyObject *ret = NULL; + if (addrobj) { + ret = Py_BuildValue("Oi", addrobj, ntohs(a->sin_port)); + Py_DECREF(addrobj); + } + return ret; + } +#ifdef ENABLE_IPV6 + case AF_INET6: { + const struct sockaddr_in6 *a = (const struct sockaddr_in6 *)Address; + PyObject *addrobj = make_ipv6_addr(a); + PyObject *ret = NULL; + if (addrobj) { + ret = Py_BuildValue("OiII", + addrobj, + ntohs(a->sin6_port), + ntohl(a->sin6_flowinfo), + a->sin6_scope_id); + Py_DECREF(addrobj); + } + return ret; + } +#endif /* ENABLE_IPV6 */ + default: { + return SetFromWindowsErr(ERROR_INVALID_PARAMETER); + } + } +} + PyDoc_STRVAR( Overlapped_cancel_doc, "cancel() -> None\n\n" @@ -670,6 +765,7 @@ Overlapped_getresult(OverlappedObject *self, PyObject *args) DWORD transferred = 0; BOOL ret; DWORD err; + PyObject *addr; if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) return NULL; @@ -695,8 +791,15 @@ Overlapped_getresult(OverlappedObject *self, PyObject *args) case ERROR_MORE_DATA: break; case ERROR_BROKEN_PIPE: - if (self->type == TYPE_READ || self->type == TYPE_READINTO) + if (self->type == TYPE_READ || self->type == TYPE_READINTO) { break; + } + else if (self->type == TYPE_READ_FROM && + (self->read_from.result != NULL || + self->read_from.allocated_buffer != NULL)) + { + break; + } /* fall through */ default: return SetFromWindowsErr(err); @@ -708,8 +811,43 @@ Overlapped_getresult(OverlappedObject *self, PyObject *args) if (transferred != PyBytes_GET_SIZE(self->allocated_buffer) && _PyBytes_Resize(&self->allocated_buffer, transferred)) return NULL; + Py_INCREF(self->allocated_buffer); return self->allocated_buffer; + case TYPE_READ_FROM: + assert(PyBytes_CheckExact(self->read_from.allocated_buffer)); + + if (transferred != PyBytes_GET_SIZE( + self->read_from.allocated_buffer) && + _PyBytes_Resize(&self->read_from.allocated_buffer, transferred)) + { + return NULL; + } + + // unparse the address + addr = unparse_address((SOCKADDR*)&self->read_from.address, + self->read_from.address_length); + + if (addr == NULL) { + return NULL; + } + + // The result is a two item tuple: (message, address) + self->read_from.result = PyTuple_New(2); + if (self->read_from.result == NULL) { + Py_CLEAR(addr); + return NULL; + } + + // first item: message + Py_INCREF(self->read_from.allocated_buffer); + PyTuple_SET_ITEM(self->read_from.result, 0, + self->read_from.allocated_buffer); + // second item: address + PyTuple_SET_ITEM(self->read_from.result, 1, addr); + + Py_INCREF(self->read_from.result); + return self->read_from.result; default: return PyLong_FromUnsignedLong((unsigned long) transferred); } @@ -1121,7 +1259,6 @@ parse_address(PyObject *obj, SOCKADDR *Address, int Length) return -1; } - PyDoc_STRVAR( Overlapped_ConnectEx_doc, "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" @@ -1314,7 +1451,7 @@ PyDoc_STRVAR( "Connect to the pipe for asynchronous I/O (overlapped)."); static PyObject * -ConnectPipe(OverlappedObject *self, PyObject *args) +overlapped_ConnectPipe(PyObject *self, PyObject *args) { PyObject *AddressObj; wchar_t *Address; @@ -1362,15 +1499,213 @@ Overlapped_traverse(OverlappedObject *self, visitproc visit, void *arg) Py_VISIT(self->allocated_buffer); break; case TYPE_WRITE: + case TYPE_WRITE_TO: case TYPE_READINTO: if (self->user_buffer.obj) { Py_VISIT(&self->user_buffer.obj); } break; + case TYPE_READ_FROM: + if(self->read_from.result) { + Py_VISIT(self->read_from.result); + } + if(self->read_from.allocated_buffer) { + Py_VISIT(self->read_from.allocated_buffer); + } } return 0; } +// UDP functions + +PyDoc_STRVAR( + WSAConnect_doc, + "WSAConnect(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Bind a remote address to a connectionless (UDP) socket"); + +/* + * Note: WSAConnect does not support Overlapped I/O so this function should + * _only_ be used for connectionless sockets (UDP). + */ +static PyObject * +overlapped_WSAConnect(PyObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + int err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) { + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) { + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + // WSAConnect does not support overlapped I/O so this call will + // successfully complete immediately. + err = WSAConnect(ConnectSocket, Address, Length, + NULL, NULL, NULL, NULL); + Py_END_ALLOW_THREADS + + if (err == 0) { + Py_RETURN_NONE; + } + else { + return SetFromWindowsErr(WSAGetLastError()); + } +} + +PyDoc_STRVAR( + Overlapped_WSASendTo_doc, + "WSASendTo(handle, buf, flags, address_as_bytes) -> " + "Overlapped[bytes_transferred]\n\n" + "Start overlapped sendto over a connectionless (UDP) socket"); + +static PyObject * +Overlapped_WSASendTo(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int AddressLength; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD "O", + &handle, &bufobj, &flags, &AddressObj)) + { + return NULL; + } + + // Parse the "to" address + AddressLength = sizeof(AddressBuf); + AddressLength = parse_address(AddressObj, Address, AddressLength); + if (AddressLength < 0) { + return NULL; + } + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->user_buffer)) { + return NULL; + } + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->user_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->user_buffer); + PyErr_SetString(PyExc_ValueError, "buffer too large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE_TO; + self->handle = handle; + wsabuf.len = (DWORD)self->user_buffer.len; + wsabuf.buf = self->user_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASendTo((SOCKET)handle, &wsabuf, 1, &written, flags, + Address, AddressLength, &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret == SOCKET_ERROR ? WSAGetLastError() : + ERROR_SUCCESS); + + switch(err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + + +PyDoc_STRVAR( + Overlapped_WSARecvFrom_doc, + "RecvFile(handle, size, flags) -> Overlapped[(message, (host, port))]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecvFrom(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + { + return NULL; + } + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) { + return NULL; + } + + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + self->type = TYPE_READ_FROM; + self->handle = handle; + self->read_from.allocated_buffer = buf; + memset(&self->read_from.address, 0, sizeof(self->read_from.address)); + self->read_from.address_length = sizeof(self->read_from.address); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecvFrom((SOCKET)handle, &wsabuf, 1, &nread, &flags, + (SOCKADDR*)&self->read_from.address, + &self->read_from.address_length, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + + switch(err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + return SetFromWindowsErr(err); + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + static PyMethodDef Overlapped_methods[] = { {"getresult", (PyCFunction) Overlapped_getresult, @@ -1399,6 +1734,10 @@ static PyMethodDef Overlapped_methods[] = { METH_VARARGS, Overlapped_TransmitFile_doc}, {"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe, METH_VARARGS, Overlapped_ConnectNamedPipe_doc}, + {"WSARecvFrom", (PyCFunction) Overlapped_WSARecvFrom, + METH_VARARGS, Overlapped_WSARecvFrom_doc }, + {"WSASendTo", (PyCFunction) Overlapped_WSASendTo, + METH_VARARGS, Overlapped_WSASendTo_doc }, {NULL} }; @@ -1484,9 +1823,10 @@ static PyMethodDef overlapped_functions[] = { METH_VARARGS, SetEvent_doc}, {"ResetEvent", overlapped_ResetEvent, METH_VARARGS, ResetEvent_doc}, - {"ConnectPipe", - (PyCFunction) ConnectPipe, + {"ConnectPipe", overlapped_ConnectPipe, METH_VARARGS, ConnectPipe_doc}, + {"WSAConnect", overlapped_WSAConnect, + METH_VARARGS, WSAConnect_doc}, {NULL} };