mirror of https://github.com/python/cpython
asyncio: Better-looking errors when ssl module cannot be imported. In part by Arnaud Faure.
This commit is contained in:
parent
a8d630a6e6
commit
28dff0d823
|
@ -466,6 +466,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
ssl=None,
|
||||
reuse_address=None):
|
||||
"""XXX"""
|
||||
if isinstance(ssl, bool):
|
||||
raise TypeError('ssl argument must be an SSLContext or None')
|
||||
if host is not None or port is not None:
|
||||
if sock is not None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -90,12 +90,13 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
except (BlockingIOError, InterruptedError):
|
||||
pass
|
||||
|
||||
def _start_serving(self, protocol_factory, sock, ssl=None, server=None):
|
||||
def _start_serving(self, protocol_factory, sock,
|
||||
sslcontext=None, server=None):
|
||||
self.add_reader(sock.fileno(), self._accept_connection,
|
||||
protocol_factory, sock, ssl, server)
|
||||
protocol_factory, sock, sslcontext, server)
|
||||
|
||||
def _accept_connection(self, protocol_factory, sock, ssl=None,
|
||||
server=None):
|
||||
def _accept_connection(self, protocol_factory, sock,
|
||||
sslcontext=None, server=None):
|
||||
try:
|
||||
conn, addr = sock.accept()
|
||||
conn.setblocking(False)
|
||||
|
@ -113,13 +114,13 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
self.remove_reader(sock.fileno())
|
||||
self.call_later(constants.ACCEPT_RETRY_DELAY,
|
||||
self._start_serving,
|
||||
protocol_factory, sock, ssl, server)
|
||||
protocol_factory, sock, sslcontext, server)
|
||||
else:
|
||||
raise # The event loop will catch, log and ignore it.
|
||||
else:
|
||||
if ssl:
|
||||
if sslcontext:
|
||||
self._make_ssl_transport(
|
||||
conn, protocol_factory(), ssl, None,
|
||||
conn, protocol_factory(), sslcontext, None,
|
||||
server_side=True, extra={'peername': addr}, server=server)
|
||||
else:
|
||||
self._make_socket_transport(
|
||||
|
@ -558,17 +559,23 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None,
|
||||
server_side=False, server_hostname=None,
|
||||
extra=None, server=None):
|
||||
if ssl is None:
|
||||
raise RuntimeError('stdlib ssl module not available')
|
||||
|
||||
if server_side:
|
||||
assert isinstance(
|
||||
sslcontext, ssl.SSLContext), 'Must pass an SSLContext'
|
||||
if not sslcontext:
|
||||
raise ValueError('Server side ssl needs a valid SSLContext')
|
||||
else:
|
||||
# Client-side may pass ssl=True to use a default context.
|
||||
# The default is the same as used by urllib.
|
||||
if sslcontext is None:
|
||||
if not sslcontext:
|
||||
# Client side may pass ssl=True to use a default
|
||||
# context; in that case the sslcontext passed is None.
|
||||
# The default is the same as used by urllib with
|
||||
# cadefault=True.
|
||||
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext.options |= ssl.OP_NO_SSLv2
|
||||
sslcontext.set_default_verify_paths()
|
||||
sslcontext.verify_mode = ssl.CERT_REQUIRED
|
||||
|
||||
wrap_kwargs = {
|
||||
'server_side': server_side,
|
||||
'do_handshake_on_connect': False,
|
||||
|
|
|
@ -43,6 +43,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
|
|||
self.assertIsInstance(
|
||||
self.loop._make_socket_transport(m, m), _SelectorSocketTransport)
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
def test_make_ssl_transport(self):
|
||||
m = unittest.mock.Mock()
|
||||
self.loop.add_reader = unittest.mock.Mock()
|
||||
|
@ -52,6 +53,16 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
|
|||
self.assertIsInstance(
|
||||
self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport)
|
||||
|
||||
@unittest.mock.patch('asyncio.selector_events.ssl', None)
|
||||
def test_make_ssl_transport_without_ssl_error(self):
|
||||
m = unittest.mock.Mock()
|
||||
self.loop.add_reader = unittest.mock.Mock()
|
||||
self.loop.add_writer = unittest.mock.Mock()
|
||||
self.loop.remove_reader = unittest.mock.Mock()
|
||||
self.loop.remove_writer = unittest.mock.Mock()
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.loop._make_ssl_transport(m, m, m, m)
|
||||
|
||||
def test_close(self):
|
||||
ssock = self.loop._ssock
|
||||
ssock.fileno.return_value = 7
|
||||
|
@ -1277,6 +1288,15 @@ class SelectorSslTransportTests(unittest.TestCase):
|
|||
server_hostname='localhost')
|
||||
|
||||
|
||||
class SelectorSslWithoutSslTransportTests(unittest.TestCase):
|
||||
|
||||
@unittest.mock.patch('asyncio.selector_events.ssl', None)
|
||||
def test_ssl_transport_requires_ssl_module(self):
|
||||
Mock = unittest.mock.Mock
|
||||
with self.assertRaises(RuntimeError):
|
||||
transport = _SelectorSslTransport(Mock(), Mock(), Mock(), Mock())
|
||||
|
||||
|
||||
class SelectorDatagramTransportTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
|
Loading…
Reference in New Issue