diff --git a/Lib/ssl.py b/Lib/ssl.py index 3162f56a372..5e5a5ce091d 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -553,16 +553,11 @@ class SSLSocket(socket): SSL channel, and the address of the remote client.""" newsock, addr = socket.accept(self) - return (SSLSocket(sock=newsock, - keyfile=self.keyfile, certfile=self.certfile, - server_side=True, - cert_reqs=self.cert_reqs, - ssl_version=self.ssl_version, - ca_certs=self.ca_certs, - ciphers=self.ciphers, - do_handshake_on_connect= - self.do_handshake_on_connect), - addr) + newsock = self.context.wrap_socket(newsock, + do_handshake_on_connect=self.do_handshake_on_connect, + suppress_ragged_eofs=self.suppress_ragged_eofs, + server_side=True) + return newsock, addr def get_channel_binding(self, cb_type="tls-unique"): """Get channel binding data for current connection. Raise ValueError diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 4ce98b6ef25..74abbd23b34 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -1796,6 +1796,42 @@ else: t.join() server.close() + def test_server_accept(self): + # Issue #16357: accept() on a SSLSocket created through + # SSLContext.wrap_socket(). + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = socket.socket(socket.AF_INET) + host = "127.0.0.1" + port = support.bind_port(server) + server = context.wrap_socket(server, server_side=True) + + evt = threading.Event() + remote = None + peer = None + def serve(): + nonlocal remote, peer + server.listen(5) + # Block on the accept and wait on the connection to close. + evt.set() + remote, peer = server.accept() + remote.recv(1) + + t = threading.Thread(target=serve) + t.start() + # Client wait until server setup and perform a connect. + evt.wait() + client = context.wrap_socket(socket.socket()) + client.connect((host, port)) + client_addr = client.getsockname() + client.close() + t.join() + # Sanity checks. + self.assertIsInstance(remote, ssl.SSLSocket) + self.assertEqual(peer, client_addr) + def test_default_ciphers(self): context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) try: diff --git a/Misc/NEWS b/Misc/NEWS index 0646a74c6bf..9c0ea03fc43 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -80,6 +80,9 @@ Core and Builtins Library ------- +- Issue #16357: fix calling accept() on a SSLSocket created through + SSLContext.wrap_socket(). Original patch by Jeff McNeil. + - Issue #16409: The reporthook callback made by the legacy urllib.request.urlretrieve API now properly supplies a constant non-zero block_size as it did in Python 3.2 and 2.7. This matches the behavior of