asyncio: Add support for UNIX Domain Sockets.

This commit is contained in:
Yury Selivanov 2014-02-18 12:15:06 -05:00
parent c36e504c53
commit 88a5bf0b2e
10 changed files with 750 additions and 205 deletions

View File

@ -407,6 +407,13 @@ class BaseEventLoop(events.AbstractEventLoop):
sock.setblocking(False)
transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
return transport, protocol
@tasks.coroutine
def _create_connection_transport(self, sock, protocol_factory, ssl,
server_hostname):
protocol = protocol_factory()
waiter = futures.Future(loop=self)
if ssl:

View File

@ -220,6 +220,32 @@ class AbstractEventLoop:
"""
raise NotImplementedError
def create_unix_connection(self, protocol_factory, path, *,
ssl=None, sock=None,
server_hostname=None):
raise NotImplementedError
def create_unix_server(self, protocol_factory, path, *,
sock=None, backlog=100, ssl=None):
"""A coroutine which creates a UNIX Domain Socket server.
The return valud is a Server object, which can be used to stop
the service.
path is a str, representing a file systsem path to bind the
server socket to.
sock can optionally be specified in order to use a preexisting
socket object.
backlog is the maximum number of queued connections passed to
listen() (defaults to 100).
ssl can be set to an SSLContext to enable SSL over the
accepted connections.
"""
raise NotImplementedError
def create_datagram_endpoint(self, protocol_factory,
local_addr=None, remote_addr=None, *,
family=0, proto=0, flags=0):

View File

@ -1,9 +1,13 @@
"""Stream-related things."""
__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
'open_connection', 'start_server', 'IncompleteReadError',
'open_connection', 'start_server',
'open_unix_connection', 'start_unix_server',
'IncompleteReadError',
]
import socket
from . import events
from . import futures
from . import protocols
@ -93,6 +97,39 @@ def start_server(client_connected_cb, host=None, port=None, *,
return (yield from loop.create_server(factory, host, port, **kwds))
if hasattr(socket, 'AF_UNIX'):
# UNIX Domain Sockets are supported on this platform
@tasks.coroutine
def open_unix_connection(path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
if loop is None:
loop = events.get_event_loop()
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, loop=loop)
transport, _ = yield from loop.create_unix_connection(
lambda: protocol, path, **kwds)
writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
@tasks.coroutine
def start_unix_server(client_connected_cb, path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `start_server` but works with UNIX Domain Sockets."""
if loop is None:
loop = events.get_event_loop()
def factory():
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop)
return protocol
return (yield from loop.create_unix_server(factory, path, **kwds))
class FlowControlMixin(protocols.Protocol):
"""Reusable flow control logic for StreamWriter.drain().

View File

@ -4,12 +4,18 @@ import collections
import contextlib
import io
import os
import socket
import socketserver
import sys
import tempfile
import threading
import time
import unittest
import unittest.mock
from http.server import HTTPServer
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
try:
import ssl
except ImportError: # pragma: no cover
@ -70,42 +76,51 @@ def run_once(loop):
loop.run_forever()
@contextlib.contextmanager
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
class SilentWSGIRequestHandler(WSGIRequestHandler):
class SilentWSGIRequestHandler(WSGIRequestHandler):
def get_stderr(self):
return io.StringIO()
def get_stderr(self):
return io.StringIO()
def log_message(self, format, *args):
def log_message(self, format, *args):
pass
class SilentWSGIServer(WSGIServer):
def handle_error(self, request, client_address):
pass
class SSLWSGIServerMixin:
def finish_request(self, request, client_address):
# The relative location of our test directory (which
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio.
# Prefer our own if we can find it.
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
if not os.path.isdir(here):
here = os.path.join(os.path.dirname(os.__file__),
'test', 'test_asyncio')
keyfile = os.path.join(here, 'ssl_key.pem')
certfile = os.path.join(here, 'ssl_cert.pem')
ssock = ssl.wrap_socket(request,
keyfile=keyfile,
certfile=certfile,
server_side=True)
try:
self.RequestHandlerClass(ssock, client_address, self)
ssock.close()
except OSError:
# maybe socket has been closed by peer
pass
class SilentWSGIServer(WSGIServer):
def handle_error(self, request, client_address):
pass
class SSLWSGIServer(SilentWSGIServer):
def finish_request(self, request, client_address):
# The relative location of our test directory (which
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio.
# Prefer our own if we can find it.
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
if not os.path.isdir(here):
here = os.path.join(os.path.dirname(os.__file__),
'test', 'test_asyncio')
keyfile = os.path.join(here, 'ssl_key.pem')
certfile = os.path.join(here, 'ssl_cert.pem')
ssock = ssl.wrap_socket(request,
keyfile=keyfile,
certfile=certfile,
server_side=True)
try:
self.RequestHandlerClass(ssock, client_address, self)
ssock.close()
except OSError:
# maybe socket has been closed by peer
pass
class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
pass
def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
def app(environ, start_response):
status = '200 OK'
@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
# Run the test WSGI server in a separate thread in order not to
# interfere with event handling in the main thread
server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
httpd = make_server(host, port, app,
server_class, SilentWSGIRequestHandler)
server_class = server_ssl_cls if use_ssl else server_cls
httpd = server_class(address, SilentWSGIRequestHandler)
httpd.set_app(app)
httpd.address = httpd.server_address
server_thread = threading.Thread(target=httpd.serve_forever)
server_thread.start()
@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
server_thread.join()
if hasattr(socket, 'AF_UNIX'):
class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
def server_bind(self):
socketserver.UnixStreamServer.server_bind(self)
self.server_name = '127.0.0.1'
self.server_port = 80
class UnixWSGIServer(UnixHTTPServer, WSGIServer):
def server_bind(self):
UnixHTTPServer.server_bind(self)
self.setup_environ()
def get_request(self):
request, client_addr = super().get_request()
# Code in the stdlib expects that get_request
# will return a socket and a tuple (host, port).
# However, this isn't true for UNIX sockets,
# as the second return value will be a path;
# hence we return some fake data sufficient
# to get the tests going
return request, ('127.0.0.1', '')
class SilentUnixWSGIServer(UnixWSGIServer):
def handle_error(self, request, client_address):
pass
class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
pass
def gen_unix_socket_path():
with tempfile.NamedTemporaryFile() as file:
return file.name
@contextlib.contextmanager
def unix_socket_path():
path = gen_unix_socket_path()
try:
yield path
finally:
try:
os.unlink(path)
except OSError:
pass
@contextlib.contextmanager
def run_test_unix_server(*, use_ssl=False):
with unix_socket_path() as path:
yield from _run_test_server(address=path, use_ssl=use_ssl,
server_cls=SilentUnixWSGIServer,
server_ssl_cls=UnixSSLWSGIServer)
@contextlib.contextmanager
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
server_cls=SilentWSGIServer,
server_ssl_cls=SSLWSGIServer)
def make_test_protocol(base):
dct = {}
for name in dir(base):
@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop):
def _write_to_self(self):
pass
def MockCallback(**kwargs):
return unittest.mock.Mock(spec=['__call__'], **kwargs)

View File

@ -11,6 +11,7 @@ import sys
import threading
from . import base_events
from . import base_subprocess
from . import constants
from . import events
@ -31,9 +32,9 @@ if sys.platform == 'win32': # pragma: no cover
class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
"""Unix event loop
"""Unix event loop.
Adds signal handling to SelectorEventLoop
Adds signal handling and UNIX Domain Socket support to SelectorEventLoop.
"""
def __init__(self, selector=None):
@ -164,6 +165,76 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
def _child_watcher_callback(self, pid, returncode, transp):
self.call_soon_threadsafe(transp._process_exited, returncode)
@tasks.coroutine
def create_unix_connection(self, protocol_factory, path, *,
ssl=None, sock=None,
server_hostname=None):
assert server_hostname is None or isinstance(server_hostname, str)
if ssl:
if server_hostname is None:
raise ValueError(
'you have to pass server_hostname when using ssl')
else:
if server_hostname is not None:
raise ValueError('server_hostname is only meaningful with ssl')
if path is not None:
if sock is not None:
raise ValueError(
'path and sock can not be specified at the same time')
try:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
sock.setblocking(False)
yield from self.sock_connect(sock, path)
except OSError:
if sock is not None:
sock.close()
raise
else:
if sock is None:
raise ValueError('no path and sock were specified')
sock.setblocking(False)
transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
return transport, protocol
@tasks.coroutine
def create_unix_server(self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None):
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')
if path is not None:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.bind(path)
except OSError as exc:
if exc.errno == errno.EADDRINUSE:
# Let's improve the error message by adding
# with what exact address it occurs.
msg = 'Address {!r} is already in use'.format(path)
raise OSError(errno.EADDRINUSE, msg) from None
else:
raise
else:
if sock is None:
raise ValueError(
'path was not specified, and no sock specified')
if sock.family != socket.AF_UNIX:
raise ValueError(
'A UNIX Domain Socket was expected, got {!r}'.format(sock))
server = base_events.Server(self, [sock])
sock.listen(backlog)
sock.setblocking(False)
self._start_serving(protocol_factory, sock, ssl, server)
return server
def _set_nonblocking(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL)

View File

@ -212,7 +212,7 @@ class BaseEventLoopTests(unittest.TestCase):
idx = -1
data = [10.0, 10.0, 10.3, 13.0]
self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda:True, ())]
self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, ())]
self.loop._run_once()
self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0])

View File

@ -39,13 +39,14 @@ def data_file(filename):
return fullname
raise FileNotFoundError(filename)
ONLYCERT = data_file('ssl_cert.pem')
ONLYKEY = data_file('ssl_key.pem')
SIGNED_CERTFILE = data_file('keycert3.pem')
SIGNING_CA = data_file('pycacert.pem')
class MyProto(asyncio.Protocol):
class MyBaseProto(asyncio.Protocol):
done = None
def __init__(self, loop=None):
@ -59,7 +60,6 @@ class MyProto(asyncio.Protocol):
self.transport = transport
assert self.state == 'INITIAL', self.state
self.state = 'CONNECTED'
transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
def data_received(self, data):
assert self.state == 'CONNECTED', self.state
@ -76,6 +76,12 @@ class MyProto(asyncio.Protocol):
self.done.set_result(None)
class MyProto(MyBaseProto):
def connection_made(self, transport):
super().connection_made(transport)
transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
class MyDatagramProto(asyncio.DatagramProtocol):
done = None
@ -357,22 +363,30 @@ class EventLoopTestsMixin:
r.close()
self.assertGreaterEqual(len(data), 200)
def _basetest_sock_client_ops(self, httpd, sock):
sock.setblocking(False)
self.loop.run_until_complete(
self.loop.sock_connect(sock, httpd.address))
self.loop.run_until_complete(
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
data = self.loop.run_until_complete(
self.loop.sock_recv(sock, 1024))
# consume data
self.loop.run_until_complete(
self.loop.sock_recv(sock, 1024))
sock.close()
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
def test_sock_client_ops(self):
with test_utils.run_test_server() as httpd:
sock = socket.socket()
sock.setblocking(False)
self.loop.run_until_complete(
self.loop.sock_connect(sock, httpd.address))
self.loop.run_until_complete(
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
data = self.loop.run_until_complete(
self.loop.sock_recv(sock, 1024))
# consume data
self.loop.run_until_complete(
self.loop.sock_recv(sock, 1024))
sock.close()
self._basetest_sock_client_ops(httpd, sock)
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_unix_sock_client_ops(self):
with test_utils.run_test_unix_server() as httpd:
sock = socket.socket(socket.AF_UNIX)
self._basetest_sock_client_ops(httpd, sock)
def test_sock_client_fail(self):
# Make sure that we will get an unused port
@ -485,16 +499,26 @@ class EventLoopTestsMixin:
self.loop.run_forever()
self.assertEqual(caught, 1)
def _basetest_create_connection(self, connection_fut):
tr, pr = self.loop.run_until_complete(connection_fut)
self.assertIsInstance(tr, asyncio.Transport)
self.assertIsInstance(pr, asyncio.Protocol)
self.loop.run_until_complete(pr.done)
self.assertGreater(pr.nbytes, 0)
tr.close()
def test_create_connection(self):
with test_utils.run_test_server() as httpd:
f = self.loop.create_connection(
conn_fut = self.loop.create_connection(
lambda: MyProto(loop=self.loop), *httpd.address)
tr, pr = self.loop.run_until_complete(f)
self.assertIsInstance(tr, asyncio.Transport)
self.assertIsInstance(pr, asyncio.Protocol)
self.loop.run_until_complete(pr.done)
self.assertGreater(pr.nbytes, 0)
tr.close()
self._basetest_create_connection(conn_fut)
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_connection(self):
with test_utils.run_test_unix_server() as httpd:
conn_fut = self.loop.create_unix_connection(
lambda: MyProto(loop=self.loop), httpd.address)
self._basetest_create_connection(conn_fut)
def test_create_connection_sock(self):
with test_utils.run_test_server() as httpd:
@ -524,20 +548,37 @@ class EventLoopTestsMixin:
self.assertGreater(pr.nbytes, 0)
tr.close()
def _basetest_create_ssl_connection(self, connection_fut):
tr, pr = self.loop.run_until_complete(connection_fut)
self.assertIsInstance(tr, asyncio.Transport)
self.assertIsInstance(pr, asyncio.Protocol)
self.assertTrue('ssl' in tr.__class__.__name__.lower())
self.assertIsNotNone(tr.get_extra_info('sockname'))
self.loop.run_until_complete(pr.done)
self.assertGreater(pr.nbytes, 0)
tr.close()
@unittest.skipIf(ssl is None, 'No ssl module')
def test_create_ssl_connection(self):
with test_utils.run_test_server(use_ssl=True) as httpd:
f = self.loop.create_connection(
lambda: MyProto(loop=self.loop), *httpd.address,
conn_fut = self.loop.create_connection(
lambda: MyProto(loop=self.loop),
*httpd.address,
ssl=test_utils.dummy_ssl_context())
tr, pr = self.loop.run_until_complete(f)
self.assertIsInstance(tr, asyncio.Transport)
self.assertIsInstance(pr, asyncio.Protocol)
self.assertTrue('ssl' in tr.__class__.__name__.lower())
self.assertIsNotNone(tr.get_extra_info('sockname'))
self.loop.run_until_complete(pr.done)
self.assertGreater(pr.nbytes, 0)
tr.close()
self._basetest_create_ssl_connection(conn_fut)
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_ssl_unix_connection(self):
with test_utils.run_test_unix_server(use_ssl=True) as httpd:
conn_fut = self.loop.create_unix_connection(
lambda: MyProto(loop=self.loop),
httpd.address,
ssl=test_utils.dummy_ssl_context(),
server_hostname='127.0.0.1')
self._basetest_create_ssl_connection(conn_fut)
def test_create_connection_local_addr(self):
with test_utils.run_test_server() as httpd:
@ -561,14 +602,8 @@ class EventLoopTestsMixin:
self.assertIn(str(httpd.address), cm.exception.strerror)
def test_create_server(self):
proto = None
def factory():
nonlocal proto
proto = MyProto()
return proto
f = self.loop.create_server(factory, '0.0.0.0', 0)
proto = MyProto()
f = self.loop.create_server(lambda: proto, '0.0.0.0', 0)
server = self.loop.run_until_complete(f)
self.assertEqual(len(server.sockets), 1)
sock = server.sockets[0]
@ -605,38 +640,76 @@ class EventLoopTestsMixin:
# close server
server.close()
def _make_ssl_server(self, factory, certfile, keyfile=None):
def _make_unix_server(self, factory, **kwargs):
path = test_utils.gen_unix_socket_path()
self.addCleanup(lambda: os.path.exists(path) and os.unlink(path))
f = self.loop.create_unix_server(factory, path, **kwargs)
server = self.loop.run_until_complete(f)
return server, path
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server(self):
proto = MyProto()
server, path = self._make_unix_server(lambda: proto)
self.assertEqual(len(server.sockets), 1)
client = socket.socket(socket.AF_UNIX)
client.connect(path)
client.sendall(b'xxx')
test_utils.run_briefly(self.loop)
test_utils.run_until(self.loop, lambda: proto is not None, 10)
self.assertIsInstance(proto, MyProto)
self.assertEqual('INITIAL', proto.state)
test_utils.run_briefly(self.loop)
self.assertEqual('CONNECTED', proto.state)
test_utils.run_until(self.loop, lambda: proto.nbytes > 0,
timeout=10)
self.assertEqual(3, proto.nbytes)
# close connection
proto.transport.close()
test_utils.run_briefly(self.loop) # windows iocp
self.assertEqual('CLOSED', proto.state)
# the client socket must be closed after to avoid ECONNRESET upon
# recv()/send() on the serving socket
client.close()
# close server
server.close()
def _create_ssl_context(self, certfile, keyfile=None):
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.load_cert_chain(certfile, keyfile)
return sslcontext
f = self.loop.create_server(
factory, '127.0.0.1', 0, ssl=sslcontext)
def _make_ssl_server(self, factory, certfile, keyfile=None):
sslcontext = self._create_ssl_context(certfile, keyfile)
f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext)
server = self.loop.run_until_complete(f)
sock = server.sockets[0]
host, port = sock.getsockname()
self.assertEqual(host, '127.0.0.1')
return server, host, port
def _make_ssl_unix_server(self, factory, certfile, keyfile=None):
sslcontext = self._create_ssl_context(certfile, keyfile)
return self._make_unix_server(factory, ssl=sslcontext)
@unittest.skipIf(ssl is None, 'No ssl module')
def test_create_server_ssl(self):
proto = None
proto = MyProto(loop=self.loop)
server, host, port = self._make_ssl_server(
lambda: proto, ONLYCERT, ONLYKEY)
class ClientMyProto(MyProto):
def connection_made(self, transport):
self.transport = transport
assert self.state == 'INITIAL', self.state
self.state = 'CONNECTED'
def factory():
nonlocal proto
proto = MyProto(loop=self.loop)
return proto
server, host, port = self._make_ssl_server(factory, ONLYCERT, ONLYKEY)
f_c = self.loop.create_connection(ClientMyProto, host, port,
f_c = self.loop.create_connection(MyBaseProto, host, port,
ssl=test_utils.dummy_ssl_context())
client, pr = self.loop.run_until_complete(f_c)
@ -666,17 +739,46 @@ class EventLoopTestsMixin:
# stop serving
server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server_ssl(self):
proto = MyProto(loop=self.loop)
server, path = self._make_ssl_unix_server(
lambda: proto, ONLYCERT, ONLYKEY)
f_c = self.loop.create_unix_connection(
MyBaseProto, path, ssl=test_utils.dummy_ssl_context(),
server_hostname='')
client, pr = self.loop.run_until_complete(f_c)
client.write(b'xxx')
test_utils.run_briefly(self.loop)
self.assertIsInstance(proto, MyProto)
test_utils.run_briefly(self.loop)
self.assertEqual('CONNECTED', proto.state)
test_utils.run_until(self.loop, lambda: proto.nbytes > 0,
timeout=10)
self.assertEqual(3, proto.nbytes)
# close connection
proto.transport.close()
self.loop.run_until_complete(proto.done)
self.assertEqual('CLOSED', proto.state)
# the client socket must be closed after to avoid ECONNRESET upon
# recv()/send() on the serving socket
client.close()
# stop serving
server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
def test_create_server_ssl_verify_failed(self):
proto = None
def factory():
nonlocal proto
proto = MyProto(loop=self.loop)
return proto
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
proto = MyProto(loop=self.loop)
server, host, port = self._make_ssl_server(
lambda: proto, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
@ -695,17 +797,38 @@ class EventLoopTestsMixin:
self.assertIsNone(proto.transport)
server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server_ssl_verify_failed(self):
proto = MyProto(loop=self.loop)
server, path = self._make_ssl_unix_server(
lambda: proto, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
if hasattr(sslcontext_client, 'check_hostname'):
sslcontext_client.check_hostname = True
# no CA loaded
f_c = self.loop.create_unix_connection(MyProto, path,
ssl=sslcontext_client,
server_hostname='invalid')
with self.assertRaisesRegex(ssl.SSLError,
'certificate verify failed '):
self.loop.run_until_complete(f_c)
# close connection
self.assertIsNone(proto.transport)
server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
def test_create_server_ssl_match_failed(self):
proto = None
def factory():
nonlocal proto
proto = MyProto(loop=self.loop)
return proto
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
proto = MyProto(loop=self.loop)
server, host, port = self._make_ssl_server(
lambda: proto, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
@ -727,17 +850,38 @@ class EventLoopTestsMixin:
proto.transport.close()
server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server_ssl_verified(self):
proto = MyProto(loop=self.loop)
server, path = self._make_ssl_unix_server(
lambda: proto, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
if hasattr(sslcontext_client, 'check_hostname'):
sslcontext_client.check_hostname = True
# Connection succeeds with correct CA and server hostname.
f_c = self.loop.create_unix_connection(MyProto, path,
ssl=sslcontext_client,
server_hostname='localhost')
client, pr = self.loop.run_until_complete(f_c)
# close connection
proto.transport.close()
client.close()
server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
def test_create_server_ssl_verified(self):
proto = None
def factory():
nonlocal proto
proto = MyProto(loop=self.loop)
return proto
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
proto = MyProto(loop=self.loop)
server, host, port = self._make_ssl_server(
lambda: proto, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
@ -915,19 +1059,15 @@ class EventLoopTestsMixin:
@unittest.skipUnless(sys.platform != 'win32',
"Don't support pipes for Windows")
def test_read_pipe(self):
proto = None
def factory():
nonlocal proto
proto = MyReadPipeProto(loop=self.loop)
return proto
proto = MyReadPipeProto(loop=self.loop)
rpipe, wpipe = os.pipe()
pipeobj = io.open(rpipe, 'rb', 1024)
@asyncio.coroutine
def connect():
t, p = yield from self.loop.connect_read_pipe(factory, pipeobj)
t, p = yield from self.loop.connect_read_pipe(
lambda: proto, pipeobj)
self.assertIs(p, proto)
self.assertIs(t, proto.transport)
self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
@ -959,19 +1099,14 @@ class EventLoopTestsMixin:
# Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9
@support.requires_freebsd_version(8)
def test_read_pty_output(self):
proto = None
def factory():
nonlocal proto
proto = MyReadPipeProto(loop=self.loop)
return proto
proto = MyReadPipeProto(loop=self.loop)
master, slave = os.openpty()
master_read_obj = io.open(master, 'rb', 0)
@asyncio.coroutine
def connect():
t, p = yield from self.loop.connect_read_pipe(factory,
t, p = yield from self.loop.connect_read_pipe(lambda: proto,
master_read_obj)
self.assertIs(p, proto)
self.assertIs(t, proto.transport)
@ -999,21 +1134,17 @@ class EventLoopTestsMixin:
@unittest.skipUnless(sys.platform != 'win32',
"Don't support pipes for Windows")
def test_write_pipe(self):
proto = None
proto = MyWritePipeProto(loop=self.loop)
transport = None
def factory():
nonlocal proto
proto = MyWritePipeProto(loop=self.loop)
return proto
rpipe, wpipe = os.pipe()
pipeobj = io.open(wpipe, 'wb', 1024)
@asyncio.coroutine
def connect():
nonlocal transport
t, p = yield from self.loop.connect_write_pipe(factory, pipeobj)
t, p = yield from self.loop.connect_write_pipe(
lambda: proto, pipeobj)
self.assertIs(p, proto)
self.assertIs(t, proto.transport)
self.assertEqual('CONNECTED', proto.state)
@ -1045,21 +1176,16 @@ class EventLoopTestsMixin:
@unittest.skipUnless(sys.platform != 'win32',
"Don't support pipes for Windows")
def test_write_pipe_disconnect_on_close(self):
proto = None
proto = MyWritePipeProto(loop=self.loop)
transport = None
def factory():
nonlocal proto
proto = MyWritePipeProto(loop=self.loop)
return proto
rsock, wsock = test_utils.socketpair()
pipeobj = io.open(wsock.detach(), 'wb', 1024)
@asyncio.coroutine
def connect():
nonlocal transport
t, p = yield from self.loop.connect_write_pipe(factory,
t, p = yield from self.loop.connect_write_pipe(lambda: proto,
pipeobj)
self.assertIs(p, proto)
self.assertIs(t, proto.transport)
@ -1084,21 +1210,16 @@ class EventLoopTestsMixin:
# older than 10.6 (Snow Leopard)
@support.requires_mac_ver(10, 6)
def test_write_pty(self):
proto = None
proto = MyWritePipeProto(loop=self.loop)
transport = None
def factory():
nonlocal proto
proto = MyWritePipeProto(loop=self.loop)
return proto
master, slave = os.openpty()
slave_write_obj = io.open(slave, 'wb', 0)
@asyncio.coroutine
def connect():
nonlocal transport
t, p = yield from self.loop.connect_write_pipe(factory,
t, p = yield from self.loop.connect_write_pipe(lambda: proto,
slave_write_obj)
self.assertIs(p, proto)
self.assertIs(t, proto.transport)

View File

@ -55,7 +55,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.remove_reader = unittest.mock.Mock()
self.loop.remove_writer = unittest.mock.Mock()
waiter = asyncio.Future(loop=self.loop)
transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter)
transport = self.loop._make_ssl_transport(
m, asyncio.Protocol(), m, waiter)
self.assertIsInstance(transport, _SelectorSslTransport)
@unittest.mock.patch('asyncio.selector_events.ssl', None)

View File

@ -1,6 +1,8 @@
"""Tests for streams.py."""
import functools
import gc
import socket
import unittest
import unittest.mock
try:
@ -32,48 +34,85 @@ class StreamReaderTests(unittest.TestCase):
stream = asyncio.StreamReader()
self.assertIs(stream._loop, m_events.get_event_loop.return_value)
def _basetest_open_connection(self, open_connection_fut):
reader, writer = self.loop.run_until_complete(open_connection_fut)
writer.write(b'GET / HTTP/1.0\r\n\r\n')
f = reader.readline()
data = self.loop.run_until_complete(f)
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
f = reader.read()
data = self.loop.run_until_complete(f)
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
writer.close()
def test_open_connection(self):
with test_utils.run_test_server() as httpd:
f = asyncio.open_connection(*httpd.address, loop=self.loop)
reader, writer = self.loop.run_until_complete(f)
writer.write(b'GET / HTTP/1.0\r\n\r\n')
f = reader.readline()
data = self.loop.run_until_complete(f)
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
f = reader.read()
data = self.loop.run_until_complete(f)
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
conn_fut = asyncio.open_connection(*httpd.address,
loop=self.loop)
self._basetest_open_connection(conn_fut)
writer.close()
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_open_unix_connection(self):
with test_utils.run_test_unix_server() as httpd:
conn_fut = asyncio.open_unix_connection(httpd.address,
loop=self.loop)
self._basetest_open_connection(conn_fut)
def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
try:
reader, writer = self.loop.run_until_complete(open_connection_fut)
finally:
asyncio.set_event_loop(None)
writer.write(b'GET / HTTP/1.0\r\n\r\n')
f = reader.read()
data = self.loop.run_until_complete(f)
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
writer.close()
@unittest.skipIf(ssl is None, 'No ssl module')
def test_open_connection_no_loop_ssl(self):
with test_utils.run_test_server(use_ssl=True) as httpd:
try:
asyncio.set_event_loop(self.loop)
f = asyncio.open_connection(*httpd.address,
ssl=test_utils.dummy_ssl_context())
reader, writer = self.loop.run_until_complete(f)
finally:
asyncio.set_event_loop(None)
writer.write(b'GET / HTTP/1.0\r\n\r\n')
f = reader.read()
data = self.loop.run_until_complete(f)
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
conn_fut = asyncio.open_connection(
*httpd.address,
ssl=test_utils.dummy_ssl_context(),
loop=self.loop)
writer.close()
self._basetest_open_connection_no_loop_ssl(conn_fut)
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_open_unix_connection_no_loop_ssl(self):
with test_utils.run_test_unix_server(use_ssl=True) as httpd:
conn_fut = asyncio.open_unix_connection(
httpd.address,
ssl=test_utils.dummy_ssl_context(),
server_hostname='',
loop=self.loop)
self._basetest_open_connection_no_loop_ssl(conn_fut)
def _basetest_open_connection_error(self, open_connection_fut):
reader, writer = self.loop.run_until_complete(open_connection_fut)
writer._protocol.connection_lost(ZeroDivisionError())
f = reader.read()
with self.assertRaises(ZeroDivisionError):
self.loop.run_until_complete(f)
writer.close()
test_utils.run_briefly(self.loop)
def test_open_connection_error(self):
with test_utils.run_test_server() as httpd:
f = asyncio.open_connection(*httpd.address, loop=self.loop)
reader, writer = self.loop.run_until_complete(f)
writer._protocol.connection_lost(ZeroDivisionError())
f = reader.read()
with self.assertRaises(ZeroDivisionError):
self.loop.run_until_complete(f)
conn_fut = asyncio.open_connection(*httpd.address,
loop=self.loop)
self._basetest_open_connection_error(conn_fut)
writer.close()
test_utils.run_briefly(self.loop)
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_open_unix_connection_error(self):
with test_utils.run_test_unix_server() as httpd:
conn_fut = asyncio.open_unix_connection(httpd.address,
loop=self.loop)
self._basetest_open_connection_error(conn_fut)
def test_feed_empty_data(self):
stream = asyncio.StreamReader(loop=self.loop)
@ -415,10 +454,86 @@ class StreamReaderTests(unittest.TestCase):
client_writer.write(data)
def start(self):
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client,
'127.0.0.1', 12345,
sock=sock,
loop=self.loop))
return sock.getsockname()
def handle_client_callback(self, client_reader, client_writer):
task = asyncio.Task(client_reader.readline(), loop=self.loop)
def done(task):
client_writer.write(task.result())
task.add_done_callback(done)
def start_callback(self):
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
addr = sock.getsockname()
sock.close()
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client_callback,
host=addr[0], port=addr[1],
loop=self.loop))
return addr
def stop(self):
if self.server is not None:
self.server.close()
self.loop.run_until_complete(self.server.wait_closed())
self.server = None
@asyncio.coroutine
def client(addr):
reader, writer = yield from asyncio.open_connection(
*addr, loop=self.loop)
# send a line
writer.write(b"hello world!\n")
# read it back
msgback = yield from reader.readline()
writer.close()
return msgback
# test the server variant with a coroutine as client handler
server = MyServer(self.loop)
addr = server.start()
msg = self.loop.run_until_complete(asyncio.Task(client(addr),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
# test the server variant with a callback as client handler
server = MyServer(self.loop)
addr = server.start_callback()
msg = self.loop.run_until_complete(asyncio.Task(client(addr),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_start_unix_server(self):
class MyServer:
def __init__(self, loop, path):
self.server = None
self.loop = loop
self.path = path
@asyncio.coroutine
def handle_client(self, client_reader, client_writer):
data = yield from client_reader.readline()
client_writer.write(data)
def start(self):
self.server = self.loop.run_until_complete(
asyncio.start_unix_server(self.handle_client,
path=self.path,
loop=self.loop))
def handle_client_callback(self, client_reader, client_writer):
task = asyncio.Task(client_reader.readline(), loop=self.loop)
@ -430,9 +545,9 @@ class StreamReaderTests(unittest.TestCase):
def start_callback(self):
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client_callback,
'127.0.0.1', 12345,
loop=self.loop))
asyncio.start_unix_server(self.handle_client_callback,
path=self.path,
loop=self.loop))
def stop(self):
if self.server is not None:
@ -441,9 +556,9 @@ class StreamReaderTests(unittest.TestCase):
self.server = None
@asyncio.coroutine
def client():
reader, writer = yield from asyncio.open_connection(
'127.0.0.1', 12345, loop=self.loop)
def client(path):
reader, writer = yield from asyncio.open_unix_connection(
path, loop=self.loop)
# send a line
writer.write(b"hello world!\n")
# read it back
@ -452,20 +567,22 @@ class StreamReaderTests(unittest.TestCase):
return msgback
# test the server variant with a coroutine as client handler
server = MyServer(self.loop)
server.start()
msg = self.loop.run_until_complete(asyncio.Task(client(),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
with test_utils.unix_socket_path() as path:
server = MyServer(self.loop, path)
server.start()
msg = self.loop.run_until_complete(asyncio.Task(client(path),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
# test the server variant with a callback as client handler
server = MyServer(self.loop)
server.start_callback()
msg = self.loop.run_until_complete(asyncio.Task(client(),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
with test_utils.unix_socket_path() as path:
server = MyServer(self.loop, path)
server.start_callback()
msg = self.loop.run_until_complete(asyncio.Task(client(path),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
if __name__ == '__main__':

View File

@ -7,8 +7,10 @@ import io
import os
import pprint
import signal
import socket
import stat
import sys
import tempfile
import threading
import unittest
import unittest.mock
@ -24,7 +26,7 @@ from asyncio import unix_events
@unittest.skipUnless(signal, 'Signals are not supported')
class SelectorEventLoopTests(unittest.TestCase):
class SelectorEventLoopSignalTests(unittest.TestCase):
def setUp(self):
self.loop = asyncio.SelectorEventLoop()
@ -200,6 +202,84 @@ class SelectorEventLoopTests(unittest.TestCase):
m_signal.set_wakeup_fd.assert_called_once_with(-1)
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
'UNIX Sockets are not supported')
class SelectorEventLoopUnixSocketTests(unittest.TestCase):
def setUp(self):
self.loop = asyncio.SelectorEventLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_create_unix_server_existing_path_sock(self):
with test_utils.unix_socket_path() as path:
sock = socket.socket(socket.AF_UNIX)
sock.bind(path)
coro = self.loop.create_unix_server(lambda: None, path)
with self.assertRaisesRegexp(OSError,
'Address.*is already in use'):
self.loop.run_until_complete(coro)
def test_create_unix_server_existing_path_nonsock(self):
with tempfile.NamedTemporaryFile() as file:
coro = self.loop.create_unix_server(lambda: None, file.name)
with self.assertRaisesRegexp(OSError,
'Address.*is already in use'):
self.loop.run_until_complete(coro)
def test_create_unix_server_ssl_bool(self):
coro = self.loop.create_unix_server(lambda: None, path='spam',
ssl=True)
with self.assertRaisesRegex(TypeError,
'ssl argument must be an SSLContext'):
self.loop.run_until_complete(coro)
def test_create_unix_server_nopath_nosock(self):
coro = self.loop.create_unix_server(lambda: None, path=None)
with self.assertRaisesRegex(ValueError,
'path was not specified, and no sock'):
self.loop.run_until_complete(coro)
def test_create_unix_server_path_inetsock(self):
coro = self.loop.create_unix_server(lambda: None, path=None,
sock=socket.socket())
with self.assertRaisesRegex(ValueError,
'A UNIX Domain Socket was expected'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_path_sock(self):
coro = self.loop.create_unix_connection(
lambda: None, '/dev/null', sock=object())
with self.assertRaisesRegex(ValueError, 'path and sock can not be'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_nopath_nosock(self):
coro = self.loop.create_unix_connection(
lambda: None, None)
with self.assertRaisesRegex(ValueError,
'no path and sock were specified'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_nossl_serverhost(self):
coro = self.loop.create_unix_connection(
lambda: None, '/dev/null', server_hostname='spam')
with self.assertRaisesRegex(ValueError,
'server_hostname is only meaningful'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_ssl_noserverhost(self):
coro = self.loop.create_unix_connection(
lambda: None, '/dev/null', ssl=True)
with self.assertRaisesRegexp(
ValueError, 'you have to pass server_hostname when using ssl'):
self.loop.run_until_complete(coro)
class UnixReadPipeTransportTests(unittest.TestCase):
def setUp(self):