Use context managers in test_ssl to simplify test writing.

This commit is contained in:
Antoine Pitrou 2011-12-21 16:52:40 +01:00
parent 28f8bee5c8
commit 65a3f4b8c5
1 changed files with 35 additions and 68 deletions

View File

@ -881,6 +881,14 @@ else:
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.daemon = True self.daemon = True
def __enter__(self):
self.start(threading.Event())
self.flag.wait()
def __exit__(self, *args):
self.stop()
self.join()
def start(self, flag=None): def start(self, flag=None):
self.flag = flag self.flag = flag
threading.Thread.start(self) threading.Thread.start(self)
@ -993,6 +1001,20 @@ else:
def __str__(self): def __str__(self):
return "<%s %s>" % (self.__class__.__name__, self.server) return "<%s %s>" % (self.__class__.__name__, self.server)
def __enter__(self):
self.start(threading.Event())
self.flag.wait()
def __exit__(self, *args):
if support.verbose:
sys.stdout.write(" cleanup: stopping server.\n")
self.stop()
if support.verbose:
sys.stdout.write(" cleanup: joining server thread.\n")
self.join()
if support.verbose:
sys.stdout.write(" cleanup: successfully joined.\n")
def start (self, flag=None): def start (self, flag=None):
self.flag = flag self.flag = flag
threading.Thread.start(self) threading.Thread.start(self)
@ -1020,12 +1042,7 @@ else:
certreqs=ssl.CERT_REQUIRED, certreqs=ssl.CERT_REQUIRED,
cacerts=CERTFILE, chatty=False, cacerts=CERTFILE, chatty=False,
connectionchatty=False) connectionchatty=False)
flag = threading.Event() with server:
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
try:
try: try:
with socket.socket() as sock: with socket.socket() as sock:
s = ssl.wrap_socket(sock, s = ssl.wrap_socket(sock,
@ -1045,9 +1062,6 @@ else:
sys.stdout.write("\IOError is %s\n" % str(x)) sys.stdout.write("\IOError is %s\n" % str(x))
else: else:
raise AssertionError("Use of invalid cert should have failed!") raise AssertionError("Use of invalid cert should have failed!")
finally:
server.stop()
server.join()
def server_params_test(client_context, server_context, indata=b"FOO\n", def server_params_test(client_context, server_context, indata=b"FOO\n",
chatty=True, connectionchatty=False): chatty=True, connectionchatty=False):
@ -1058,12 +1072,7 @@ else:
server = ThreadedEchoServer(context=server_context, server = ThreadedEchoServer(context=server_context,
chatty=chatty, chatty=chatty,
connectionchatty=False) connectionchatty=False)
flag = threading.Event() with server:
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
try:
s = client_context.wrap_socket(socket.socket()) s = client_context.wrap_socket(socket.socket())
s.connect((HOST, server.port)) s.connect((HOST, server.port))
for arg in [indata, bytearray(indata), memoryview(indata)]: for arg in [indata, bytearray(indata), memoryview(indata)]:
@ -1086,9 +1095,6 @@ else:
if support.verbose: if support.verbose:
sys.stdout.write(" client: closing connection.\n") sys.stdout.write(" client: closing connection.\n")
s.close() s.close()
finally:
server.stop()
server.join()
def try_protocol_combo(server_protocol, client_protocol, expect_success, def try_protocol_combo(server_protocol, client_protocol, expect_success,
certsreqs=None, server_options=0, client_options=0): certsreqs=None, server_options=0, client_options=0):
@ -1157,12 +1163,7 @@ else:
context.load_verify_locations(CERTFILE) context.load_verify_locations(CERTFILE)
context.load_cert_chain(CERTFILE) context.load_cert_chain(CERTFILE)
server = ThreadedEchoServer(context=context, chatty=False) server = ThreadedEchoServer(context=context, chatty=False)
flag = threading.Event() with server:
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
try:
s = context.wrap_socket(socket.socket()) s = context.wrap_socket(socket.socket())
s.connect((HOST, server.port)) s.connect((HOST, server.port))
cert = s.getpeercert() cert = s.getpeercert()
@ -1185,9 +1186,6 @@ else:
after = ssl.cert_time_to_seconds(cert['notAfter']) after = ssl.cert_time_to_seconds(cert['notAfter'])
self.assertLess(before, after) self.assertLess(before, after)
s.close() s.close()
finally:
server.stop()
server.join()
def test_empty_cert(self): def test_empty_cert(self):
"""Connecting with an empty cert file""" """Connecting with an empty cert file"""
@ -1346,13 +1344,8 @@ else:
starttls_server=True, starttls_server=True,
chatty=True, chatty=True,
connectionchatty=True) connectionchatty=True)
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
wrapped = False wrapped = False
try: with server:
s = socket.socket() s = socket.socket()
s.setblocking(1) s.setblocking(1)
s.connect((HOST, server.port)) s.connect((HOST, server.port))
@ -1399,9 +1392,6 @@ else:
conn.close() conn.close()
else: else:
s.close() s.close()
finally:
server.stop()
server.join()
def test_socketserver(self): def test_socketserver(self):
"""Using a SocketServer to create and manage SSL connections.""" """Using a SocketServer to create and manage SSL connections."""
@ -1437,12 +1427,7 @@ else:
indata = b"FOO\n" indata = b"FOO\n"
server = AsyncoreEchoServer(CERTFILE) server = AsyncoreEchoServer(CERTFILE)
flag = threading.Event() with server:
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
try:
s = ssl.wrap_socket(socket.socket()) s = ssl.wrap_socket(socket.socket())
s.connect(('127.0.0.1', server.port)) s.connect(('127.0.0.1', server.port))
if support.verbose: if support.verbose:
@ -1463,15 +1448,6 @@ else:
s.close() s.close()
if support.verbose: if support.verbose:
sys.stdout.write(" client: connection closed.\n") sys.stdout.write(" client: connection closed.\n")
finally:
if support.verbose:
sys.stdout.write(" cleanup: stopping server.\n")
server.stop()
if support.verbose:
sys.stdout.write(" cleanup: joining server thread.\n")
server.join()
if support.verbose:
sys.stdout.write(" cleanup: successfully joined.\n")
def test_recv_send(self): def test_recv_send(self):
"""Test recv(), send() and friends.""" """Test recv(), send() and friends."""
@ -1484,19 +1460,14 @@ else:
cacerts=CERTFILE, cacerts=CERTFILE,
chatty=True, chatty=True,
connectionchatty=False) connectionchatty=False)
flag = threading.Event() with server:
server.start(flag) s = ssl.wrap_socket(socket.socket(),
# wait for it to start server_side=False,
flag.wait() certfile=CERTFILE,
# try to connect ca_certs=CERTFILE,
s = ssl.wrap_socket(socket.socket(), cert_reqs=ssl.CERT_NONE,
server_side=False, ssl_version=ssl.PROTOCOL_TLSv1)
certfile=CERTFILE, s.connect((HOST, server.port))
ca_certs=CERTFILE,
cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_TLSv1)
s.connect((HOST, server.port))
try:
# helper methods for standardising recv* method signatures # helper methods for standardising recv* method signatures
def _recv_into(): def _recv_into():
b = bytearray(b"\0"*100) b = bytearray(b"\0"*100)
@ -1581,12 +1552,8 @@ else:
) )
# consume data # consume data
s.read() s.read()
s.write(b"over\n") s.write(b"over\n")
s.close() s.close()
finally:
server.stop()
server.join()
def test_handshake_timeout(self): def test_handshake_timeout(self):
# Issue #5103: SSL handshake must respect the socket timeout # Issue #5103: SSL handshake must respect the socket timeout