From a170fa162dc03f0a014373349e548954fff2e567 Mon Sep 17 00:00:00 2001 From: Christian Heimes Date: Fri, 15 Sep 2017 20:27:30 +0200 Subject: [PATCH] bpo-31346: Use PROTOCOL_TLS_CLIENT/SERVER (#3058) Replaces PROTOCOL_TLSv* and PROTOCOL_SSLv23 with PROTOCOL_TLS_CLIENT and PROTOCOL_TLS_SERVER. Signed-off-by: Christian Heimes --- Lib/asyncio/test_utils.py | 2 +- Lib/ftplib.py | 4 +- Lib/ssl.py | 7 +- Lib/test/test_asyncio/test_events.py | 16 +- Lib/test/test_ftplib.py | 18 +- Lib/test/test_httplib.py | 15 +- Lib/test/test_imaplib.py | 18 +- Lib/test/test_logging.py | 2 +- Lib/test/test_poplib.py | 10 +- Lib/test/test_smtpnet.py | 8 +- Lib/test/test_ssl.py | 528 +++++++++--------- Lib/test/test_urllib2_localnet.py | 2 +- .../2017-09-04-16-21-18.bpo-31346.xni1VR.rst | 1 + 13 files changed, 321 insertions(+), 310 deletions(-) create mode 100644 Misc/NEWS.d/next/Tests/2017-09-04-16-21-18.bpo-31346.xni1VR.rst diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py index ecb3fca097b..65805947fda 100644 --- a/Lib/asyncio/test_utils.py +++ b/Lib/asyncio/test_utils.py @@ -45,7 +45,7 @@ def dummy_ssl_context(): if ssl is None: return None else: - return ssl.SSLContext(ssl.PROTOCOL_SSLv23) + return ssl.SSLContext(ssl.PROTOCOL_TLS) def run_briefly(loop): diff --git a/Lib/ftplib.py b/Lib/ftplib.py index a02e595cb02..05840d49236 100644 --- a/Lib/ftplib.py +++ b/Lib/ftplib.py @@ -719,7 +719,7 @@ else: '221 Goodbye.' >>> ''' - ssl_version = ssl.PROTOCOL_SSLv23 + ssl_version = ssl.PROTOCOL_TLS_CLIENT def __init__(self, host='', user='', passwd='', acct='', keyfile=None, certfile=None, context=None, @@ -753,7 +753,7 @@ else: '''Set up secure control connection by using TLS/SSL.''' if isinstance(self.sock, ssl.SSLSocket): raise ValueError("Already using TLS") - if self.ssl_version >= ssl.PROTOCOL_SSLv23: + if self.ssl_version >= ssl.PROTOCOL_TLS: resp = self.voidcmd('AUTH TLS') else: resp = self.voidcmd('AUTH SSL') diff --git a/Lib/ssl.py b/Lib/ssl.py index 2849deee07e..24f24b17bf1 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -522,7 +522,7 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, context.load_default_certs(purpose) return context -def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=None, +def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=CERT_NONE, check_hostname=False, purpose=Purpose.SERVER_AUTH, certfile=None, keyfile=None, cafile=None, capath=None, cadata=None): @@ -541,9 +541,12 @@ def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=None, # by default. context = SSLContext(protocol) + if not check_hostname: + context.check_hostname = False if cert_reqs is not None: context.verify_mode = cert_reqs - context.check_hostname = check_hostname + if check_hostname: + context.check_hostname = True if keyfile and not certfile: raise ValueError("certfile must be specified") diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 33421ce4c37..736f703c2fb 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -824,13 +824,13 @@ class EventLoopTestsMixin: 'SSL not supported with proactor event loops before Python 3.5' ) - server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) server_context.load_cert_chain(ONLYCERT, ONLYKEY) if hasattr(server_context, 'check_hostname'): server_context.check_hostname = False server_context.verify_mode = ssl.CERT_NONE - client_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) if hasattr(server_context, 'check_hostname'): client_context.check_hostname = False client_context.verify_mode = ssl.CERT_NONE @@ -985,7 +985,7 @@ class EventLoopTestsMixin: self.loop.run_until_complete(f) def _create_ssl_context(self, certfile, keyfile=None): - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslcontext.options |= ssl.OP_NO_SSLv2 sslcontext.load_cert_chain(certfile, keyfile) return sslcontext @@ -1082,7 +1082,7 @@ class EventLoopTestsMixin: server, host, port = self._make_ssl_server( lambda: proto, SIGNED_CERTFILE) - sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.verify_mode = ssl.CERT_REQUIRED if hasattr(sslcontext_client, 'check_hostname'): @@ -1116,7 +1116,7 @@ class EventLoopTestsMixin: server, path = self._make_ssl_unix_server( lambda: proto, SIGNED_CERTFILE) - sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.verify_mode = ssl.CERT_REQUIRED if hasattr(sslcontext_client, 'check_hostname'): @@ -1150,7 +1150,7 @@ class EventLoopTestsMixin: server, host, port = self._make_ssl_server( lambda: proto, SIGNED_CERTFILE) - sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.verify_mode = ssl.CERT_REQUIRED sslcontext_client.load_verify_locations( @@ -1183,7 +1183,7 @@ class EventLoopTestsMixin: server, path = self._make_ssl_unix_server( lambda: proto, SIGNED_CERTFILE) - sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.verify_mode = ssl.CERT_REQUIRED sslcontext_client.load_verify_locations(cafile=SIGNING_CA) @@ -1212,7 +1212,7 @@ class EventLoopTestsMixin: server, host, port = self._make_ssl_server( lambda: proto, SIGNED_CERTFILE) - sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.verify_mode = ssl.CERT_REQUIRED sslcontext_client.load_verify_locations(cafile=SIGNING_CA) diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index 5880a1e941d..f1b0185b2bf 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -902,17 +902,11 @@ class TestTLS_FTPClass(TestCase): self.client.auth() self.assertRaises(ValueError, self.client.auth) - def test_auth_ssl(self): - try: - self.client.ssl_version = ssl.PROTOCOL_SSLv23 - self.client.auth() - self.assertRaises(ValueError, self.client.auth) - finally: - self.client.ssl_version = ssl.PROTOCOL_TLSv1 - def test_context(self): self.client.quit() - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE self.assertRaises(ValueError, ftplib.FTP_TLS, keyfile=CERTFILE, context=ctx) self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE, @@ -941,9 +935,9 @@ class TestTLS_FTPClass(TestCase): def test_check_hostname(self): self.client.quit() - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.check_hostname = True + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self.assertEqual(ctx.check_hostname, True) ctx.load_verify_locations(CAFILE) self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT) diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index 68f6946a3a1..ab798a2525e 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -1583,8 +1583,9 @@ class HTTPSTest(TestCase): import ssl support.requires('network') with support.transient_internet('self-signed.pythontest.net'): - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = ssl.CERT_REQUIRED + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertEqual(context.verify_mode, ssl.CERT_REQUIRED) + self.assertEqual(context.check_hostname, True) context.load_verify_locations(CERT_selfsigned_pythontestdotnet) h = client.HTTPSConnection('self-signed.pythontest.net', 443, context=context) h.request('GET', '/') @@ -1599,8 +1600,7 @@ class HTTPSTest(TestCase): import ssl support.requires('network') with support.transient_internet('self-signed.pythontest.net'): - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = ssl.CERT_REQUIRED + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.load_verify_locations(CERT_localhost) h = client.HTTPSConnection('self-signed.pythontest.net', 443, context=context) with self.assertRaises(ssl.SSLError) as exc_info: @@ -1620,8 +1620,7 @@ class HTTPSTest(TestCase): # The (valid) cert validates the HTTP hostname import ssl server = self.make_server(CERT_localhost) - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = ssl.CERT_REQUIRED + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.load_verify_locations(CERT_localhost) h = client.HTTPSConnection('localhost', server.port, context=context) self.addCleanup(h.close) @@ -1634,9 +1633,7 @@ class HTTPSTest(TestCase): # The (valid) cert doesn't validate the HTTP hostname import ssl server = self.make_server(CERT_fakehostname) - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = ssl.CERT_REQUIRED - context.check_hostname = True + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.load_verify_locations(CERT_fakehostname) h = client.HTTPSConnection('localhost', server.port, context=context) with self.assertRaises(ssl.CertificateError): diff --git a/Lib/test/test_imaplib.py b/Lib/test/test_imaplib.py index 2b62b05a594..4a45be65724 100644 --- a/Lib/test/test_imaplib.py +++ b/Lib/test/test_imaplib.py @@ -479,9 +479,9 @@ class NewIMAPSSLTests(NewIMAPTestsMixin, unittest.TestCase): server_class = SecureTCPServer def test_ssl_raises(self): - ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ssl_context.verify_mode = ssl.CERT_REQUIRED - ssl_context.check_hostname = True + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertEqual(ssl_context.verify_mode, ssl.CERT_REQUIRED) + self.assertEqual(ssl_context.check_hostname, True) ssl_context.load_verify_locations(CAFILE) with self.assertRaisesRegex(ssl.CertificateError, @@ -492,9 +492,7 @@ class NewIMAPSSLTests(NewIMAPTestsMixin, unittest.TestCase): client.shutdown() def test_ssl_verified(self): - ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ssl_context.verify_mode = ssl.CERT_REQUIRED - ssl_context.check_hostname = True + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(CAFILE) _, server = self._setup(SimpleIMAPHandler) @@ -871,9 +869,7 @@ class ThreadedNetworkedTestsSSL(ThreadedNetworkedTests): @reap_threads def test_ssl_verified(self): - ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ssl_context.verify_mode = ssl.CERT_REQUIRED - ssl_context.check_hostname = True + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(CAFILE) with self.assertRaisesRegex( @@ -953,7 +949,9 @@ class RemoteIMAP_SSLTest(RemoteIMAPTest): pass def create_ssl_context(self): - ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE ssl_context.load_cert_chain(CERTFILE) return ssl_context diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index d264d786720..fb0de9de046 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -1792,7 +1792,7 @@ class HTTPHandlerTest(BaseTest): else: here = os.path.dirname(__file__) localhost_cert = os.path.join(here, "keycert.pem") - sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx.load_cert_chain(localhost_cert) context = ssl.create_default_context(cafile=localhost_cert) diff --git a/Lib/test/test_poplib.py b/Lib/test/test_poplib.py index 721ea252011..9ba678f2039 100644 --- a/Lib/test/test_poplib.py +++ b/Lib/test/test_poplib.py @@ -352,10 +352,10 @@ class TestPOP3Class(TestCase): @requires_ssl def test_stls_context(self): expected = b'+OK Begin TLS negotiation' - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(CAFILE) - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.check_hostname = True + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self.assertEqual(ctx.check_hostname, True) with self.assertRaises(ssl.CertificateError): resp = self.client.stls(context=ctx) self.client = poplib.POP3("localhost", self.server.port, timeout=3) @@ -392,7 +392,9 @@ class TestPOP3_SSLClass(TestPOP3Class): self.assertIn('POP3_SSL', poplib.__all__) def test_context(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE self.assertRaises(ValueError, poplib.POP3_SSL, self.server.host, self.server.port, keyfile=CERTFILE, context=ctx) self.assertRaises(ValueError, poplib.POP3_SSL, self.server.host, diff --git a/Lib/test/test_smtpnet.py b/Lib/test/test_smtpnet.py index cc9bab43f4f..b69cd9de627 100644 --- a/Lib/test/test_smtpnet.py +++ b/Lib/test/test_smtpnet.py @@ -25,7 +25,9 @@ class SmtpTest(unittest.TestCase): def test_connect_starttls(self): support.get_attribute(smtplib, 'SMTP_SSL') - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE with support.transient_internet(self.testServer): server = smtplib.SMTP(self.testServer, self.remotePort) try: @@ -58,7 +60,9 @@ class SmtpSSLTest(unittest.TestCase): server.quit() def test_connect_using_sslcontext(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE support.get_attribute(smtplib, 'SMTP_SSL') with support.transient_internet(self.testServer): server = smtplib.SMTP_SSL(self.testServer, self.remotePort, context=context) diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index fb5958f1a5e..2978b8b3ebd 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -60,7 +60,9 @@ CRLFILE = data_file("revocation.crl") # Two keys and certs signed by the same CA (for SNI tests) SIGNED_CERTFILE = data_file("keycert3.pem") +SIGNED_CERTFILE_HOSTNAME = 'localhost' SIGNED_CERTFILE2 = data_file("keycert4.pem") +SIGNED_CERTFILE2_HOSTNAME = 'fakehostname' # Same certificate as pycacert.pem, but without extra text in file SIGNING_CA = data_file("capath", "ceff1710.0") # cert with all kinds of subject alt names @@ -147,6 +149,8 @@ def test_wrap_socket(sock, ssl_version=ssl.PROTOCOL_TLS, *, **kwargs): context = ssl.SSLContext(ssl_version) if cert_reqs is not None: + if cert_reqs == ssl.CERT_NONE: + context.check_hostname = False context.verify_mode = cert_reqs if ca_certs is not None: context.load_verify_locations(ca_certs) @@ -156,6 +160,28 @@ def test_wrap_socket(sock, ssl_version=ssl.PROTOCOL_TLS, *, context.set_ciphers(ciphers) return context.wrap_socket(sock, **kwargs) + +def testing_context(server_cert=SIGNED_CERTFILE): + """Create context + + client_context, server_context, hostname = testing_context() + """ + if server_cert == SIGNED_CERTFILE: + hostname = SIGNED_CERTFILE_HOSTNAME + elif server_cert == SIGNED_CERTFILE2: + hostname = SIGNED_CERTFILE2_HOSTNAME + else: + raise ValueError(server_cert) + + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_context.load_verify_locations(SIGNING_CA) + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(server_cert) + + return client_context, server_context, hostname + + class BasicSocketTests(unittest.TestCase): def test_constants(self): @@ -177,6 +203,7 @@ class BasicSocketTests(unittest.TestCase): if ssl.OPENSSL_VERSION_INFO >= (1, 0, 1): ssl.OP_NO_TLSv1_1 ssl.OP_NO_TLSv1_2 + self.assertEqual(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv23) def test_str_for_enums(self): # Make sure that the PROTOCOL_* constants have enum-like string @@ -447,8 +474,7 @@ class BasicSocketTests(unittest.TestCase): self.addCleanup(sock.close) with self.assertRaises(ssl.SSLError): test_wrap_socket(sock, - certfile=certfile, - ssl_version=ssl.PROTOCOL_TLSv1) + certfile=certfile) def test_empty_cert(self): """Wrapping with an empty cert file""" @@ -621,7 +647,7 @@ class BasicSocketTests(unittest.TestCase): def test_server_side(self): # server_hostname doesn't work for server sockets - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) with socket.socket() as sock: self.assertRaises(ValueError, ctx.wrap_socket, sock, True, server_hostname="some.hostname") @@ -772,7 +798,7 @@ class BasicSocketTests(unittest.TestCase): with self.assertRaises(NotImplementedError) as cx: test_wrap_socket(s, cert_reqs=ssl.CERT_NONE) self.assertEqual(str(cx.exception), "only stream sockets are supported") - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) with self.assertRaises(NotImplementedError) as cx: ctx.wrap_socket(s) self.assertEqual(str(cx.exception), "only stream sockets are supported") @@ -877,7 +903,7 @@ class ContextTests(unittest.TestCase): self.assertEqual(ctx.protocol, proto) def test_ciphers(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.set_ciphers("ALL") ctx.set_ciphers("DEFAULT") with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"): @@ -885,7 +911,7 @@ class ContextTests(unittest.TestCase): @unittest.skipIf(ssl.OPENSSL_VERSION_INFO < (1, 0, 2, 0, 0), 'OpenSSL too old') def test_get_ciphers(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.set_ciphers('AESGCM') names = set(d['name'] for d in ctx.get_ciphers()) self.assertIn('AES256-GCM-SHA384', names) @@ -893,7 +919,7 @@ class ContextTests(unittest.TestCase): @skip_if_broken_ubuntu_ssl def test_options(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3) # SSLContext also enables these by default @@ -912,8 +938,8 @@ class ContextTests(unittest.TestCase): with self.assertRaises(ValueError): ctx.options = 0 - def test_verify_mode(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + def test_verify_mode_protocol(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) # Default value self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) ctx.verify_mode = ssl.CERT_OPTIONAL @@ -927,10 +953,19 @@ class ContextTests(unittest.TestCase): with self.assertRaises(ValueError): ctx.verify_mode = 42 + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + self.assertFalse(ctx.check_hostname) + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self.assertTrue(ctx.check_hostname) + + @unittest.skipUnless(have_verify_flags(), "verify_flags need OpenSSL > 0.9.8") def test_verify_flags(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) # default value tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0) self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf) @@ -948,7 +983,7 @@ class ContextTests(unittest.TestCase): ctx.verify_flags = None def test_load_cert_chain(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) # Combined key and cert in a single file ctx.load_cert_chain(CERTFILE, keyfile=None) ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE) @@ -961,7 +996,7 @@ class ContextTests(unittest.TestCase): with self.assertRaisesRegex(ssl.SSLError, "PEM lib"): ctx.load_cert_chain(EMPTYCERT) # Separate key and cert - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ctx.load_cert_chain(ONLYCERT, ONLYKEY) ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY) ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY) @@ -972,7 +1007,7 @@ class ContextTests(unittest.TestCase): with self.assertRaisesRegex(ssl.SSLError, "PEM lib"): ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT) # Mismatching key and cert - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"): ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY) # Password protected key and cert @@ -1031,7 +1066,7 @@ class ContextTests(unittest.TestCase): ctx.load_cert_chain(CERTFILE, password=getpass_exception) def test_load_verify_locations(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ctx.load_verify_locations(CERTFILE) ctx.load_verify_locations(cafile=CERTFILE, capath=None) ctx.load_verify_locations(BYTES_CERTFILE) @@ -1059,7 +1094,7 @@ class ContextTests(unittest.TestCase): neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem) # test PEM - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0) ctx.load_verify_locations(cadata=cacert_pem) self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1) @@ -1070,20 +1105,20 @@ class ContextTests(unittest.TestCase): self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) # combined - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) combined = "\n".join((cacert_pem, neuronio_pem)) ctx.load_verify_locations(cadata=combined) self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) # with junk around the certs - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) combined = ["head", cacert_pem, "other", neuronio_pem, "again", neuronio_pem, "tail"] ctx.load_verify_locations(cadata="\n".join(combined)) self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) # test DER - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(cadata=cacert_der) ctx.load_verify_locations(cadata=neuronio_der) self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) @@ -1092,13 +1127,13 @@ class ContextTests(unittest.TestCase): self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) # combined - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) combined = b"".join((cacert_der, neuronio_der)) ctx.load_verify_locations(cadata=combined) self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) # error cases - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object) with self.assertRaisesRegex(ssl.SSLError, "no start line"): @@ -1108,7 +1143,7 @@ class ContextTests(unittest.TestCase): def test_load_dh_params(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ctx.load_dh_params(DHFILE) if os.name != 'nt': ctx.load_dh_params(BYTES_DHFILE) @@ -1141,12 +1176,12 @@ class ContextTests(unittest.TestCase): def test_set_default_verify_paths(self): # There's not much we can do to test that it acts as expected, # so just check it doesn't crash or raise an exception. - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.set_default_verify_paths() @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build") def test_set_ecdh_curve(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ctx.set_ecdh_curve("prime256v1") ctx.set_ecdh_curve(b"prime256v1") self.assertRaises(TypeError, ctx.set_ecdh_curve) @@ -1156,7 +1191,7 @@ class ContextTests(unittest.TestCase): @needs_sni def test_sni_callback(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) # set_servername_callback expects a callable, or None self.assertRaises(TypeError, ctx.set_servername_callback) @@ -1173,7 +1208,7 @@ class ContextTests(unittest.TestCase): def test_sni_callback_refcycle(self): # Reference cycles through the servername callback are detected # and cleared. - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) def dummycallback(sock, servername, ctx, cycle=ctx): pass ctx.set_servername_callback(dummycallback) @@ -1183,7 +1218,7 @@ class ContextTests(unittest.TestCase): self.assertIs(wr(), None) def test_cert_store_stats(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.assertEqual(ctx.cert_store_stats(), {'x509_ca': 0, 'crl': 0, 'x509': 0}) ctx.load_cert_chain(CERTFILE) @@ -1197,7 +1232,7 @@ class ContextTests(unittest.TestCase): {'x509_ca': 1, 'crl': 0, 'x509': 2}) def test_get_ca_certs(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.assertEqual(ctx.get_ca_certs(), []) # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE ctx.load_verify_locations(CERTFILE) @@ -1225,24 +1260,24 @@ class ContextTests(unittest.TestCase): self.assertEqual(ctx.get_ca_certs(True), [der]) def test_load_default_certs(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_default_certs() - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_default_certs(ssl.Purpose.SERVER_AUTH) ctx.load_default_certs() - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH) - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.assertRaises(TypeError, ctx.load_default_certs, None) self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH') @unittest.skipIf(sys.platform == "win32", "not-Windows specific") @unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars") def test_load_default_certs_env(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) with support.EnvironmentVarGuard() as env: env["SSL_CERT_DIR"] = CAPATH env["SSL_CERT_FILE"] = CERTFILE @@ -1252,11 +1287,11 @@ class ContextTests(unittest.TestCase): @unittest.skipUnless(sys.platform == "win32", "Windows specific") @unittest.skipIf(hasattr(sys, "gettotalrefcount"), "Debug build does not share environment between CRTs") def test_load_default_certs_env_windows(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_default_certs() stats = ctx.cert_store_stats() - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) with support.EnvironmentVarGuard() as env: env["SSL_CERT_DIR"] = CAPATH env["SSL_CERT_FILE"] = CERTFILE @@ -1282,28 +1317,27 @@ class ContextTests(unittest.TestCase): def test_create_default_context(self): ctx = ssl.create_default_context() - self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) self.assertTrue(ctx.check_hostname) self._assert_context_options(ctx) - with open(SIGNING_CA) as f: cadata = f.read() ctx = ssl.create_default_context(cafile=SIGNING_CA, capath=CAPATH, cadata=cadata) - self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) self._assert_context_options(ctx) ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS) self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) self._assert_context_options(ctx) def test__create_stdlib_context(self): ctx = ssl._create_stdlib_context() - self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS) self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) self.assertFalse(ctx.check_hostname) self._assert_context_options(ctx) @@ -1322,12 +1356,12 @@ class ContextTests(unittest.TestCase): self._assert_context_options(ctx) ctx = ssl._create_stdlib_context(purpose=ssl.Purpose.CLIENT_AUTH) - self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS) self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) self._assert_context_options(ctx) def test_check_hostname(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) self.assertFalse(ctx.check_hostname) # Requires CERT_REQUIRED or CERT_OPTIONAL @@ -1390,7 +1424,7 @@ class SSLErrorTests(unittest.TestCase): def test_lib_reason(self): # Test the library and reason attributes - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) with self.assertRaises(ssl.SSLError) as cm: ctx.load_dh_params(CERTFILE) self.assertEqual(cm.exception.library, 'PEM') @@ -1401,7 +1435,9 @@ class SSLErrorTests(unittest.TestCase): def test_subclass(self): # Check that the appropriate SSLError subclass is raised # (this only tests one of them) - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE with socket.socket() as s: s.bind(("127.0.0.1", 0)) s.listen() @@ -1562,7 +1598,7 @@ class SimpleBackgroundTests(unittest.TestCase): def test_connect_with_context(self): # Same as test_connect, but with a separately created context - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: s.connect(self.server_addr) self.assertEqual({}, s.getpeercert()) @@ -1582,7 +1618,7 @@ class SimpleBackgroundTests(unittest.TestCase): # This should fail because we have no verification certs. Connection # failure crashes ThreadedEchoServer, so run this in an independent # test method. - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) ctx.verify_mode = ssl.CERT_REQUIRED s = ctx.wrap_socket(socket.socket(socket.AF_INET)) self.addCleanup(s.close) @@ -1595,7 +1631,7 @@ class SimpleBackgroundTests(unittest.TestCase): # OpenSSL 0.9.8n and 1.0.0, as a result the capath directory must # contain both versions of each certificate (same content, different # filename) for this test to be portable across OpenSSL releases. - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_verify_locations(capath=CAPATH) with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: @@ -1603,7 +1639,7 @@ class SimpleBackgroundTests(unittest.TestCase): cert = s.getpeercert() self.assertTrue(cert) # Same with a bytes `capath` argument - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_verify_locations(capath=BYTES_CAPATH) with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: @@ -1615,7 +1651,7 @@ class SimpleBackgroundTests(unittest.TestCase): with open(SIGNING_CA) as f: pem = f.read() der = ssl.PEM_cert_to_DER_cert(pem) - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_verify_locations(cadata=pem) with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: @@ -1624,7 +1660,7 @@ class SimpleBackgroundTests(unittest.TestCase): self.assertTrue(cert) # same with DER - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_verify_locations(cadata=der) with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: @@ -1696,11 +1732,11 @@ class SimpleBackgroundTests(unittest.TestCase): def test_get_ca_certs_capath(self): # capath certs are loaded on request - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ctx.verify_mode = ssl.CERT_REQUIRED + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(capath=CAPATH) self.assertEqual(ctx.get_ca_certs(), []) - with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: + with ctx.wrap_socket(socket.socket(socket.AF_INET), + server_hostname='localhost') as s: s.connect(self.server_addr) cert = s.getpeercert() self.assertTrue(cert) @@ -1709,10 +1745,12 @@ class SimpleBackgroundTests(unittest.TestCase): @needs_sni def test_context_setget(self): # Check that the context of a connected socket can be replaced. - ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - ctx2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx1.load_verify_locations(capath=CAPATH) + ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx2.load_verify_locations(capath=CAPATH) s = socket.socket(socket.AF_INET) - with ctx1.wrap_socket(s) as ss: + with ctx1.wrap_socket(s, server_hostname='localhost') as ss: ss.connect(self.server_addr) self.assertIs(ss.context, ctx1) self.assertIs(ss._sslobj.context, ctx1) @@ -1763,11 +1801,12 @@ class SimpleBackgroundTests(unittest.TestCase): sock.connect(self.server_addr) incoming = ssl.MemoryBIO() outgoing = ssl.MemoryBIO() - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ctx.verify_mode = ssl.CERT_REQUIRED + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertTrue(ctx.check_hostname) + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) ctx.load_verify_locations(SIGNING_CA) - ctx.check_hostname = True - sslobj = ctx.wrap_bio(incoming, outgoing, False, 'localhost') + sslobj = ctx.wrap_bio(incoming, outgoing, False, + SIGNED_CERTFILE_HOSTNAME) self.assertIs(sslobj._sslobj.owner, sslobj) self.assertIsNone(sslobj.cipher()) self.assertIsNone(sslobj.version()) @@ -1796,7 +1835,7 @@ class SimpleBackgroundTests(unittest.TestCase): sock.connect(self.server_addr) incoming = ssl.MemoryBIO() outgoing = ssl.MemoryBIO() - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) ctx.verify_mode = ssl.CERT_NONE sslobj = ctx.wrap_bio(incoming, outgoing, False) self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) @@ -2008,7 +2047,7 @@ class ThreadedEchoServer(threading.Thread): else: self.context = ssl.SSLContext(ssl_version if ssl_version is not None - else ssl.PROTOCOL_TLSv1) + else ssl.PROTOCOL_TLS_SERVER) self.context.verify_mode = (certreqs if certreqs is not None else ssl.CERT_NONE) if cacerts: @@ -2271,7 +2310,7 @@ def try_protocol_combo(server_protocol, client_protocol, expect_success, # NOTE: we must enable "ALL" ciphers on the client, otherwise an # SSLv23 client will send an SSLv3 hello (rather than SSLv2) # starting from OpenSSL 1.0.0 (see issue #8322). - if client_context.protocol == ssl.PROTOCOL_SSLv23: + if client_context.protocol == ssl.PROTOCOL_TLS: client_context.set_ciphers("ALL") for ctx in (client_context, server_context): @@ -2317,17 +2356,13 @@ class ThreadedTests(unittest.TestCase): server_params_test(context, context, chatty=True, connectionchatty=True) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - client_context.load_verify_locations(SIGNING_CA) - server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - # server_context.load_verify_locations(SIGNING_CA) - server_context.load_cert_chain(SIGNED_CERTFILE2) + client_context, server_context, hostname = testing_context() with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER): server_params_test(client_context=client_context, server_context=server_context, chatty=True, connectionchatty=True, - sni_name='fakehostname') + sni_name=hostname) client_context.check_hostname = False with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT): @@ -2335,7 +2370,7 @@ class ThreadedTests(unittest.TestCase): server_params_test(client_context=server_context, server_context=client_context, chatty=True, connectionchatty=True, - sni_name='fakehostname') + sni_name=hostname) self.assertIn('called a function you should not call', str(e.exception)) @@ -2355,44 +2390,41 @@ class ThreadedTests(unittest.TestCase): self.assertIn('called a function you should not call', str(e.exception)) - def test_getpeercert(self): if support.verbose: sys.stdout.write("\n") - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - server = ThreadedEchoServer(context=context, chatty=False) + + client_context, server_context, hostname = testing_context() + server = ThreadedEchoServer(context=server_context, chatty=False) with server: - s = context.wrap_socket(socket.socket(), - do_handshake_on_connect=False) - s.connect((HOST, server.port)) - # getpeercert() raise ValueError while the handshake isn't - # done. - with self.assertRaises(ValueError): - s.getpeercert() - s.do_handshake() - cert = s.getpeercert() - self.assertTrue(cert, "Can't get peer certificate.") - cipher = s.cipher() - if support.verbose: - sys.stdout.write(pprint.pformat(cert) + '\n') - sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') - if 'subject' not in cert: - self.fail("No subject field in certificate: %s." % - pprint.pformat(cert)) - if ((('organizationName', 'Python Software Foundation'),) - not in cert['subject']): - self.fail( - "Missing or invalid 'organizationName' field in certificate subject; " - "should be 'Python Software Foundation'.") - self.assertIn('notBefore', cert) - self.assertIn('notAfter', cert) - before = ssl.cert_time_to_seconds(cert['notBefore']) - after = ssl.cert_time_to_seconds(cert['notAfter']) - self.assertLess(before, after) - s.close() + with client_context.wrap_socket(socket.socket(), + do_handshake_on_connect=False, + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + # getpeercert() raise ValueError while the handshake isn't + # done. + with self.assertRaises(ValueError): + s.getpeercert() + s.do_handshake() + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + cipher = s.cipher() + if support.verbose: + sys.stdout.write(pprint.pformat(cert) + '\n') + sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') + if 'subject' not in cert: + self.fail("No subject field in certificate: %s." % + pprint.pformat(cert)) + if ((('organizationName', 'Python Software Foundation'),) + not in cert['subject']): + self.fail( + "Missing or invalid 'organizationName' field in certificate subject; " + "should be 'Python Software Foundation'.") + self.assertIn('notBefore', cert) + self.assertIn('notAfter', cert) + before = ssl.cert_time_to_seconds(cert['notBefore']) + after = ssl.cert_time_to_seconds(cert['notAfter']) + self.assertLess(before, after) @unittest.skipUnless(have_verify_flags(), "verify_flags need OpenSSL > 0.9.8") @@ -2400,39 +2432,38 @@ class ThreadedTests(unittest.TestCase): if support.verbose: sys.stdout.write("\n") - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) + client_context, server_context, hostname = testing_context() - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(SIGNING_CA) tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0) - self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf) + self.assertEqual(client_context.verify_flags, ssl.VERIFY_DEFAULT | tf) # VERIFY_DEFAULT should pass server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket()) as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails - context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF + client_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket()) as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: with self.assertRaisesRegex(ssl.SSLError, "certificate verify failed"): s.connect((HOST, server.port)) # now load a CRL file. The CRL file is signed by the CA. - context.load_verify_locations(CRLFILE) + client_context.load_verify_locations(CRLFILE) server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket()) as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") @@ -2441,19 +2472,13 @@ class ThreadedTests(unittest.TestCase): if support.verbose: sys.stdout.write("\n") - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = ssl.CERT_REQUIRED - context.check_hostname = True - context.load_verify_locations(SIGNING_CA) + client_context, server_context, hostname = testing_context() # correct hostname should verify server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket(), - server_hostname="localhost") as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") @@ -2461,8 +2486,8 @@ class ThreadedTests(unittest.TestCase): # incorrect hostname should raise an exception server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket(), - server_hostname="invalid") as s: + with client_context.wrap_socket(socket.socket(), + server_hostname="invalid") as s: with self.assertRaisesRegex(ssl.CertificateError, "hostname 'invalid' doesn't match 'localhost'"): s.connect((HOST, server.port)) @@ -2473,7 +2498,7 @@ class ThreadedTests(unittest.TestCase): with socket.socket() as s: with self.assertRaisesRegex(ValueError, "check_hostname requires server_hostname"): - context.wrap_socket(s) + client_context.wrap_socket(s) def test_wrong_cert(self): """Connecting when the server rejects the client's certificate @@ -2489,9 +2514,7 @@ class ThreadedTests(unittest.TestCase): connectionchatty=False) with server, \ socket.socket() as sock, \ - test_wrap_socket(sock, - certfile=certfile, - ssl_version=ssl.PROTOCOL_TLSv1) as s: + test_wrap_socket(sock, certfile=certfile) as s: try: # Expect either an SSL error about the server rejecting # the connection, or a low-level connection reset (which @@ -2561,7 +2584,7 @@ class ThreadedTests(unittest.TestCase): server = ThreadedEchoServer(context=server_context, chatty=True) with server: with context.wrap_socket(socket.socket(), - server_hostname="localhost") as s: + server_hostname=SIGNED_CERTFILE_HOSTNAME) as s: try: s.connect((HOST, server.port)) except ssl.SSLError as e: @@ -2582,28 +2605,28 @@ class ThreadedTests(unittest.TestCase): try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True) try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL) try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False) try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False) # SSLv23 client with specific SSL options if no_sslv2_implies_sslv3_hello(): # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False, client_options=ssl.OP_NO_SSLv2) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False, client_options=ssl.OP_NO_SSLv3) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False, client_options=ssl.OP_NO_TLSv1) @skip_if_broken_ubuntu_ssl - def test_protocol_sslv23(self): + def test_PROTOCOL_TLS(self): """Connecting to an SSLv23 server with various client options""" if support.verbose: sys.stdout.write("\n") if hasattr(ssl, 'PROTOCOL_SSLv2'): try: - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv2, True) except OSError as x: # this fails on some older versions of OpenSSL (0.9.7l, for instance) if support.verbose: @@ -2611,28 +2634,28 @@ class ThreadedTests(unittest.TestCase): " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" % str(x)) if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1') + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True) + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1') if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) # Server with specific SSL options if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, server_options=ssl.OP_NO_SSLv3) # Will choose TLSv1 - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False, + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, False, server_options=ssl.OP_NO_TLSv1) @@ -2648,12 +2671,12 @@ class ThreadedTests(unittest.TestCase): try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS, False, client_options=ssl.OP_NO_SSLv3) try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) if no_sslv2_implies_sslv3_hello(): # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS, False, client_options=ssl.OP_NO_SSLv2) @skip_if_broken_ubuntu_ssl @@ -2668,7 +2691,7 @@ class ThreadedTests(unittest.TestCase): try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLS, False, client_options=ssl.OP_NO_TLSv1) @skip_if_broken_ubuntu_ssl @@ -2684,14 +2707,13 @@ class ThreadedTests(unittest.TestCase): try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLS, False, client_options=ssl.OP_NO_TLSv1_1) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) - @skip_if_broken_ubuntu_ssl @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"), "TLS version 1.2 not supported.") @@ -2707,10 +2729,10 @@ class ThreadedTests(unittest.TestCase): try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLS, False, client_options=ssl.OP_NO_TLSv1_2) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) @@ -2721,7 +2743,6 @@ class ThreadedTests(unittest.TestCase): msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6") server = ThreadedEchoServer(CERTFILE, - ssl_version=ssl.PROTOCOL_TLSv1, starttls_server=True, chatty=True, connectionchatty=True) @@ -2749,7 +2770,7 @@ class ThreadedTests(unittest.TestCase): sys.stdout.write( " client: read %r from server, starting TLS...\n" % msg) - conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) + conn = test_wrap_socket(s) wrapped = True elif indata == b"ENDTLS" and msg.startswith(b"ok"): # ENDTLS ok, switch back to clear text @@ -2836,7 +2857,7 @@ class ThreadedTests(unittest.TestCase): server = ThreadedEchoServer(CERTFILE, certreqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1, + ssl_version=ssl.PROTOCOL_TLS_SERVER, cacerts=CERTFILE, chatty=True, connectionchatty=False) @@ -2846,7 +2867,7 @@ class ThreadedTests(unittest.TestCase): certfile=CERTFILE, ca_certs=CERTFILE, cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) + ssl_version=ssl.PROTOCOL_TLS_CLIENT) s.connect((HOST, server.port)) # helper methods for standardising recv* method signatures def _recv_into(): @@ -2988,7 +3009,7 @@ class ThreadedTests(unittest.TestCase): def test_nonblocking_send(self): server = ThreadedEchoServer(CERTFILE, certreqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1, + ssl_version=ssl.PROTOCOL_TLS_SERVER, cacerts=CERTFILE, chatty=True, connectionchatty=False) @@ -2998,7 +3019,7 @@ class ThreadedTests(unittest.TestCase): certfile=CERTFILE, ca_certs=CERTFILE, cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) + ssl_version=ssl.PROTOCOL_TLS_CLIENT) s.connect((HOST, server.port)) s.setblocking(False) @@ -3067,7 +3088,7 @@ class ThreadedTests(unittest.TestCase): def test_server_accept(self): # Issue #16357: accept() on a SSLSocket created through # SSLContext.wrap_socket(). - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context = ssl.SSLContext(ssl.PROTOCOL_TLS) context.verify_mode = ssl.CERT_REQUIRED context.load_verify_locations(CERTFILE) context.load_cert_chain(CERTFILE) @@ -3104,28 +3125,28 @@ class ThreadedTests(unittest.TestCase): self.assertEqual(peer, client_addr) def test_getpeercert_enotconn(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context = ssl.SSLContext(ssl.PROTOCOL_TLS) with context.wrap_socket(socket.socket()) as sock: with self.assertRaises(OSError) as cm: sock.getpeercert() self.assertEqual(cm.exception.errno, errno.ENOTCONN) def test_do_handshake_enotconn(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context = ssl.SSLContext(ssl.PROTOCOL_TLS) with context.wrap_socket(socket.socket()) as sock: with self.assertRaises(OSError) as cm: sock.do_handshake() self.assertEqual(cm.exception.errno, errno.ENOTCONN) def test_default_ciphers(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context = ssl.SSLContext(ssl.PROTOCOL_TLS) try: # Force a set of weak ciphers on our client context context.set_ciphers("DES") except ssl.SSLError: self.skipTest("no DES cipher available") with ThreadedEchoServer(CERTFILE, - ssl_version=ssl.PROTOCOL_SSLv23, + ssl_version=ssl.PROTOCOL_TLS, chatty=False) as server: with context.wrap_socket(socket.socket()) as s: with self.assertRaises(OSError): @@ -3137,14 +3158,19 @@ class ThreadedTests(unittest.TestCase): Basic tests for SSLSocket.version(). More tests are done in the test_protocol_*() methods. """ - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE with ThreadedEchoServer(CERTFILE, - ssl_version=ssl.PROTOCOL_TLSv1, + ssl_version=ssl.PROTOCOL_TLS_SERVER, chatty=False) as server: with context.wrap_socket(socket.socket()) as s: self.assertIs(s.version(), None) s.connect((HOST, server.port)) - self.assertEqual(s.version(), 'TLSv1') + if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2): + self.assertEqual(s.version(), 'TLSv1.2') + else: # 0.9.8 to 1.0.1 + self.assertIn(s.version(), ('TLSv1', 'TLSv1.2')) self.assertIs(s.version(), None) @unittest.skipUnless(ssl.HAS_TLSv1_3, @@ -3169,7 +3195,7 @@ class ThreadedTests(unittest.TestCase): def test_default_ecdh_curve(self): # Issue #21015: elliptic curve-based Diffie Hellman key exchange # should be enabled by default on SSL contexts. - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context = ssl.SSLContext(ssl.PROTOCOL_TLS) context.load_cert_chain(CERTFILE) # TLSv1.3 defaults to PFS key agreement and no longer has KEA in # cipher name. @@ -3194,7 +3220,7 @@ class ThreadedTests(unittest.TestCase): server = ThreadedEchoServer(CERTFILE, certreqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1, + ssl_version=ssl.PROTOCOL_TLS_SERVER, cacerts=CERTFILE, chatty=True, connectionchatty=False) @@ -3204,7 +3230,7 @@ class ThreadedTests(unittest.TestCase): certfile=CERTFILE, ca_certs=CERTFILE, cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) + ssl_version=ssl.PROTOCOL_TLS_CLIENT) s.connect((HOST, server.port)) # get the data cb_data = s.get_channel_binding("tls-unique") @@ -3229,7 +3255,7 @@ class ThreadedTests(unittest.TestCase): certfile=CERTFILE, ca_certs=CERTFILE, cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) + ssl_version=ssl.PROTOCOL_TLS_CLIENT) s.connect((HOST, server.port)) new_cb_data = s.get_channel_binding("tls-unique") if support.verbose: @@ -3246,10 +3272,10 @@ class ThreadedTests(unittest.TestCase): s.close() def test_compression(self): - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) + client_context, server_context, hostname = testing_context() + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True, + sni_name=hostname) if support.verbose: sys.stdout.write(" got compression: {!r}\n".format(stats['compression'])) self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' }) @@ -3257,21 +3283,22 @@ class ThreadedTests(unittest.TestCase): @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'), "ssl.OP_NO_COMPRESSION needed for this test") def test_compression_disabled(self): - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - context.options |= ssl.OP_NO_COMPRESSION - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) + client_context, server_context, hostname = testing_context() + client_context.options |= ssl.OP_NO_COMPRESSION + server_context.options |= ssl.OP_NO_COMPRESSION + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True, + sni_name=hostname) self.assertIs(stats['compression'], None) def test_dh_params(self): # Check we can get a connection with ephemeral Diffie-Hellman - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - context.load_dh_params(DHFILE) - context.set_ciphers("kEDH") - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) + client_context, server_context, hostname = testing_context() + server_context.load_dh_params(DHFILE) + server_context.set_ciphers("kEDH") + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True, + sni_name=hostname) cipher = stats["cipher"][0] parts = cipher.split("-") if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: @@ -3279,22 +3306,20 @@ class ThreadedTests(unittest.TestCase): def test_selected_alpn_protocol(self): # selected_alpn_protocol() is None unless ALPN is used. - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) + client_context, server_context, hostname = testing_context() + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True, + sni_name=hostname) self.assertIs(stats['client_alpn_protocol'], None) @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required") def test_selected_alpn_protocol_if_server_uses_alpn(self): # selected_alpn_protocol() is None unless ALPN is used by the client. - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.load_verify_locations(CERTFILE) - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(CERTFILE) + client_context, server_context, hostname = testing_context() server_context.set_alpn_protocols(['foo', 'bar']) stats = server_params_test(client_context, server_context, - chatty=True, connectionchatty=True) + chatty=True, connectionchatty=True, + sni_name=hostname) self.assertIs(stats['client_alpn_protocol'], None) @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test") @@ -3307,18 +3332,16 @@ class ThreadedTests(unittest.TestCase): (['http/3.0', 'http/4.0'], None) ] for client_protocols, expected in protocol_tests: - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - server_context.load_cert_chain(CERTFILE) + client_context, server_context, hostname = testing_context() server_context.set_alpn_protocols(server_protocols) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - client_context.load_cert_chain(CERTFILE) client_context.set_alpn_protocols(client_protocols) try: stats = server_params_test(client_context, server_context, chatty=True, - connectionchatty=True) + connectionchatty=True, + sni_name=hostname) except ssl.SSLError as e: stats = e @@ -3341,10 +3364,10 @@ class ThreadedTests(unittest.TestCase): def test_selected_npn_protocol(self): # selected_npn_protocol() is None unless NPN is used - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) + client_context, server_context, hostname = testing_context() + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True, + sni_name=hostname) self.assertIs(stats['client_npn_protocol'], None) @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test") @@ -3357,15 +3380,12 @@ class ThreadedTests(unittest.TestCase): (['abc', 'def'], 'abc') ] for client_protocols, expected in protocol_tests: - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(CERTFILE) + client_context, server_context, hostname = testing_context() server_context.set_npn_protocols(server_protocols) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.load_cert_chain(CERTFILE) client_context.set_npn_protocols(client_protocols) stats = server_params_test(client_context, server_context, - chatty=True, connectionchatty=True) - + chatty=True, connectionchatty=True, + sni_name=hostname) msg = "failed trying %s (s) and %s (c).\n" \ "was expecting %s, but got %%s from the %%s" \ % (str(server_protocols), str(client_protocols), @@ -3377,12 +3397,11 @@ class ThreadedTests(unittest.TestCase): self.assertEqual(server_result, expected, msg % (server_result, "server")) def sni_contexts(self): - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) server_context.load_cert_chain(SIGNED_CERTFILE) - other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + other_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) other_context.load_cert_chain(SIGNED_CERTFILE2) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.verify_mode = ssl.CERT_REQUIRED + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) client_context.load_verify_locations(SIGNING_CA) return server_context, other_context, client_context @@ -3395,6 +3414,8 @@ class ThreadedTests(unittest.TestCase): calls = [] server_context, other_context, client_context = self.sni_contexts() + client_context.check_hostname = False + def servername_cb(ssl_sock, server_name, initial_context): calls.append((server_name, initial_context)) if server_name is not None: @@ -3416,7 +3437,7 @@ class ThreadedTests(unittest.TestCase): chatty=True, sni_name=None) self.assertEqual(calls, [(None, server_context)]) - self.check_common_name(stats, 'localhost') + self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME) # Check disabling the callback calls = [] @@ -3426,7 +3447,7 @@ class ThreadedTests(unittest.TestCase): chatty=True, sni_name='notfunny') # Certificate didn't change - self.check_common_name(stats, 'localhost') + self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME) self.assertEqual(calls, []) @needs_sni @@ -3437,7 +3458,6 @@ class ThreadedTests(unittest.TestCase): def cb_returning_alert(ssl_sock, server_name, initial_context): return ssl.ALERT_DESCRIPTION_ACCESS_DENIED server_context.set_servername_callback(cb_returning_alert) - with self.assertRaises(ssl.SSLError) as cm: stats = server_params_test(client_context, server_context, chatty=False, @@ -3480,11 +3500,7 @@ class ThreadedTests(unittest.TestCase): self.assertIn("TypeError", stderr.getvalue()) def test_shared_ciphers(self): - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.verify_mode = ssl.CERT_REQUIRED - client_context.load_verify_locations(SIGNING_CA) + client_context, server_context, hostname = testing_context() if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2): client_context.set_ciphers("AES128:AES256") server_context.set_ciphers("AES256") @@ -3496,7 +3512,8 @@ class ThreadedTests(unittest.TestCase): alg1 = "3DES" alg2 = "DES-CBC3" - stats = server_params_test(client_context, server_context) + stats = server_params_test(client_context, server_context, + sni_name=hostname) ciphers = stats['server_shared_ciphers'][0] self.assertGreater(len(ciphers), 0) for name, tls_version, bits in ciphers: @@ -3504,14 +3521,12 @@ class ThreadedTests(unittest.TestCase): self.fail(name) def test_read_write_after_close_raises_valuerror(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - server = ThreadedEchoServer(context=context, chatty=False) + client_context, server_context, hostname = testing_context() + server = ThreadedEchoServer(context=server_context, chatty=False) with server: - s = context.wrap_socket(socket.socket()) + s = client_context.wrap_socket(socket.socket(), + server_hostname=hostname) s.connect((HOST, server.port)) s.close() @@ -3523,7 +3538,7 @@ class ThreadedTests(unittest.TestCase): with open(support.TESTFN, 'wb') as f: f.write(TEST_DATA) self.addCleanup(support.unlink, support.TESTFN) - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context = ssl.SSLContext(ssl.PROTOCOL_TLS) context.verify_mode = ssl.CERT_REQUIRED context.load_verify_locations(CERTFILE) context.load_cert_chain(CERTFILE) @@ -3536,14 +3551,11 @@ class ThreadedTests(unittest.TestCase): self.assertEqual(s.recv(1024), TEST_DATA) def test_session(self): - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.verify_mode = ssl.CERT_REQUIRED - client_context.load_verify_locations(SIGNING_CA) + client_context, server_context, hostname = testing_context() # first connection without session - stats = server_params_test(client_context, server_context) + stats = server_params_test(client_context, server_context, + sni_name=hostname) session = stats['session'] self.assertTrue(session.id) self.assertGreater(session.time, 0) @@ -3557,7 +3569,8 @@ class ThreadedTests(unittest.TestCase): self.assertEqual(sess_stat['hits'], 0) # reuse session - stats = server_params_test(client_context, server_context, session=session) + stats = server_params_test(client_context, server_context, + session=session, sni_name=hostname) sess_stat = server_context.session_stats() self.assertEqual(sess_stat['accept'], 2) self.assertEqual(sess_stat['hits'], 1) @@ -3570,7 +3583,8 @@ class ThreadedTests(unittest.TestCase): self.assertGreaterEqual(session2.timeout, session.timeout) # another one without session - stats = server_params_test(client_context, server_context) + stats = server_params_test(client_context, server_context, + sni_name=hostname) self.assertFalse(stats['session_reused']) session3 = stats['session'] self.assertNotEqual(session3.id, session.id) @@ -3580,7 +3594,8 @@ class ThreadedTests(unittest.TestCase): self.assertEqual(sess_stat['hits'], 1) # reuse session again - stats = server_params_test(client_context, server_context, session=session) + stats = server_params_test(client_context, server_context, + session=session, sni_name=hostname) self.assertTrue(stats['session_reused']) session4 = stats['session'] self.assertEqual(session4.id, session.id) @@ -3592,23 +3607,17 @@ class ThreadedTests(unittest.TestCase): self.assertEqual(sess_stat['hits'], 2) def test_session_handling(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - - context2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context2.verify_mode = ssl.CERT_REQUIRED - context2.load_verify_locations(CERTFILE) - context2.load_cert_chain(CERTFILE) + client_context, server_context, hostname = testing_context() + client_context2, _, _ = testing_context() # TODO: session reuse does not work with TLS 1.3 - context.options |= ssl.OP_NO_TLSv1_3 - context2.options |= ssl.OP_NO_TLSv1_3 + client_context.options |= ssl.OP_NO_TLSv1_3 + client_context2.options |= ssl.OP_NO_TLSv1_3 - server = ThreadedEchoServer(context=context, chatty=False) + server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with context.wrap_socket(socket.socket()) as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: # session is None before handshake self.assertEqual(s.session, None) self.assertEqual(s.session_reused, None) @@ -3619,7 +3628,8 @@ class ThreadedTests(unittest.TestCase): s.session = object self.assertEqual(str(e.exception), 'Value is not a SSLSession.') - with context.wrap_socket(socket.socket()) as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: s.connect((HOST, server.port)) # cannot set session after handshake with self.assertRaises(ValueError) as e: @@ -3627,7 +3637,8 @@ class ThreadedTests(unittest.TestCase): self.assertEqual(str(e.exception), 'Cannot set session after handshake.') - with context.wrap_socket(socket.socket()) as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: # can set session before handshake and before the # connection was established s.session = session @@ -3636,7 +3647,8 @@ class ThreadedTests(unittest.TestCase): self.assertEqual(s.session, session) self.assertEqual(s.session_reused, True) - with context2.wrap_socket(socket.socket()) as s: + with client_context2.wrap_socket(socket.socket(), + server_hostname=hostname) as s: # cannot re-use session with a different SSLContext with self.assertRaises(ValueError) as e: s.session = session diff --git a/Lib/test/test_urllib2_localnet.py b/Lib/test/test_urllib2_localnet.py index 741d136a92c..b2d1e5f9804 100644 --- a/Lib/test/test_urllib2_localnet.py +++ b/Lib/test/test_urllib2_localnet.py @@ -594,7 +594,7 @@ class TestUrlopen(unittest.TestCase): def cb_sni(ssl_sock, server_name, initial_context): nonlocal sni_name sni_name = server_name - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) context.set_servername_callback(cb_sni) handler = self.start_https_server(context=context, certfile=CERT_localhost) context = ssl.create_default_context(cafile=CERT_localhost) diff --git a/Misc/NEWS.d/next/Tests/2017-09-04-16-21-18.bpo-31346.xni1VR.rst b/Misc/NEWS.d/next/Tests/2017-09-04-16-21-18.bpo-31346.xni1VR.rst new file mode 100644 index 00000000000..87bc8cfabaf --- /dev/null +++ b/Misc/NEWS.d/next/Tests/2017-09-04-16-21-18.bpo-31346.xni1VR.rst @@ -0,0 +1 @@ +Prefer PROTOCOL_TLS_CLIENT and PROTOCOL_TLS_SERVER protocols for SSLContext.