diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 172a463ef80..017437552fc 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -707,8 +707,6 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'host and port was not specified and no sock specified') - sock.setblocking(False) - transport, protocol = yield from self._create_connection_transport( sock, protocol_factory, ssl, server_hostname) if self._debug: @@ -721,14 +719,17 @@ class BaseEventLoop(events.AbstractEventLoop): @coroutine def _create_connection_transport(self, sock, protocol_factory, ssl, - server_hostname): + server_hostname, server_side=False): + + sock.setblocking(False) + protocol = protocol_factory() waiter = self.create_future() if ssl: sslcontext = None if isinstance(ssl, bool) else ssl transport = self._make_ssl_transport( sock, protocol, sslcontext, waiter, - server_side=False, server_hostname=server_hostname) + server_side=server_side, server_hostname=server_hostname) else: transport = self._make_socket_transport(sock, protocol, waiter) @@ -979,6 +980,25 @@ class BaseEventLoop(events.AbstractEventLoop): logger.info("%r is serving", server) return server + @coroutine + def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None): + """Handle an accepted connection. + + This is used by servers that accept connections outside of + asyncio but that use asyncio to handle connections. + + This method is a coroutine. When completed, the coroutine + returns a (transport, protocol) pair. + """ + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, '', server_side=True) + if self._debug: + # Get the socket from the transport because SSL transport closes + # the old socket and creates a new SSL socket + sock = transport.get_extra_info('socket') + logger.debug("%r handled: (%r, %r)", sock, transport, protocol) + return transport, protocol + @coroutine def connect_read_pipe(self, protocol_factory, pipe): protocol = protocol_factory() diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index d0777758a7d..5c186cebb68 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -744,6 +744,85 @@ class EventLoopTestsMixin: self.assertEqual(cm.exception.errno, errno.EADDRINUSE) self.assertIn(str(httpd.address), cm.exception.strerror) + def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None): + loop = self.loop + + class MyProto(MyBaseProto): + + def connection_lost(self, exc): + super().connection_lost(exc) + loop.call_soon(loop.stop) + + def data_received(self, data): + super().data_received(data) + self.transport.write(expected_response) + + lsock = socket.socket() + lsock.bind(('127.0.0.1', 0)) + lsock.listen(1) + addr = lsock.getsockname() + + message = b'test data' + reponse = None + expected_response = b'roger' + + def client(): + global response + try: + csock = socket.socket() + if client_ssl is not None: + csock = client_ssl.wrap_socket(csock) + csock.connect(addr) + csock.sendall(message) + response = csock.recv(99) + csock.close() + except Exception as exc: + print( + "Failure in client thread in test_connect_accepted_socket", + exc) + + thread = threading.Thread(target=client, daemon=True) + thread.start() + + conn, _ = lsock.accept() + proto = MyProto(loop=loop) + proto.loop = loop + f = loop.create_task( + loop.connect_accepted_socket( + (lambda : proto), conn, ssl=server_ssl)) + loop.run_forever() + conn.close() + lsock.close() + + thread.join(1) + self.assertFalse(thread.is_alive()) + self.assertEqual(proto.state, 'CLOSED') + self.assertEqual(proto.nbytes, len(message)) + self.assertEqual(response, expected_response) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_ssl_connect_accepted_socket(self): + if (sys.platform == 'win32' and + sys.version_info < (3, 5) and + isinstance(self.loop, proactor_events.BaseProactorEventLoop) + ): + raise unittest.SkipTest( + 'SSL not supported with proactor event loops before Python 3.5' + ) + + server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + server_context.load_cert_chain(ONLYCERT, ONLYKEY) + if hasattr(server_context, 'check_hostname'): + server_context.check_hostname = False + server_context.verify_mode = ssl.CERT_NONE + + client_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + if hasattr(server_context, 'check_hostname'): + client_context.check_hostname = False + client_context.verify_mode = ssl.CERT_NONE + + self.test_connect_accepted_socket(server_context, client_context) + @mock.patch('asyncio.base_events.socket') def create_server_multiple_hosts(self, family, hosts, mock_sock): @asyncio.coroutine diff --git a/Misc/NEWS b/Misc/NEWS index b7ad75b2628..3b7411cdf71 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -72,6 +72,9 @@ Library - Issue #26930: Update Windows builds to use OpenSSL 1.0.2h. +- Issue #27392: Add loop.connect_accepted_socket(). + Patch by Jim Fulton. + IDLE ----