asyncio: Add support for UNIX Domain Sockets.

This commit is contained in:
Yury Selivanov 2014-02-18 12:15:06 -05:00
parent c36e504c53
commit 88a5bf0b2e
10 changed files with 750 additions and 205 deletions

View File

@ -407,6 +407,13 @@ class BaseEventLoop(events.AbstractEventLoop):
sock.setblocking(False) sock.setblocking(False)
transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
return transport, protocol
@tasks.coroutine
def _create_connection_transport(self, sock, protocol_factory, ssl,
server_hostname):
protocol = protocol_factory() protocol = protocol_factory()
waiter = futures.Future(loop=self) waiter = futures.Future(loop=self)
if ssl: if ssl:

View File

@ -220,6 +220,32 @@ class AbstractEventLoop:
""" """
raise NotImplementedError raise NotImplementedError
def create_unix_connection(self, protocol_factory, path, *,
ssl=None, sock=None,
server_hostname=None):
raise NotImplementedError
def create_unix_server(self, protocol_factory, path, *,
sock=None, backlog=100, ssl=None):
"""A coroutine which creates a UNIX Domain Socket server.
The return valud is a Server object, which can be used to stop
the service.
path is a str, representing a file systsem path to bind the
server socket to.
sock can optionally be specified in order to use a preexisting
socket object.
backlog is the maximum number of queued connections passed to
listen() (defaults to 100).
ssl can be set to an SSLContext to enable SSL over the
accepted connections.
"""
raise NotImplementedError
def create_datagram_endpoint(self, protocol_factory, def create_datagram_endpoint(self, protocol_factory,
local_addr=None, remote_addr=None, *, local_addr=None, remote_addr=None, *,
family=0, proto=0, flags=0): family=0, proto=0, flags=0):

View File

@ -1,9 +1,13 @@
"""Stream-related things.""" """Stream-related things."""
__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
'open_connection', 'start_server', 'IncompleteReadError', 'open_connection', 'start_server',
'open_unix_connection', 'start_unix_server',
'IncompleteReadError',
] ]
import socket
from . import events from . import events
from . import futures from . import futures
from . import protocols from . import protocols
@ -93,6 +97,39 @@ def start_server(client_connected_cb, host=None, port=None, *,
return (yield from loop.create_server(factory, host, port, **kwds)) return (yield from loop.create_server(factory, host, port, **kwds))
if hasattr(socket, 'AF_UNIX'):
# UNIX Domain Sockets are supported on this platform
@tasks.coroutine
def open_unix_connection(path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
if loop is None:
loop = events.get_event_loop()
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, loop=loop)
transport, _ = yield from loop.create_unix_connection(
lambda: protocol, path, **kwds)
writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
@tasks.coroutine
def start_unix_server(client_connected_cb, path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `start_server` but works with UNIX Domain Sockets."""
if loop is None:
loop = events.get_event_loop()
def factory():
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop)
return protocol
return (yield from loop.create_unix_server(factory, path, **kwds))
class FlowControlMixin(protocols.Protocol): class FlowControlMixin(protocols.Protocol):
"""Reusable flow control logic for StreamWriter.drain(). """Reusable flow control logic for StreamWriter.drain().

View File

@ -4,12 +4,18 @@ import collections
import contextlib import contextlib
import io import io
import os import os
import socket
import socketserver
import sys import sys
import tempfile
import threading import threading
import time import time
import unittest import unittest
import unittest.mock import unittest.mock
from http.server import HTTPServer
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
try: try:
import ssl import ssl
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
@ -70,42 +76,51 @@ def run_once(loop):
loop.run_forever() loop.run_forever()
@contextlib.contextmanager class SilentWSGIRequestHandler(WSGIRequestHandler):
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
class SilentWSGIRequestHandler(WSGIRequestHandler): def get_stderr(self):
def get_stderr(self): return io.StringIO()
return io.StringIO()
def log_message(self, format, *args): def log_message(self, format, *args):
pass
class SilentWSGIServer(WSGIServer):
def handle_error(self, request, client_address):
pass
class SSLWSGIServerMixin:
def finish_request(self, request, client_address):
# The relative location of our test directory (which
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio.
# Prefer our own if we can find it.
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
if not os.path.isdir(here):
here = os.path.join(os.path.dirname(os.__file__),
'test', 'test_asyncio')
keyfile = os.path.join(here, 'ssl_key.pem')
certfile = os.path.join(here, 'ssl_cert.pem')
ssock = ssl.wrap_socket(request,
keyfile=keyfile,
certfile=certfile,
server_side=True)
try:
self.RequestHandlerClass(ssock, client_address, self)
ssock.close()
except OSError:
# maybe socket has been closed by peer
pass pass
class SilentWSGIServer(WSGIServer):
def handle_error(self, request, client_address):
pass
class SSLWSGIServer(SilentWSGIServer): class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
def finish_request(self, request, client_address): pass
# The relative location of our test directory (which
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio. def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
# Prefer our own if we can find it.
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
if not os.path.isdir(here):
here = os.path.join(os.path.dirname(os.__file__),
'test', 'test_asyncio')
keyfile = os.path.join(here, 'ssl_key.pem')
certfile = os.path.join(here, 'ssl_cert.pem')
ssock = ssl.wrap_socket(request,
keyfile=keyfile,
certfile=certfile,
server_side=True)
try:
self.RequestHandlerClass(ssock, client_address, self)
ssock.close()
except OSError:
# maybe socket has been closed by peer
pass
def app(environ, start_response): def app(environ, start_response):
status = '200 OK' status = '200 OK'
@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
# Run the test WSGI server in a separate thread in order not to # Run the test WSGI server in a separate thread in order not to
# interfere with event handling in the main thread # interfere with event handling in the main thread
server_class = SSLWSGIServer if use_ssl else SilentWSGIServer server_class = server_ssl_cls if use_ssl else server_cls
httpd = make_server(host, port, app, httpd = server_class(address, SilentWSGIRequestHandler)
server_class, SilentWSGIRequestHandler) httpd.set_app(app)
httpd.address = httpd.server_address httpd.address = httpd.server_address
server_thread = threading.Thread(target=httpd.serve_forever) server_thread = threading.Thread(target=httpd.serve_forever)
server_thread.start() server_thread.start()
@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
server_thread.join() server_thread.join()
if hasattr(socket, 'AF_UNIX'):
class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
def server_bind(self):
socketserver.UnixStreamServer.server_bind(self)
self.server_name = '127.0.0.1'
self.server_port = 80
class UnixWSGIServer(UnixHTTPServer, WSGIServer):
def server_bind(self):
UnixHTTPServer.server_bind(self)
self.setup_environ()
def get_request(self):
request, client_addr = super().get_request()
# Code in the stdlib expects that get_request
# will return a socket and a tuple (host, port).
# However, this isn't true for UNIX sockets,
# as the second return value will be a path;
# hence we return some fake data sufficient
# to get the tests going
return request, ('127.0.0.1', '')
class SilentUnixWSGIServer(UnixWSGIServer):
def handle_error(self, request, client_address):
pass
class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
pass
def gen_unix_socket_path():
with tempfile.NamedTemporaryFile() as file:
return file.name
@contextlib.contextmanager
def unix_socket_path():
path = gen_unix_socket_path()
try:
yield path
finally:
try:
os.unlink(path)
except OSError:
pass
@contextlib.contextmanager
def run_test_unix_server(*, use_ssl=False):
with unix_socket_path() as path:
yield from _run_test_server(address=path, use_ssl=use_ssl,
server_cls=SilentUnixWSGIServer,
server_ssl_cls=UnixSSLWSGIServer)
@contextlib.contextmanager
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
server_cls=SilentWSGIServer,
server_ssl_cls=SSLWSGIServer)
def make_test_protocol(base): def make_test_protocol(base):
dct = {} dct = {}
for name in dir(base): for name in dir(base):
@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop):
def _write_to_self(self): def _write_to_self(self):
pass pass
def MockCallback(**kwargs): def MockCallback(**kwargs):
return unittest.mock.Mock(spec=['__call__'], **kwargs) return unittest.mock.Mock(spec=['__call__'], **kwargs)

View File

@ -11,6 +11,7 @@ import sys
import threading import threading
from . import base_events
from . import base_subprocess from . import base_subprocess
from . import constants from . import constants
from . import events from . import events
@ -31,9 +32,9 @@ if sys.platform == 'win32': # pragma: no cover
class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
"""Unix event loop """Unix event loop.
Adds signal handling to SelectorEventLoop Adds signal handling and UNIX Domain Socket support to SelectorEventLoop.
""" """
def __init__(self, selector=None): def __init__(self, selector=None):
@ -164,6 +165,76 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
def _child_watcher_callback(self, pid, returncode, transp): def _child_watcher_callback(self, pid, returncode, transp):
self.call_soon_threadsafe(transp._process_exited, returncode) self.call_soon_threadsafe(transp._process_exited, returncode)
@tasks.coroutine
def create_unix_connection(self, protocol_factory, path, *,
ssl=None, sock=None,
server_hostname=None):
assert server_hostname is None or isinstance(server_hostname, str)
if ssl:
if server_hostname is None:
raise ValueError(
'you have to pass server_hostname when using ssl')
else:
if server_hostname is not None:
raise ValueError('server_hostname is only meaningful with ssl')
if path is not None:
if sock is not None:
raise ValueError(
'path and sock can not be specified at the same time')
try:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
sock.setblocking(False)
yield from self.sock_connect(sock, path)
except OSError:
if sock is not None:
sock.close()
raise
else:
if sock is None:
raise ValueError('no path and sock were specified')
sock.setblocking(False)
transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
return transport, protocol
@tasks.coroutine
def create_unix_server(self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None):
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')
if path is not None:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.bind(path)
except OSError as exc:
if exc.errno == errno.EADDRINUSE:
# Let's improve the error message by adding
# with what exact address it occurs.
msg = 'Address {!r} is already in use'.format(path)
raise OSError(errno.EADDRINUSE, msg) from None
else:
raise
else:
if sock is None:
raise ValueError(
'path was not specified, and no sock specified')
if sock.family != socket.AF_UNIX:
raise ValueError(
'A UNIX Domain Socket was expected, got {!r}'.format(sock))
server = base_events.Server(self, [sock])
sock.listen(backlog)
sock.setblocking(False)
self._start_serving(protocol_factory, sock, ssl, server)
return server
def _set_nonblocking(fd): def _set_nonblocking(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL) flags = fcntl.fcntl(fd, fcntl.F_GETFL)

View File

@ -212,7 +212,7 @@ class BaseEventLoopTests(unittest.TestCase):
idx = -1 idx = -1
data = [10.0, 10.0, 10.3, 13.0] data = [10.0, 10.0, 10.3, 13.0]
self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda:True, ())] self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, ())]
self.loop._run_once() self.loop._run_once()
self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0]) self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0])

View File

@ -39,13 +39,14 @@ def data_file(filename):
return fullname return fullname
raise FileNotFoundError(filename) raise FileNotFoundError(filename)
ONLYCERT = data_file('ssl_cert.pem') ONLYCERT = data_file('ssl_cert.pem')
ONLYKEY = data_file('ssl_key.pem') ONLYKEY = data_file('ssl_key.pem')
SIGNED_CERTFILE = data_file('keycert3.pem') SIGNED_CERTFILE = data_file('keycert3.pem')
SIGNING_CA = data_file('pycacert.pem') SIGNING_CA = data_file('pycacert.pem')
class MyProto(asyncio.Protocol): class MyBaseProto(asyncio.Protocol):
done = None done = None
def __init__(self, loop=None): def __init__(self, loop=None):
@ -59,7 +60,6 @@ class MyProto(asyncio.Protocol):
self.transport = transport self.transport = transport
assert self.state == 'INITIAL', self.state assert self.state == 'INITIAL', self.state
self.state = 'CONNECTED' self.state = 'CONNECTED'
transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
def data_received(self, data): def data_received(self, data):
assert self.state == 'CONNECTED', self.state assert self.state == 'CONNECTED', self.state
@ -76,6 +76,12 @@ class MyProto(asyncio.Protocol):
self.done.set_result(None) self.done.set_result(None)
class MyProto(MyBaseProto):
def connection_made(self, transport):
super().connection_made(transport)
transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
class MyDatagramProto(asyncio.DatagramProtocol): class MyDatagramProto(asyncio.DatagramProtocol):
done = None done = None
@ -357,22 +363,30 @@ class EventLoopTestsMixin:
r.close() r.close()
self.assertGreaterEqual(len(data), 200) self.assertGreaterEqual(len(data), 200)
def _basetest_sock_client_ops(self, httpd, sock):
sock.setblocking(False)
self.loop.run_until_complete(
self.loop.sock_connect(sock, httpd.address))
self.loop.run_until_complete(
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
data = self.loop.run_until_complete(
self.loop.sock_recv(sock, 1024))
# consume data
self.loop.run_until_complete(
self.loop.sock_recv(sock, 1024))
sock.close()
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
def test_sock_client_ops(self): def test_sock_client_ops(self):
with test_utils.run_test_server() as httpd: with test_utils.run_test_server() as httpd:
sock = socket.socket() sock = socket.socket()
sock.setblocking(False) self._basetest_sock_client_ops(httpd, sock)
self.loop.run_until_complete(
self.loop.sock_connect(sock, httpd.address))
self.loop.run_until_complete(
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
data = self.loop.run_until_complete(
self.loop.sock_recv(sock, 1024))
# consume data
self.loop.run_until_complete(
self.loop.sock_recv(sock, 1024))
sock.close()
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_unix_sock_client_ops(self):
with test_utils.run_test_unix_server() as httpd:
sock = socket.socket(socket.AF_UNIX)
self._basetest_sock_client_ops(httpd, sock)
def test_sock_client_fail(self): def test_sock_client_fail(self):
# Make sure that we will get an unused port # Make sure that we will get an unused port
@ -485,16 +499,26 @@ class EventLoopTestsMixin:
self.loop.run_forever() self.loop.run_forever()
self.assertEqual(caught, 1) self.assertEqual(caught, 1)
def _basetest_create_connection(self, connection_fut):
tr, pr = self.loop.run_until_complete(connection_fut)
self.assertIsInstance(tr, asyncio.Transport)
self.assertIsInstance(pr, asyncio.Protocol)
self.loop.run_until_complete(pr.done)
self.assertGreater(pr.nbytes, 0)
tr.close()
def test_create_connection(self): def test_create_connection(self):
with test_utils.run_test_server() as httpd: with test_utils.run_test_server() as httpd:
f = self.loop.create_connection( conn_fut = self.loop.create_connection(
lambda: MyProto(loop=self.loop), *httpd.address) lambda: MyProto(loop=self.loop), *httpd.address)
tr, pr = self.loop.run_until_complete(f) self._basetest_create_connection(conn_fut)
self.assertIsInstance(tr, asyncio.Transport)
self.assertIsInstance(pr, asyncio.Protocol) @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
self.loop.run_until_complete(pr.done) def test_create_unix_connection(self):
self.assertGreater(pr.nbytes, 0) with test_utils.run_test_unix_server() as httpd:
tr.close() conn_fut = self.loop.create_unix_connection(
lambda: MyProto(loop=self.loop), httpd.address)
self._basetest_create_connection(conn_fut)
def test_create_connection_sock(self): def test_create_connection_sock(self):
with test_utils.run_test_server() as httpd: with test_utils.run_test_server() as httpd:
@ -524,20 +548,37 @@ class EventLoopTestsMixin:
self.assertGreater(pr.nbytes, 0) self.assertGreater(pr.nbytes, 0)
tr.close() tr.close()
def _basetest_create_ssl_connection(self, connection_fut):
tr, pr = self.loop.run_until_complete(connection_fut)
self.assertIsInstance(tr, asyncio.Transport)
self.assertIsInstance(pr, asyncio.Protocol)
self.assertTrue('ssl' in tr.__class__.__name__.lower())
self.assertIsNotNone(tr.get_extra_info('sockname'))
self.loop.run_until_complete(pr.done)
self.assertGreater(pr.nbytes, 0)
tr.close()
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
def test_create_ssl_connection(self): def test_create_ssl_connection(self):
with test_utils.run_test_server(use_ssl=True) as httpd: with test_utils.run_test_server(use_ssl=True) as httpd:
f = self.loop.create_connection( conn_fut = self.loop.create_connection(
lambda: MyProto(loop=self.loop), *httpd.address, lambda: MyProto(loop=self.loop),
*httpd.address,
ssl=test_utils.dummy_ssl_context()) ssl=test_utils.dummy_ssl_context())
tr, pr = self.loop.run_until_complete(f)
self.assertIsInstance(tr, asyncio.Transport) self._basetest_create_ssl_connection(conn_fut)
self.assertIsInstance(pr, asyncio.Protocol)
self.assertTrue('ssl' in tr.__class__.__name__.lower()) @unittest.skipIf(ssl is None, 'No ssl module')
self.assertIsNotNone(tr.get_extra_info('sockname')) @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
self.loop.run_until_complete(pr.done) def test_create_ssl_unix_connection(self):
self.assertGreater(pr.nbytes, 0) with test_utils.run_test_unix_server(use_ssl=True) as httpd:
tr.close() conn_fut = self.loop.create_unix_connection(
lambda: MyProto(loop=self.loop),
httpd.address,
ssl=test_utils.dummy_ssl_context(),
server_hostname='127.0.0.1')
self._basetest_create_ssl_connection(conn_fut)
def test_create_connection_local_addr(self): def test_create_connection_local_addr(self):
with test_utils.run_test_server() as httpd: with test_utils.run_test_server() as httpd:
@ -561,14 +602,8 @@ class EventLoopTestsMixin:
self.assertIn(str(httpd.address), cm.exception.strerror) self.assertIn(str(httpd.address), cm.exception.strerror)
def test_create_server(self): def test_create_server(self):
proto = None proto = MyProto()
f = self.loop.create_server(lambda: proto, '0.0.0.0', 0)
def factory():
nonlocal proto
proto = MyProto()
return proto
f = self.loop.create_server(factory, '0.0.0.0', 0)
server = self.loop.run_until_complete(f) server = self.loop.run_until_complete(f)
self.assertEqual(len(server.sockets), 1) self.assertEqual(len(server.sockets), 1)
sock = server.sockets[0] sock = server.sockets[0]
@ -605,38 +640,76 @@ class EventLoopTestsMixin:
# close server # close server
server.close() server.close()
def _make_ssl_server(self, factory, certfile, keyfile=None): def _make_unix_server(self, factory, **kwargs):
path = test_utils.gen_unix_socket_path()
self.addCleanup(lambda: os.path.exists(path) and os.unlink(path))
f = self.loop.create_unix_server(factory, path, **kwargs)
server = self.loop.run_until_complete(f)
return server, path
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server(self):
proto = MyProto()
server, path = self._make_unix_server(lambda: proto)
self.assertEqual(len(server.sockets), 1)
client = socket.socket(socket.AF_UNIX)
client.connect(path)
client.sendall(b'xxx')
test_utils.run_briefly(self.loop)
test_utils.run_until(self.loop, lambda: proto is not None, 10)
self.assertIsInstance(proto, MyProto)
self.assertEqual('INITIAL', proto.state)
test_utils.run_briefly(self.loop)
self.assertEqual('CONNECTED', proto.state)
test_utils.run_until(self.loop, lambda: proto.nbytes > 0,
timeout=10)
self.assertEqual(3, proto.nbytes)
# close connection
proto.transport.close()
test_utils.run_briefly(self.loop) # windows iocp
self.assertEqual('CLOSED', proto.state)
# the client socket must be closed after to avoid ECONNRESET upon
# recv()/send() on the serving socket
client.close()
# close server
server.close()
def _create_ssl_context(self, certfile, keyfile=None):
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.options |= ssl.OP_NO_SSLv2 sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.load_cert_chain(certfile, keyfile) sslcontext.load_cert_chain(certfile, keyfile)
return sslcontext
f = self.loop.create_server( def _make_ssl_server(self, factory, certfile, keyfile=None):
factory, '127.0.0.1', 0, ssl=sslcontext) sslcontext = self._create_ssl_context(certfile, keyfile)
f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext)
server = self.loop.run_until_complete(f) server = self.loop.run_until_complete(f)
sock = server.sockets[0] sock = server.sockets[0]
host, port = sock.getsockname() host, port = sock.getsockname()
self.assertEqual(host, '127.0.0.1') self.assertEqual(host, '127.0.0.1')
return server, host, port return server, host, port
def _make_ssl_unix_server(self, factory, certfile, keyfile=None):
sslcontext = self._create_ssl_context(certfile, keyfile)
return self._make_unix_server(factory, ssl=sslcontext)
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
def test_create_server_ssl(self): def test_create_server_ssl(self):
proto = None proto = MyProto(loop=self.loop)
server, host, port = self._make_ssl_server(
lambda: proto, ONLYCERT, ONLYKEY)
class ClientMyProto(MyProto): f_c = self.loop.create_connection(MyBaseProto, host, port,
def connection_made(self, transport):
self.transport = transport
assert self.state == 'INITIAL', self.state
self.state = 'CONNECTED'
def factory():
nonlocal proto
proto = MyProto(loop=self.loop)
return proto
server, host, port = self._make_ssl_server(factory, ONLYCERT, ONLYKEY)
f_c = self.loop.create_connection(ClientMyProto, host, port,
ssl=test_utils.dummy_ssl_context()) ssl=test_utils.dummy_ssl_context())
client, pr = self.loop.run_until_complete(f_c) client, pr = self.loop.run_until_complete(f_c)
@ -666,17 +739,46 @@ class EventLoopTestsMixin:
# stop serving # stop serving
server.close() server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server_ssl(self):
proto = MyProto(loop=self.loop)
server, path = self._make_ssl_unix_server(
lambda: proto, ONLYCERT, ONLYKEY)
f_c = self.loop.create_unix_connection(
MyBaseProto, path, ssl=test_utils.dummy_ssl_context(),
server_hostname='')
client, pr = self.loop.run_until_complete(f_c)
client.write(b'xxx')
test_utils.run_briefly(self.loop)
self.assertIsInstance(proto, MyProto)
test_utils.run_briefly(self.loop)
self.assertEqual('CONNECTED', proto.state)
test_utils.run_until(self.loop, lambda: proto.nbytes > 0,
timeout=10)
self.assertEqual(3, proto.nbytes)
# close connection
proto.transport.close()
self.loop.run_until_complete(proto.done)
self.assertEqual('CLOSED', proto.state)
# the client socket must be closed after to avoid ECONNRESET upon
# recv()/send() on the serving socket
client.close()
# stop serving
server.close()
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
def test_create_server_ssl_verify_failed(self): def test_create_server_ssl_verify_failed(self):
proto = None proto = MyProto(loop=self.loop)
server, host, port = self._make_ssl_server(
def factory(): lambda: proto, SIGNED_CERTFILE)
nonlocal proto
proto = MyProto(loop=self.loop)
return proto
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.options |= ssl.OP_NO_SSLv2
@ -695,17 +797,38 @@ class EventLoopTestsMixin:
self.assertIsNone(proto.transport) self.assertIsNone(proto.transport)
server.close() server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server_ssl_verify_failed(self):
proto = MyProto(loop=self.loop)
server, path = self._make_ssl_unix_server(
lambda: proto, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
if hasattr(sslcontext_client, 'check_hostname'):
sslcontext_client.check_hostname = True
# no CA loaded
f_c = self.loop.create_unix_connection(MyProto, path,
ssl=sslcontext_client,
server_hostname='invalid')
with self.assertRaisesRegex(ssl.SSLError,
'certificate verify failed '):
self.loop.run_until_complete(f_c)
# close connection
self.assertIsNone(proto.transport)
server.close()
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
def test_create_server_ssl_match_failed(self): def test_create_server_ssl_match_failed(self):
proto = None proto = MyProto(loop=self.loop)
server, host, port = self._make_ssl_server(
def factory(): lambda: proto, SIGNED_CERTFILE)
nonlocal proto
proto = MyProto(loop=self.loop)
return proto
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.options |= ssl.OP_NO_SSLv2
@ -727,17 +850,38 @@ class EventLoopTestsMixin:
proto.transport.close() proto.transport.close()
server.close() server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server_ssl_verified(self):
proto = MyProto(loop=self.loop)
server, path = self._make_ssl_unix_server(
lambda: proto, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
if hasattr(sslcontext_client, 'check_hostname'):
sslcontext_client.check_hostname = True
# Connection succeeds with correct CA and server hostname.
f_c = self.loop.create_unix_connection(MyProto, path,
ssl=sslcontext_client,
server_hostname='localhost')
client, pr = self.loop.run_until_complete(f_c)
# close connection
proto.transport.close()
client.close()
server.close()
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
def test_create_server_ssl_verified(self): def test_create_server_ssl_verified(self):
proto = None proto = MyProto(loop=self.loop)
server, host, port = self._make_ssl_server(
def factory(): lambda: proto, SIGNED_CERTFILE)
nonlocal proto
proto = MyProto(loop=self.loop)
return proto
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.options |= ssl.OP_NO_SSLv2
@ -915,19 +1059,15 @@ class EventLoopTestsMixin:
@unittest.skipUnless(sys.platform != 'win32', @unittest.skipUnless(sys.platform != 'win32',
"Don't support pipes for Windows") "Don't support pipes for Windows")
def test_read_pipe(self): def test_read_pipe(self):
proto = None proto = MyReadPipeProto(loop=self.loop)
def factory():
nonlocal proto
proto = MyReadPipeProto(loop=self.loop)
return proto
rpipe, wpipe = os.pipe() rpipe, wpipe = os.pipe()
pipeobj = io.open(rpipe, 'rb', 1024) pipeobj = io.open(rpipe, 'rb', 1024)
@asyncio.coroutine @asyncio.coroutine
def connect(): def connect():
t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) t, p = yield from self.loop.connect_read_pipe(
lambda: proto, pipeobj)
self.assertIs(p, proto) self.assertIs(p, proto)
self.assertIs(t, proto.transport) self.assertIs(t, proto.transport)
self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
@ -959,19 +1099,14 @@ class EventLoopTestsMixin:
# Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9 # Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9
@support.requires_freebsd_version(8) @support.requires_freebsd_version(8)
def test_read_pty_output(self): def test_read_pty_output(self):
proto = None proto = MyReadPipeProto(loop=self.loop)
def factory():
nonlocal proto
proto = MyReadPipeProto(loop=self.loop)
return proto
master, slave = os.openpty() master, slave = os.openpty()
master_read_obj = io.open(master, 'rb', 0) master_read_obj = io.open(master, 'rb', 0)
@asyncio.coroutine @asyncio.coroutine
def connect(): def connect():
t, p = yield from self.loop.connect_read_pipe(factory, t, p = yield from self.loop.connect_read_pipe(lambda: proto,
master_read_obj) master_read_obj)
self.assertIs(p, proto) self.assertIs(p, proto)
self.assertIs(t, proto.transport) self.assertIs(t, proto.transport)
@ -999,21 +1134,17 @@ class EventLoopTestsMixin:
@unittest.skipUnless(sys.platform != 'win32', @unittest.skipUnless(sys.platform != 'win32',
"Don't support pipes for Windows") "Don't support pipes for Windows")
def test_write_pipe(self): def test_write_pipe(self):
proto = None proto = MyWritePipeProto(loop=self.loop)
transport = None transport = None
def factory():
nonlocal proto
proto = MyWritePipeProto(loop=self.loop)
return proto
rpipe, wpipe = os.pipe() rpipe, wpipe = os.pipe()
pipeobj = io.open(wpipe, 'wb', 1024) pipeobj = io.open(wpipe, 'wb', 1024)
@asyncio.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal transport nonlocal transport
t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) t, p = yield from self.loop.connect_write_pipe(
lambda: proto, pipeobj)
self.assertIs(p, proto) self.assertIs(p, proto)
self.assertIs(t, proto.transport) self.assertIs(t, proto.transport)
self.assertEqual('CONNECTED', proto.state) self.assertEqual('CONNECTED', proto.state)
@ -1045,21 +1176,16 @@ class EventLoopTestsMixin:
@unittest.skipUnless(sys.platform != 'win32', @unittest.skipUnless(sys.platform != 'win32',
"Don't support pipes for Windows") "Don't support pipes for Windows")
def test_write_pipe_disconnect_on_close(self): def test_write_pipe_disconnect_on_close(self):
proto = None proto = MyWritePipeProto(loop=self.loop)
transport = None transport = None
def factory():
nonlocal proto
proto = MyWritePipeProto(loop=self.loop)
return proto
rsock, wsock = test_utils.socketpair() rsock, wsock = test_utils.socketpair()
pipeobj = io.open(wsock.detach(), 'wb', 1024) pipeobj = io.open(wsock.detach(), 'wb', 1024)
@asyncio.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal transport nonlocal transport
t, p = yield from self.loop.connect_write_pipe(factory, t, p = yield from self.loop.connect_write_pipe(lambda: proto,
pipeobj) pipeobj)
self.assertIs(p, proto) self.assertIs(p, proto)
self.assertIs(t, proto.transport) self.assertIs(t, proto.transport)
@ -1084,21 +1210,16 @@ class EventLoopTestsMixin:
# older than 10.6 (Snow Leopard) # older than 10.6 (Snow Leopard)
@support.requires_mac_ver(10, 6) @support.requires_mac_ver(10, 6)
def test_write_pty(self): def test_write_pty(self):
proto = None proto = MyWritePipeProto(loop=self.loop)
transport = None transport = None
def factory():
nonlocal proto
proto = MyWritePipeProto(loop=self.loop)
return proto
master, slave = os.openpty() master, slave = os.openpty()
slave_write_obj = io.open(slave, 'wb', 0) slave_write_obj = io.open(slave, 'wb', 0)
@asyncio.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal transport nonlocal transport
t, p = yield from self.loop.connect_write_pipe(factory, t, p = yield from self.loop.connect_write_pipe(lambda: proto,
slave_write_obj) slave_write_obj)
self.assertIs(p, proto) self.assertIs(p, proto)
self.assertIs(t, proto.transport) self.assertIs(t, proto.transport)

View File

@ -55,7 +55,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.remove_reader = unittest.mock.Mock() self.loop.remove_reader = unittest.mock.Mock()
self.loop.remove_writer = unittest.mock.Mock() self.loop.remove_writer = unittest.mock.Mock()
waiter = asyncio.Future(loop=self.loop) waiter = asyncio.Future(loop=self.loop)
transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter) transport = self.loop._make_ssl_transport(
m, asyncio.Protocol(), m, waiter)
self.assertIsInstance(transport, _SelectorSslTransport) self.assertIsInstance(transport, _SelectorSslTransport)
@unittest.mock.patch('asyncio.selector_events.ssl', None) @unittest.mock.patch('asyncio.selector_events.ssl', None)

View File

@ -1,6 +1,8 @@
"""Tests for streams.py.""" """Tests for streams.py."""
import functools
import gc import gc
import socket
import unittest import unittest
import unittest.mock import unittest.mock
try: try:
@ -32,48 +34,85 @@ class StreamReaderTests(unittest.TestCase):
stream = asyncio.StreamReader() stream = asyncio.StreamReader()
self.assertIs(stream._loop, m_events.get_event_loop.return_value) self.assertIs(stream._loop, m_events.get_event_loop.return_value)
def _basetest_open_connection(self, open_connection_fut):
reader, writer = self.loop.run_until_complete(open_connection_fut)
writer.write(b'GET / HTTP/1.0\r\n\r\n')
f = reader.readline()
data = self.loop.run_until_complete(f)
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
f = reader.read()
data = self.loop.run_until_complete(f)
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
writer.close()
def test_open_connection(self): def test_open_connection(self):
with test_utils.run_test_server() as httpd: with test_utils.run_test_server() as httpd:
f = asyncio.open_connection(*httpd.address, loop=self.loop) conn_fut = asyncio.open_connection(*httpd.address,
reader, writer = self.loop.run_until_complete(f) loop=self.loop)
writer.write(b'GET / HTTP/1.0\r\n\r\n') self._basetest_open_connection(conn_fut)
f = reader.readline()
data = self.loop.run_until_complete(f)
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
f = reader.read()
data = self.loop.run_until_complete(f)
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
writer.close() @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_open_unix_connection(self):
with test_utils.run_test_unix_server() as httpd:
conn_fut = asyncio.open_unix_connection(httpd.address,
loop=self.loop)
self._basetest_open_connection(conn_fut)
def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
try:
reader, writer = self.loop.run_until_complete(open_connection_fut)
finally:
asyncio.set_event_loop(None)
writer.write(b'GET / HTTP/1.0\r\n\r\n')
f = reader.read()
data = self.loop.run_until_complete(f)
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
writer.close()
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
def test_open_connection_no_loop_ssl(self): def test_open_connection_no_loop_ssl(self):
with test_utils.run_test_server(use_ssl=True) as httpd: with test_utils.run_test_server(use_ssl=True) as httpd:
try: conn_fut = asyncio.open_connection(
asyncio.set_event_loop(self.loop) *httpd.address,
f = asyncio.open_connection(*httpd.address, ssl=test_utils.dummy_ssl_context(),
ssl=test_utils.dummy_ssl_context()) loop=self.loop)
reader, writer = self.loop.run_until_complete(f)
finally:
asyncio.set_event_loop(None)
writer.write(b'GET / HTTP/1.0\r\n\r\n')
f = reader.read()
data = self.loop.run_until_complete(f)
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
writer.close() self._basetest_open_connection_no_loop_ssl(conn_fut)
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_open_unix_connection_no_loop_ssl(self):
with test_utils.run_test_unix_server(use_ssl=True) as httpd:
conn_fut = asyncio.open_unix_connection(
httpd.address,
ssl=test_utils.dummy_ssl_context(),
server_hostname='',
loop=self.loop)
self._basetest_open_connection_no_loop_ssl(conn_fut)
def _basetest_open_connection_error(self, open_connection_fut):
reader, writer = self.loop.run_until_complete(open_connection_fut)
writer._protocol.connection_lost(ZeroDivisionError())
f = reader.read()
with self.assertRaises(ZeroDivisionError):
self.loop.run_until_complete(f)
writer.close()
test_utils.run_briefly(self.loop)
def test_open_connection_error(self): def test_open_connection_error(self):
with test_utils.run_test_server() as httpd: with test_utils.run_test_server() as httpd:
f = asyncio.open_connection(*httpd.address, loop=self.loop) conn_fut = asyncio.open_connection(*httpd.address,
reader, writer = self.loop.run_until_complete(f) loop=self.loop)
writer._protocol.connection_lost(ZeroDivisionError()) self._basetest_open_connection_error(conn_fut)
f = reader.read()
with self.assertRaises(ZeroDivisionError):
self.loop.run_until_complete(f)
writer.close() @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
test_utils.run_briefly(self.loop) def test_open_unix_connection_error(self):
with test_utils.run_test_unix_server() as httpd:
conn_fut = asyncio.open_unix_connection(httpd.address,
loop=self.loop)
self._basetest_open_connection_error(conn_fut)
def test_feed_empty_data(self): def test_feed_empty_data(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
@ -415,10 +454,86 @@ class StreamReaderTests(unittest.TestCase):
client_writer.write(data) client_writer.write(data)
def start(self): def start(self):
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
self.server = self.loop.run_until_complete( self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client, asyncio.start_server(self.handle_client,
'127.0.0.1', 12345, sock=sock,
loop=self.loop)) loop=self.loop))
return sock.getsockname()
def handle_client_callback(self, client_reader, client_writer):
task = asyncio.Task(client_reader.readline(), loop=self.loop)
def done(task):
client_writer.write(task.result())
task.add_done_callback(done)
def start_callback(self):
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
addr = sock.getsockname()
sock.close()
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client_callback,
host=addr[0], port=addr[1],
loop=self.loop))
return addr
def stop(self):
if self.server is not None:
self.server.close()
self.loop.run_until_complete(self.server.wait_closed())
self.server = None
@asyncio.coroutine
def client(addr):
reader, writer = yield from asyncio.open_connection(
*addr, loop=self.loop)
# send a line
writer.write(b"hello world!\n")
# read it back
msgback = yield from reader.readline()
writer.close()
return msgback
# test the server variant with a coroutine as client handler
server = MyServer(self.loop)
addr = server.start()
msg = self.loop.run_until_complete(asyncio.Task(client(addr),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
# test the server variant with a callback as client handler
server = MyServer(self.loop)
addr = server.start_callback()
msg = self.loop.run_until_complete(asyncio.Task(client(addr),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_start_unix_server(self):
class MyServer:
def __init__(self, loop, path):
self.server = None
self.loop = loop
self.path = path
@asyncio.coroutine
def handle_client(self, client_reader, client_writer):
data = yield from client_reader.readline()
client_writer.write(data)
def start(self):
self.server = self.loop.run_until_complete(
asyncio.start_unix_server(self.handle_client,
path=self.path,
loop=self.loop))
def handle_client_callback(self, client_reader, client_writer): def handle_client_callback(self, client_reader, client_writer):
task = asyncio.Task(client_reader.readline(), loop=self.loop) task = asyncio.Task(client_reader.readline(), loop=self.loop)
@ -430,9 +545,9 @@ class StreamReaderTests(unittest.TestCase):
def start_callback(self): def start_callback(self):
self.server = self.loop.run_until_complete( self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client_callback, asyncio.start_unix_server(self.handle_client_callback,
'127.0.0.1', 12345, path=self.path,
loop=self.loop)) loop=self.loop))
def stop(self): def stop(self):
if self.server is not None: if self.server is not None:
@ -441,9 +556,9 @@ class StreamReaderTests(unittest.TestCase):
self.server = None self.server = None
@asyncio.coroutine @asyncio.coroutine
def client(): def client(path):
reader, writer = yield from asyncio.open_connection( reader, writer = yield from asyncio.open_unix_connection(
'127.0.0.1', 12345, loop=self.loop) path, loop=self.loop)
# send a line # send a line
writer.write(b"hello world!\n") writer.write(b"hello world!\n")
# read it back # read it back
@ -452,20 +567,22 @@ class StreamReaderTests(unittest.TestCase):
return msgback return msgback
# test the server variant with a coroutine as client handler # test the server variant with a coroutine as client handler
server = MyServer(self.loop) with test_utils.unix_socket_path() as path:
server.start() server = MyServer(self.loop, path)
msg = self.loop.run_until_complete(asyncio.Task(client(), server.start()
loop=self.loop)) msg = self.loop.run_until_complete(asyncio.Task(client(path),
server.stop() loop=self.loop))
self.assertEqual(msg, b"hello world!\n") server.stop()
self.assertEqual(msg, b"hello world!\n")
# test the server variant with a callback as client handler # test the server variant with a callback as client handler
server = MyServer(self.loop) with test_utils.unix_socket_path() as path:
server.start_callback() server = MyServer(self.loop, path)
msg = self.loop.run_until_complete(asyncio.Task(client(), server.start_callback()
loop=self.loop)) msg = self.loop.run_until_complete(asyncio.Task(client(path),
server.stop() loop=self.loop))
self.assertEqual(msg, b"hello world!\n") server.stop()
self.assertEqual(msg, b"hello world!\n")
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -7,8 +7,10 @@ import io
import os import os
import pprint import pprint
import signal import signal
import socket
import stat import stat
import sys import sys
import tempfile
import threading import threading
import unittest import unittest
import unittest.mock import unittest.mock
@ -24,7 +26,7 @@ from asyncio import unix_events
@unittest.skipUnless(signal, 'Signals are not supported') @unittest.skipUnless(signal, 'Signals are not supported')
class SelectorEventLoopTests(unittest.TestCase): class SelectorEventLoopSignalTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.SelectorEventLoop() self.loop = asyncio.SelectorEventLoop()
@ -200,6 +202,84 @@ class SelectorEventLoopTests(unittest.TestCase):
m_signal.set_wakeup_fd.assert_called_once_with(-1) m_signal.set_wakeup_fd.assert_called_once_with(-1)
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
'UNIX Sockets are not supported')
class SelectorEventLoopUnixSocketTests(unittest.TestCase):
def setUp(self):
self.loop = asyncio.SelectorEventLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_create_unix_server_existing_path_sock(self):
with test_utils.unix_socket_path() as path:
sock = socket.socket(socket.AF_UNIX)
sock.bind(path)
coro = self.loop.create_unix_server(lambda: None, path)
with self.assertRaisesRegexp(OSError,
'Address.*is already in use'):
self.loop.run_until_complete(coro)
def test_create_unix_server_existing_path_nonsock(self):
with tempfile.NamedTemporaryFile() as file:
coro = self.loop.create_unix_server(lambda: None, file.name)
with self.assertRaisesRegexp(OSError,
'Address.*is already in use'):
self.loop.run_until_complete(coro)
def test_create_unix_server_ssl_bool(self):
coro = self.loop.create_unix_server(lambda: None, path='spam',
ssl=True)
with self.assertRaisesRegex(TypeError,
'ssl argument must be an SSLContext'):
self.loop.run_until_complete(coro)
def test_create_unix_server_nopath_nosock(self):
coro = self.loop.create_unix_server(lambda: None, path=None)
with self.assertRaisesRegex(ValueError,
'path was not specified, and no sock'):
self.loop.run_until_complete(coro)
def test_create_unix_server_path_inetsock(self):
coro = self.loop.create_unix_server(lambda: None, path=None,
sock=socket.socket())
with self.assertRaisesRegex(ValueError,
'A UNIX Domain Socket was expected'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_path_sock(self):
coro = self.loop.create_unix_connection(
lambda: None, '/dev/null', sock=object())
with self.assertRaisesRegex(ValueError, 'path and sock can not be'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_nopath_nosock(self):
coro = self.loop.create_unix_connection(
lambda: None, None)
with self.assertRaisesRegex(ValueError,
'no path and sock were specified'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_nossl_serverhost(self):
coro = self.loop.create_unix_connection(
lambda: None, '/dev/null', server_hostname='spam')
with self.assertRaisesRegex(ValueError,
'server_hostname is only meaningful'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_ssl_noserverhost(self):
coro = self.loop.create_unix_connection(
lambda: None, '/dev/null', ssl=True)
with self.assertRaisesRegexp(
ValueError, 'you have to pass server_hostname when using ssl'):
self.loop.run_until_complete(coro)
class UnixReadPipeTransportTests(unittest.TestCase): class UnixReadPipeTransportTests(unittest.TestCase):
def setUp(self): def setUp(self):