mirror of https://github.com/python/cpython
Use context managers in test_ssl to simplify test writing.
This commit is contained in:
parent
17c07134a9
commit
5b95eb90a7
|
@ -532,6 +532,14 @@ else:
|
|||
threading.Thread.__init__(self)
|
||||
self.daemon = True
|
||||
|
||||
def __enter__(self):
|
||||
self.start(threading.Event())
|
||||
self.flag.wait()
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.stop()
|
||||
self.join()
|
||||
|
||||
def start(self, flag=None):
|
||||
self.flag = flag
|
||||
threading.Thread.start(self)
|
||||
|
@ -638,6 +646,20 @@ else:
|
|||
def __str__(self):
|
||||
return "<%s %s>" % (self.__class__.__name__, self.server)
|
||||
|
||||
def __enter__(self):
|
||||
self.start(threading.Event())
|
||||
self.flag.wait()
|
||||
|
||||
def __exit__(self, *args):
|
||||
if test_support.verbose:
|
||||
sys.stdout.write(" cleanup: stopping server.\n")
|
||||
self.stop()
|
||||
if test_support.verbose:
|
||||
sys.stdout.write(" cleanup: joining server thread.\n")
|
||||
self.join()
|
||||
if test_support.verbose:
|
||||
sys.stdout.write(" cleanup: successfully joined.\n")
|
||||
|
||||
def start(self, flag=None):
|
||||
self.flag = flag
|
||||
threading.Thread.start(self)
|
||||
|
@ -752,12 +774,7 @@ else:
|
|||
server = ThreadedEchoServer(CERTFILE,
|
||||
certreqs=ssl.CERT_REQUIRED,
|
||||
cacerts=CERTFILE, chatty=False)
|
||||
flag = threading.Event()
|
||||
server.start(flag)
|
||||
# wait for it to start
|
||||
flag.wait()
|
||||
# try to connect
|
||||
try:
|
||||
with server:
|
||||
try:
|
||||
s = ssl.wrap_socket(socket.socket(),
|
||||
certfile=certfile,
|
||||
|
@ -771,9 +788,6 @@ else:
|
|||
sys.stdout.write("\nsocket.error is %s\n" % x[1])
|
||||
else:
|
||||
raise AssertionError("Use of invalid cert should have failed!")
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
def server_params_test(certfile, protocol, certreqs, cacertsfile,
|
||||
client_certfile, client_protocol=None, indata="FOO\n",
|
||||
|
@ -791,14 +805,10 @@ else:
|
|||
chatty=chatty,
|
||||
connectionchatty=connectionchatty,
|
||||
wrap_accepting_socket=wrap_accepting_socket)
|
||||
flag = threading.Event()
|
||||
server.start(flag)
|
||||
# wait for it to start
|
||||
flag.wait()
|
||||
# try to connect
|
||||
if client_protocol is None:
|
||||
client_protocol = protocol
|
||||
try:
|
||||
with server:
|
||||
# try to connect
|
||||
if client_protocol is None:
|
||||
client_protocol = protocol
|
||||
s = ssl.wrap_socket(socket.socket(),
|
||||
certfile=client_certfile,
|
||||
ca_certs=cacertsfile,
|
||||
|
@ -826,9 +836,6 @@ else:
|
|||
if test_support.verbose:
|
||||
sys.stdout.write(" client: closing connection.\n")
|
||||
s.close()
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
def try_protocol_combo(server_protocol,
|
||||
client_protocol,
|
||||
|
@ -930,12 +937,7 @@ else:
|
|||
ssl_version=ssl.PROTOCOL_SSLv23,
|
||||
cacerts=CERTFILE,
|
||||
chatty=False)
|
||||
flag = threading.Event()
|
||||
server.start(flag)
|
||||
# wait for it to start
|
||||
flag.wait()
|
||||
# try to connect
|
||||
try:
|
||||
with server:
|
||||
s = ssl.wrap_socket(socket.socket(),
|
||||
certfile=CERTFILE,
|
||||
ca_certs=CERTFILE,
|
||||
|
@ -957,9 +959,6 @@ else:
|
|||
"Missing or invalid 'organizationName' field in certificate subject; "
|
||||
"should be 'Python Software Foundation'.")
|
||||
s.close()
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
def test_empty_cert(self):
|
||||
"""Connecting with an empty cert file"""
|
||||
|
@ -1042,13 +1041,8 @@ else:
|
|||
starttls_server=True,
|
||||
chatty=True,
|
||||
connectionchatty=True)
|
||||
flag = threading.Event()
|
||||
server.start(flag)
|
||||
# wait for it to start
|
||||
flag.wait()
|
||||
# try to connect
|
||||
wrapped = False
|
||||
try:
|
||||
with server:
|
||||
s = socket.socket()
|
||||
s.setblocking(1)
|
||||
s.connect((HOST, server.port))
|
||||
|
@ -1093,9 +1087,6 @@ else:
|
|||
else:
|
||||
s.send("over\n")
|
||||
s.close()
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
def test_socketserver(self):
|
||||
"""Using a SocketServer to create and manage SSL connections."""
|
||||
|
@ -1145,12 +1136,7 @@ else:
|
|||
if test_support.verbose:
|
||||
sys.stdout.write("\n")
|
||||
server = AsyncoreEchoServer(CERTFILE)
|
||||
flag = threading.Event()
|
||||
server.start(flag)
|
||||
# wait for it to start
|
||||
flag.wait()
|
||||
# try to connect
|
||||
try:
|
||||
with server:
|
||||
s = ssl.wrap_socket(socket.socket())
|
||||
s.connect(('127.0.0.1', server.port))
|
||||
if test_support.verbose:
|
||||
|
@ -1169,10 +1155,6 @@ else:
|
|||
if test_support.verbose:
|
||||
sys.stdout.write(" client: closing connection.\n")
|
||||
s.close()
|
||||
finally:
|
||||
server.stop()
|
||||
# wait for server thread to end
|
||||
server.join()
|
||||
|
||||
def test_recv_send(self):
|
||||
"""Test recv(), send() and friends."""
|
||||
|
@ -1185,19 +1167,14 @@ else:
|
|||
cacerts=CERTFILE,
|
||||
chatty=True,
|
||||
connectionchatty=False)
|
||||
flag = threading.Event()
|
||||
server.start(flag)
|
||||
# wait for it to start
|
||||
flag.wait()
|
||||
# try to connect
|
||||
s = ssl.wrap_socket(socket.socket(),
|
||||
server_side=False,
|
||||
certfile=CERTFILE,
|
||||
ca_certs=CERTFILE,
|
||||
cert_reqs=ssl.CERT_NONE,
|
||||
ssl_version=ssl.PROTOCOL_TLSv1)
|
||||
s.connect((HOST, server.port))
|
||||
try:
|
||||
with server:
|
||||
s = ssl.wrap_socket(socket.socket(),
|
||||
server_side=False,
|
||||
certfile=CERTFILE,
|
||||
ca_certs=CERTFILE,
|
||||
cert_reqs=ssl.CERT_NONE,
|
||||
ssl_version=ssl.PROTOCOL_TLSv1)
|
||||
s.connect((HOST, server.port))
|
||||
# helper methods for standardising recv* method signatures
|
||||
def _recv_into():
|
||||
b = bytearray("\0"*100)
|
||||
|
@ -1285,9 +1262,6 @@ else:
|
|||
|
||||
s.write("over\n".encode("ASCII", "strict"))
|
||||
s.close()
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
def test_handshake_timeout(self):
|
||||
# Issue #5103: SSL handshake must respect the socket timeout
|
||||
|
|
Loading…
Reference in New Issue