From 18c913e2b1ca9f7c92fa7ae22e260237f241a8b7 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 27 Apr 2010 10:59:39 +0000 Subject: [PATCH] Merged revisions 80529 via svnmerge from svn+ssh://pythondev@svn.python.org/python/trunk ........ r80529 | antoine.pitrou | 2010-04-27 12:32:58 +0200 (mar., 27 avril 2010) | 4 lines Qualify or remove or bare excepts. Simplify exception handling in places. Remove uses of test_support.TestFailed. ........ --- Lib/test/test_ssl.py | 264 +++++++++++++++++-------------------------- 1 file changed, 104 insertions(+), 160 deletions(-) diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 921205d854c..a5af557513c 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -44,7 +44,7 @@ class BasicTests(unittest.TestCase): s.connect(("svn.python.org", 443)) c = s.getpeercert() if c: - raise support.TestFailed("Peer cert %s shouldn't be here!") + self.fail("Peer cert %s shouldn't be here!") s.close() # this should fail because we have no verification certs @@ -94,8 +94,7 @@ class BasicTests(unittest.TestCase): d1 = ssl.PEM_cert_to_DER_cert(pem) p2 = ssl.DER_cert_to_PEM_cert(d1) d2 = ssl.PEM_cert_to_DER_cert(p2) - if (d1 != d2): - raise support.TestFailed("PEM-to-DER or DER-to-PEM translation failed") + self.assertEqual(d1, d2) def test_openssl_version(self): n = ssl.OPENSSL_VERSION_NUMBER @@ -169,7 +168,7 @@ class NetworkedTests(unittest.TestCase): s.connect(("svn.python.org", 443)) c = s.getpeercert() if c: - raise support.TestFailed("Peer cert %s shouldn't be here!") + self.fail("Peer cert %s shouldn't be here!") s.close() # this should fail because we have no verification certs @@ -188,8 +187,6 @@ class NetworkedTests(unittest.TestCase): ca_certs=SVN_PYTHON_ORG_ROOT_CERT) try: s.connect(("svn.python.org", 443)) - except ssl.SSLError as x: - raise support.TestFailed("Unexpected exception %s" % x) finally: s.close() @@ -240,7 +237,7 @@ class NetworkedTests(unittest.TestCase): pem = ssl.get_server_certificate(("svn.python.org", 443)) if not pem: - raise support.TestFailed("No server certificate on svn.python.org:443!") + self.fail("No server certificate on svn.python.org:443!") return @@ -251,11 +248,11 @@ class NetworkedTests(unittest.TestCase): if support.verbose: sys.stdout.write("%s\n" % x) else: - raise support.TestFailed("Got server certificate %s for svn.python.org!" % pem) + self.fail("Got server certificate %s for svn.python.org!" % pem) pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT) if not pem: - raise support.TestFailed("No server certificate on svn.python.org:443!") + self.fail("No server certificate on svn.python.org:443!") if support.verbose: sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem) @@ -317,19 +314,16 @@ else: ca_certs=self.server.cacerts, cert_reqs=self.server.certreqs, ciphers=self.server.ciphers) - except: + except ssl.SSLError: + # XXX Various errors can have happened here, for example + # a mismatching protocol version, an invalid certificate, + # or a low-level bug. This should be made more discriminating. if self.server.chatty: handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n") - if not self.server.expect_bad_connects: - # here, we want to stop the server, because this shouldn't - # happen in the context of our test case - self.running = False - # normally, we'd just stop here, but for the test - # harness, we want to stop the server - self.server.stop() + self.running = False + self.server.stop() self.close() return False - else: if self.server.certreqs == ssl.CERT_REQUIRED: cert = self.sslconn.getpeercert() @@ -410,11 +404,9 @@ else: # normally, we'd just stop here, but for the test # harness, we want to stop the server self.server.stop() - except: - handle_error('') def __init__(self, certificate, ssl_version=None, - certreqs=None, cacerts=None, expect_bad_connects=False, + certreqs=None, cacerts=None, chatty=True, connectionchatty=False, starttls_server=False, ciphers=None): if ssl_version is None: @@ -426,7 +418,6 @@ else: self.certreqs = certreqs self.cacerts = cacerts self.ciphers = ciphers - self.expect_bad_connects = expect_bad_connects self.chatty = chatty self.connectionchatty = connectionchatty self.starttls_server = starttls_server @@ -460,9 +451,6 @@ else: pass except KeyboardInterrupt: self.stop() - except: - if self.chatty: - handle_error("Test server failure:\n") self.sock.close() def stop (self): @@ -621,7 +609,7 @@ else: def handle_close(self): self.close() - if test_support.verbose: + if support.verbose: sys.stdout.write(" server: closed connection %s\n" % self.socket) def handle_error(self): @@ -690,13 +678,9 @@ else: s.connect((HOST, server.port)) except ssl.SSLError as x: if support.verbose: - sys.stdout.write("\nSSLError is %s\n" % x) - except socket.error as x: - if support.verbose: - sys.stdout.write("\nsocket.error is %s\n" % x) + sys.stdout.write("\nSSLError is %s\n" % x.args[1]) else: - raise support.TestFailed( - "Use of invalid cert should have failed!") + self.fail("Use of invalid cert should have failed!") finally: server.stop() server.join() @@ -722,18 +706,12 @@ else: client_protocol = protocol try: s = ssl.wrap_socket(socket.socket(), - server_side=False, certfile=client_certfile, ca_certs=cacertsfile, - cert_reqs=certreqs, ciphers=ciphers, + cert_reqs=certreqs, ssl_version=client_protocol) s.connect((HOST, server.port)) - except ssl.SSLError as x: - raise support.TestFailed("Unexpected SSL error: " + str(x)) - except Exception as x: - raise support.TestFailed("Unexpected exception: " + str(x)) - else: bindata = indata.encode('ASCII', 'strict') for arg in [bindata, bytearray(bindata), memoryview(bindata)]: if connectionchatty: @@ -747,7 +725,7 @@ else: sys.stdout.write(" client: read %s\n" % repr(outdata)) outdata = str(outdata, 'ASCII', 'strict') if outdata != indata.lower(): - raise support.TestFailed( + self.fail( "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" % (repr(outdata[:min(len(outdata),20)]), len(outdata), repr(indata[:min(len(indata),20)].lower()), len(indata))) @@ -788,12 +766,17 @@ else: CERTFILE, CERTFILE, client_protocol, ciphers="ALL", chatty=False, connectionchatty=False) - except support.TestFailed: + # Protocol mismatch can result in either an SSLError, or a + # "Connection reset by peer" error. + except ssl.SSLError: if expectedToWork: raise + except socket.error as e: + if expectedToWork or e.errno != errno.ECONNRESET: + raise else: if not expectedToWork: - raise support.TestFailed( + self.fail( "Client protocol %s succeeded with server protocol %s!" % (ssl.get_protocol_name(client_protocol), ssl.get_protocol_name(server_protocol))) @@ -825,41 +808,27 @@ else: flag.wait() # try to connect try: - try: - s = ssl.wrap_socket(socket.socket(), - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_REQUIRED, - ssl_version=ssl.PROTOCOL_SSLv23) - s.connect((HOST, server.port)) - except ssl.SSLError as x: - raise support.TestFailed( - "Unexpected SSL error: " + str(x)) - except Exception as x: - raise support.TestFailed( - "Unexpected exception: " + str(x)) - else: - if not s: - raise support.TestFailed( - "Can't SSL-handshake with test server") - cert = s.getpeercert() - if not cert: - raise support.TestFailed( - "Can't get peer certificate.") - cipher = s.cipher() - if support.verbose: - sys.stdout.write(pprint.pformat(cert) + '\n') - sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') - if 'subject' not in cert: - raise support.TestFailed( - "No subject field in certificate: %s." % - pprint.pformat(cert)) - if ((('organizationName', 'Python Software Foundation'),) - not in cert['subject']): - raise support.TestFailed( - "Missing or invalid 'organizationName' field in certificate subject; " - "should be 'Python Software Foundation'.") - s.close() + s = ssl.wrap_socket(socket.socket(), + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_REQUIRED, + ssl_version=ssl.PROTOCOL_SSLv23) + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + cipher = s.cipher() + if support.verbose: + sys.stdout.write(pprint.pformat(cert) + '\n') + sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') + if 'subject' not in cert: + self.fail("No subject field in certificate: %s." % + pprint.pformat(cert)) + if ((('organizationName', 'Python Software Foundation'),) + not in cert['subject']): + self.fail( + "Missing or invalid 'organizationName' field in certificate subject; " + "should be 'Python Software Foundation'.") + s.close() finally: server.stop() server.join() @@ -905,8 +874,7 @@ else: except IOError: pass else: - raise test_support.TestFailed( - 'connecting to closed SSL socket should have failed') + self.fail('connecting to closed SSL socket should have failed') t = threading.Thread(target=listener) t.start() @@ -930,7 +898,7 @@ else: sys.stdout.write("\n") try: tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) - except support.TestFailed as x: + except (ssl.SSLError, socket.error) as x: # this fails on some older versions of OpenSSL (0.9.7l, for instance) if support.verbose: sys.stdout.write( @@ -984,55 +952,51 @@ else: # try to connect wrapped = False try: - try: - s = socket.socket() - s.setblocking(1) - s.connect((HOST, server.port)) - except Exception as x: - raise support.TestFailed("Unexpected exception: " + str(x)) - else: + s = socket.socket() + s.setblocking(1) + s.connect((HOST, server.port)) + if support.verbose: + sys.stdout.write("\n") + for indata in msgs: + msg = indata.encode('ASCII', 'replace') if support.verbose: - sys.stdout.write("\n") - for indata in msgs: - msg = indata.encode('ASCII', 'replace') - if support.verbose: - sys.stdout.write( - " client: sending %s...\n" % repr(msg)) - if wrapped: - conn.write(msg) - outdata = conn.read() - else: - s.send(msg) - outdata = s.recv(1024) - if (indata == "STARTTLS" and - str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")): - if support.verbose: - msg = str(outdata, 'ASCII', 'replace') - sys.stdout.write( - " client: read %s from server, starting TLS...\n" - % repr(msg)) - conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) - wrapped = True - elif (indata == "ENDTLS" and - str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")): - if support.verbose: - msg = str(outdata, 'ASCII', 'replace') - sys.stdout.write( - " client: read %s from server, ending TLS...\n" - % repr(msg)) - s = conn.unwrap() - wrapped = False - else: - if support.verbose: - msg = str(outdata, 'ASCII', 'replace') - sys.stdout.write( - " client: read %s from server\n" % repr(msg)) - if support.verbose: - sys.stdout.write(" client: closing connection.\n") + sys.stdout.write( + " client: sending %s...\n" % repr(msg)) if wrapped: - conn.write("over\n".encode("ASCII", "strict")) + conn.write(msg) + outdata = conn.read() else: - s.send("over\n".encode("ASCII", "strict")) + s.send(msg) + outdata = s.recv(1024) + if (indata == "STARTTLS" and + str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")): + if support.verbose: + msg = str(outdata, 'ASCII', 'replace') + sys.stdout.write( + " client: read %s from server, starting TLS...\n" + % repr(msg)) + conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) + wrapped = True + elif (indata == "ENDTLS" and + str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")): + if support.verbose: + msg = str(outdata, 'ASCII', 'replace') + sys.stdout.write( + " client: read %s from server, ending TLS...\n" + % repr(msg)) + s = conn.unwrap() + wrapped = False + else: + if support.verbose: + msg = str(outdata, 'ASCII', 'replace') + sys.stdout.write( + " client: read %s from server\n" % repr(msg)) + if support.verbose: + sys.stdout.write(" client: closing connection.\n") + if wrapped: + conn.write("over\n".encode("ASCII", "strict")) + else: + s.send("over\n".encode("ASCII", "strict")) if wrapped: conn.close() else: @@ -1066,17 +1030,7 @@ else: " client: read %d bytes from remote server '%s'\n" % (len(d2), server)) f.close() - except: - msg = ''.join(traceback.format_exception(*sys.exc_info())) - if support.verbose: - sys.stdout.write('\n' + msg) - raise support.TestFailed(msg) - else: - if not (d1 == d2): - print("d1 is", len(d1), repr(d1)) - print("d2 is", len(d2), repr(d2)) - raise support.TestFailed( - "Couldn't fetch data from HTTPS server") + self.assertEqual(d1, d2) finally: if support.verbose: sys.stdout.write('stopping server\n') @@ -1099,25 +1053,20 @@ else: # try to connect try: s = ssl.wrap_socket(socket.socket()) - s.connect((HOST, server.port)) - except ssl.SSLError as x: - raise support.TestFailed("Unexpected SSL error: " + str(x)) - except Exception as x: - raise support.TestFailed("Unexpected exception: " + str(x)) - else: + s.connect(('127.0.0.1', server.port)) if support.verbose: sys.stdout.write( " client: sending %s...\n" % (repr(indata))) - s.sendall(indata.encode('ASCII', 'strict')) - outdata = s.recv() + s.write(indata.encode('ASCII', 'strict')) + outdata = s.read() if support.verbose: sys.stdout.write(" client: read %s\n" % repr(outdata)) outdata = str(outdata, 'ASCII', 'strict') if outdata != indata.lower(): - raise support.TestFailed( + self.fail( "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" - % (repr(outdata[:min(len(outdata),20)]), len(outdata), - repr(indata[:min(len(indata),20)].lower()), len(indata))) + % (outdata[:min(len(outdata),20)], len(outdata), + indata[:min(len(indata),20)].lower(), len(indata))) s.write("over\n".encode("ASCII", "strict")) if support.verbose: sys.stdout.write(" client: closing connection.\n") @@ -1142,19 +1091,14 @@ else: # wait for it to start flag.wait() # try to connect + s = ssl.wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) try: - s = ssl.wrap_socket(socket.socket(), - server_side=False, - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) - s.connect((HOST, server.port)) - except ssl.SSLError as x: - self.fail("Unexpected SSL error: " + str(x)) - except Exception as x: - self.fail("Unexpected exception: " + str(x)) - else: # helper methods for standardising recv* method signatures def _recv_into(): b = bytearray(b"\0"*100)