mirror of https://github.com/python/cpython
asyncio: Add support for UNIX Domain Sockets.
This commit is contained in:
parent
c36e504c53
commit
88a5bf0b2e
|
@ -407,6 +407,13 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
|
||||
sock.setblocking(False)
|
||||
|
||||
transport, protocol = yield from self._create_connection_transport(
|
||||
sock, protocol_factory, ssl, server_hostname)
|
||||
return transport, protocol
|
||||
|
||||
@tasks.coroutine
|
||||
def _create_connection_transport(self, sock, protocol_factory, ssl,
|
||||
server_hostname):
|
||||
protocol = protocol_factory()
|
||||
waiter = futures.Future(loop=self)
|
||||
if ssl:
|
||||
|
|
|
@ -220,6 +220,32 @@ class AbstractEventLoop:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_unix_connection(self, protocol_factory, path, *,
|
||||
ssl=None, sock=None,
|
||||
server_hostname=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def create_unix_server(self, protocol_factory, path, *,
|
||||
sock=None, backlog=100, ssl=None):
|
||||
"""A coroutine which creates a UNIX Domain Socket server.
|
||||
|
||||
The return valud is a Server object, which can be used to stop
|
||||
the service.
|
||||
|
||||
path is a str, representing a file systsem path to bind the
|
||||
server socket to.
|
||||
|
||||
sock can optionally be specified in order to use a preexisting
|
||||
socket object.
|
||||
|
||||
backlog is the maximum number of queued connections passed to
|
||||
listen() (defaults to 100).
|
||||
|
||||
ssl can be set to an SSLContext to enable SSL over the
|
||||
accepted connections.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_datagram_endpoint(self, protocol_factory,
|
||||
local_addr=None, remote_addr=None, *,
|
||||
family=0, proto=0, flags=0):
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
"""Stream-related things."""
|
||||
|
||||
__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
|
||||
'open_connection', 'start_server', 'IncompleteReadError',
|
||||
'open_connection', 'start_server',
|
||||
'open_unix_connection', 'start_unix_server',
|
||||
'IncompleteReadError',
|
||||
]
|
||||
|
||||
import socket
|
||||
|
||||
from . import events
|
||||
from . import futures
|
||||
from . import protocols
|
||||
|
@ -93,6 +97,39 @@ def start_server(client_connected_cb, host=None, port=None, *,
|
|||
return (yield from loop.create_server(factory, host, port, **kwds))
|
||||
|
||||
|
||||
if hasattr(socket, 'AF_UNIX'):
|
||||
# UNIX Domain Sockets are supported on this platform
|
||||
|
||||
@tasks.coroutine
|
||||
def open_unix_connection(path=None, *,
|
||||
loop=None, limit=_DEFAULT_LIMIT, **kwds):
|
||||
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
|
||||
if loop is None:
|
||||
loop = events.get_event_loop()
|
||||
reader = StreamReader(limit=limit, loop=loop)
|
||||
protocol = StreamReaderProtocol(reader, loop=loop)
|
||||
transport, _ = yield from loop.create_unix_connection(
|
||||
lambda: protocol, path, **kwds)
|
||||
writer = StreamWriter(transport, protocol, reader, loop)
|
||||
return reader, writer
|
||||
|
||||
|
||||
@tasks.coroutine
|
||||
def start_unix_server(client_connected_cb, path=None, *,
|
||||
loop=None, limit=_DEFAULT_LIMIT, **kwds):
|
||||
"""Similar to `start_server` but works with UNIX Domain Sockets."""
|
||||
if loop is None:
|
||||
loop = events.get_event_loop()
|
||||
|
||||
def factory():
|
||||
reader = StreamReader(limit=limit, loop=loop)
|
||||
protocol = StreamReaderProtocol(reader, client_connected_cb,
|
||||
loop=loop)
|
||||
return protocol
|
||||
|
||||
return (yield from loop.create_unix_server(factory, path, **kwds))
|
||||
|
||||
|
||||
class FlowControlMixin(protocols.Protocol):
|
||||
"""Reusable flow control logic for StreamWriter.drain().
|
||||
|
||||
|
|
|
@ -4,12 +4,18 @@ import collections
|
|||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import socket
|
||||
import socketserver
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
from http.server import HTTPServer
|
||||
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError: # pragma: no cover
|
||||
|
@ -70,42 +76,51 @@ def run_once(loop):
|
|||
loop.run_forever()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
|
||||
class SilentWSGIRequestHandler(WSGIRequestHandler):
|
||||
|
||||
class SilentWSGIRequestHandler(WSGIRequestHandler):
|
||||
def get_stderr(self):
|
||||
return io.StringIO()
|
||||
def get_stderr(self):
|
||||
return io.StringIO()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
def log_message(self, format, *args):
|
||||
pass
|
||||
|
||||
|
||||
class SilentWSGIServer(WSGIServer):
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
pass
|
||||
|
||||
|
||||
class SSLWSGIServerMixin:
|
||||
|
||||
def finish_request(self, request, client_address):
|
||||
# The relative location of our test directory (which
|
||||
# contains the ssl key and certificate files) differs
|
||||
# between the stdlib and stand-alone asyncio.
|
||||
# Prefer our own if we can find it.
|
||||
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
|
||||
if not os.path.isdir(here):
|
||||
here = os.path.join(os.path.dirname(os.__file__),
|
||||
'test', 'test_asyncio')
|
||||
keyfile = os.path.join(here, 'ssl_key.pem')
|
||||
certfile = os.path.join(here, 'ssl_cert.pem')
|
||||
ssock = ssl.wrap_socket(request,
|
||||
keyfile=keyfile,
|
||||
certfile=certfile,
|
||||
server_side=True)
|
||||
try:
|
||||
self.RequestHandlerClass(ssock, client_address, self)
|
||||
ssock.close()
|
||||
except OSError:
|
||||
# maybe socket has been closed by peer
|
||||
pass
|
||||
|
||||
class SilentWSGIServer(WSGIServer):
|
||||
def handle_error(self, request, client_address):
|
||||
pass
|
||||
|
||||
class SSLWSGIServer(SilentWSGIServer):
|
||||
def finish_request(self, request, client_address):
|
||||
# The relative location of our test directory (which
|
||||
# contains the ssl key and certificate files) differs
|
||||
# between the stdlib and stand-alone asyncio.
|
||||
# Prefer our own if we can find it.
|
||||
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
|
||||
if not os.path.isdir(here):
|
||||
here = os.path.join(os.path.dirname(os.__file__),
|
||||
'test', 'test_asyncio')
|
||||
keyfile = os.path.join(here, 'ssl_key.pem')
|
||||
certfile = os.path.join(here, 'ssl_cert.pem')
|
||||
ssock = ssl.wrap_socket(request,
|
||||
keyfile=keyfile,
|
||||
certfile=certfile,
|
||||
server_side=True)
|
||||
try:
|
||||
self.RequestHandlerClass(ssock, client_address, self)
|
||||
ssock.close()
|
||||
except OSError:
|
||||
# maybe socket has been closed by peer
|
||||
pass
|
||||
class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
|
||||
pass
|
||||
|
||||
|
||||
def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
|
||||
|
||||
def app(environ, start_response):
|
||||
status = '200 OK'
|
||||
|
@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
|
|||
|
||||
# Run the test WSGI server in a separate thread in order not to
|
||||
# interfere with event handling in the main thread
|
||||
server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
|
||||
httpd = make_server(host, port, app,
|
||||
server_class, SilentWSGIRequestHandler)
|
||||
server_class = server_ssl_cls if use_ssl else server_cls
|
||||
httpd = server_class(address, SilentWSGIRequestHandler)
|
||||
httpd.set_app(app)
|
||||
httpd.address = httpd.server_address
|
||||
server_thread = threading.Thread(target=httpd.serve_forever)
|
||||
server_thread.start()
|
||||
|
@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
|
|||
server_thread.join()
|
||||
|
||||
|
||||
if hasattr(socket, 'AF_UNIX'):
|
||||
|
||||
class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
|
||||
|
||||
def server_bind(self):
|
||||
socketserver.UnixStreamServer.server_bind(self)
|
||||
self.server_name = '127.0.0.1'
|
||||
self.server_port = 80
|
||||
|
||||
|
||||
class UnixWSGIServer(UnixHTTPServer, WSGIServer):
|
||||
|
||||
def server_bind(self):
|
||||
UnixHTTPServer.server_bind(self)
|
||||
self.setup_environ()
|
||||
|
||||
def get_request(self):
|
||||
request, client_addr = super().get_request()
|
||||
# Code in the stdlib expects that get_request
|
||||
# will return a socket and a tuple (host, port).
|
||||
# However, this isn't true for UNIX sockets,
|
||||
# as the second return value will be a path;
|
||||
# hence we return some fake data sufficient
|
||||
# to get the tests going
|
||||
return request, ('127.0.0.1', '')
|
||||
|
||||
|
||||
class SilentUnixWSGIServer(UnixWSGIServer):
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
pass
|
||||
|
||||
|
||||
class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
|
||||
pass
|
||||
|
||||
|
||||
def gen_unix_socket_path():
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
return file.name
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def unix_socket_path():
|
||||
path = gen_unix_socket_path()
|
||||
try:
|
||||
yield path
|
||||
finally:
|
||||
try:
|
||||
os.unlink(path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_test_unix_server(*, use_ssl=False):
|
||||
with unix_socket_path() as path:
|
||||
yield from _run_test_server(address=path, use_ssl=use_ssl,
|
||||
server_cls=SilentUnixWSGIServer,
|
||||
server_ssl_cls=UnixSSLWSGIServer)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
|
||||
yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
|
||||
server_cls=SilentWSGIServer,
|
||||
server_ssl_cls=SSLWSGIServer)
|
||||
|
||||
|
||||
def make_test_protocol(base):
|
||||
dct = {}
|
||||
for name in dir(base):
|
||||
|
@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop):
|
|||
def _write_to_self(self):
|
||||
pass
|
||||
|
||||
|
||||
def MockCallback(**kwargs):
|
||||
return unittest.mock.Mock(spec=['__call__'], **kwargs)
|
||||
|
|
|
@ -11,6 +11,7 @@ import sys
|
|||
import threading
|
||||
|
||||
|
||||
from . import base_events
|
||||
from . import base_subprocess
|
||||
from . import constants
|
||||
from . import events
|
||||
|
@ -31,9 +32,9 @@ if sys.platform == 'win32': # pragma: no cover
|
|||
|
||||
|
||||
class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
|
||||
"""Unix event loop
|
||||
"""Unix event loop.
|
||||
|
||||
Adds signal handling to SelectorEventLoop
|
||||
Adds signal handling and UNIX Domain Socket support to SelectorEventLoop.
|
||||
"""
|
||||
|
||||
def __init__(self, selector=None):
|
||||
|
@ -164,6 +165,76 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
|
|||
def _child_watcher_callback(self, pid, returncode, transp):
|
||||
self.call_soon_threadsafe(transp._process_exited, returncode)
|
||||
|
||||
@tasks.coroutine
|
||||
def create_unix_connection(self, protocol_factory, path, *,
|
||||
ssl=None, sock=None,
|
||||
server_hostname=None):
|
||||
assert server_hostname is None or isinstance(server_hostname, str)
|
||||
if ssl:
|
||||
if server_hostname is None:
|
||||
raise ValueError(
|
||||
'you have to pass server_hostname when using ssl')
|
||||
else:
|
||||
if server_hostname is not None:
|
||||
raise ValueError('server_hostname is only meaningful with ssl')
|
||||
|
||||
if path is not None:
|
||||
if sock is not None:
|
||||
raise ValueError(
|
||||
'path and sock can not be specified at the same time')
|
||||
|
||||
try:
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
|
||||
sock.setblocking(False)
|
||||
yield from self.sock_connect(sock, path)
|
||||
except OSError:
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
raise
|
||||
|
||||
else:
|
||||
if sock is None:
|
||||
raise ValueError('no path and sock were specified')
|
||||
sock.setblocking(False)
|
||||
|
||||
transport, protocol = yield from self._create_connection_transport(
|
||||
sock, protocol_factory, ssl, server_hostname)
|
||||
return transport, protocol
|
||||
|
||||
@tasks.coroutine
|
||||
def create_unix_server(self, protocol_factory, path=None, *,
|
||||
sock=None, backlog=100, ssl=None):
|
||||
if isinstance(ssl, bool):
|
||||
raise TypeError('ssl argument must be an SSLContext or None')
|
||||
|
||||
if path is not None:
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
|
||||
try:
|
||||
sock.bind(path)
|
||||
except OSError as exc:
|
||||
if exc.errno == errno.EADDRINUSE:
|
||||
# Let's improve the error message by adding
|
||||
# with what exact address it occurs.
|
||||
msg = 'Address {!r} is already in use'.format(path)
|
||||
raise OSError(errno.EADDRINUSE, msg) from None
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
if sock is None:
|
||||
raise ValueError(
|
||||
'path was not specified, and no sock specified')
|
||||
|
||||
if sock.family != socket.AF_UNIX:
|
||||
raise ValueError(
|
||||
'A UNIX Domain Socket was expected, got {!r}'.format(sock))
|
||||
|
||||
server = base_events.Server(self, [sock])
|
||||
sock.listen(backlog)
|
||||
sock.setblocking(False)
|
||||
self._start_serving(protocol_factory, sock, ssl, server)
|
||||
return server
|
||||
|
||||
|
||||
def _set_nonblocking(fd):
|
||||
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
|
|
|
@ -212,7 +212,7 @@ class BaseEventLoopTests(unittest.TestCase):
|
|||
|
||||
idx = -1
|
||||
data = [10.0, 10.0, 10.3, 13.0]
|
||||
self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda:True, ())]
|
||||
self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, ())]
|
||||
self.loop._run_once()
|
||||
self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0])
|
||||
|
||||
|
|
|
@ -39,13 +39,14 @@ def data_file(filename):
|
|||
return fullname
|
||||
raise FileNotFoundError(filename)
|
||||
|
||||
|
||||
ONLYCERT = data_file('ssl_cert.pem')
|
||||
ONLYKEY = data_file('ssl_key.pem')
|
||||
SIGNED_CERTFILE = data_file('keycert3.pem')
|
||||
SIGNING_CA = data_file('pycacert.pem')
|
||||
|
||||
|
||||
class MyProto(asyncio.Protocol):
|
||||
class MyBaseProto(asyncio.Protocol):
|
||||
done = None
|
||||
|
||||
def __init__(self, loop=None):
|
||||
|
@ -59,7 +60,6 @@ class MyProto(asyncio.Protocol):
|
|||
self.transport = transport
|
||||
assert self.state == 'INITIAL', self.state
|
||||
self.state = 'CONNECTED'
|
||||
transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
|
||||
|
||||
def data_received(self, data):
|
||||
assert self.state == 'CONNECTED', self.state
|
||||
|
@ -76,6 +76,12 @@ class MyProto(asyncio.Protocol):
|
|||
self.done.set_result(None)
|
||||
|
||||
|
||||
class MyProto(MyBaseProto):
|
||||
def connection_made(self, transport):
|
||||
super().connection_made(transport)
|
||||
transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
|
||||
|
||||
|
||||
class MyDatagramProto(asyncio.DatagramProtocol):
|
||||
done = None
|
||||
|
||||
|
@ -357,22 +363,30 @@ class EventLoopTestsMixin:
|
|||
r.close()
|
||||
self.assertGreaterEqual(len(data), 200)
|
||||
|
||||
def _basetest_sock_client_ops(self, httpd, sock):
|
||||
sock.setblocking(False)
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_connect(sock, httpd.address))
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
|
||||
data = self.loop.run_until_complete(
|
||||
self.loop.sock_recv(sock, 1024))
|
||||
# consume data
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_recv(sock, 1024))
|
||||
sock.close()
|
||||
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
|
||||
|
||||
def test_sock_client_ops(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
sock = socket.socket()
|
||||
sock.setblocking(False)
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_connect(sock, httpd.address))
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
|
||||
data = self.loop.run_until_complete(
|
||||
self.loop.sock_recv(sock, 1024))
|
||||
# consume data
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_recv(sock, 1024))
|
||||
sock.close()
|
||||
self._basetest_sock_client_ops(httpd, sock)
|
||||
|
||||
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
|
||||
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
|
||||
def test_unix_sock_client_ops(self):
|
||||
with test_utils.run_test_unix_server() as httpd:
|
||||
sock = socket.socket(socket.AF_UNIX)
|
||||
self._basetest_sock_client_ops(httpd, sock)
|
||||
|
||||
def test_sock_client_fail(self):
|
||||
# Make sure that we will get an unused port
|
||||
|
@ -485,16 +499,26 @@ class EventLoopTestsMixin:
|
|||
self.loop.run_forever()
|
||||
self.assertEqual(caught, 1)
|
||||
|
||||
def _basetest_create_connection(self, connection_fut):
|
||||
tr, pr = self.loop.run_until_complete(connection_fut)
|
||||
self.assertIsInstance(tr, asyncio.Transport)
|
||||
self.assertIsInstance(pr, asyncio.Protocol)
|
||||
self.loop.run_until_complete(pr.done)
|
||||
self.assertGreater(pr.nbytes, 0)
|
||||
tr.close()
|
||||
|
||||
def test_create_connection(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
f = self.loop.create_connection(
|
||||
conn_fut = self.loop.create_connection(
|
||||
lambda: MyProto(loop=self.loop), *httpd.address)
|
||||
tr, pr = self.loop.run_until_complete(f)
|
||||
self.assertIsInstance(tr, asyncio.Transport)
|
||||
self.assertIsInstance(pr, asyncio.Protocol)
|
||||
self.loop.run_until_complete(pr.done)
|
||||
self.assertGreater(pr.nbytes, 0)
|
||||
tr.close()
|
||||
self._basetest_create_connection(conn_fut)
|
||||
|
||||
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
|
||||
def test_create_unix_connection(self):
|
||||
with test_utils.run_test_unix_server() as httpd:
|
||||
conn_fut = self.loop.create_unix_connection(
|
||||
lambda: MyProto(loop=self.loop), httpd.address)
|
||||
self._basetest_create_connection(conn_fut)
|
||||
|
||||
def test_create_connection_sock(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
|
@ -524,20 +548,37 @@ class EventLoopTestsMixin:
|
|||
self.assertGreater(pr.nbytes, 0)
|
||||
tr.close()
|
||||
|
||||
def _basetest_create_ssl_connection(self, connection_fut):
|
||||
tr, pr = self.loop.run_until_complete(connection_fut)
|
||||
self.assertIsInstance(tr, asyncio.Transport)
|
||||
self.assertIsInstance(pr, asyncio.Protocol)
|
||||
self.assertTrue('ssl' in tr.__class__.__name__.lower())
|
||||
self.assertIsNotNone(tr.get_extra_info('sockname'))
|
||||
self.loop.run_until_complete(pr.done)
|
||||
self.assertGreater(pr.nbytes, 0)
|
||||
tr.close()
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
def test_create_ssl_connection(self):
|
||||
with test_utils.run_test_server(use_ssl=True) as httpd:
|
||||
f = self.loop.create_connection(
|
||||
lambda: MyProto(loop=self.loop), *httpd.address,
|
||||
conn_fut = self.loop.create_connection(
|
||||
lambda: MyProto(loop=self.loop),
|
||||
*httpd.address,
|
||||
ssl=test_utils.dummy_ssl_context())
|
||||
tr, pr = self.loop.run_until_complete(f)
|
||||
self.assertIsInstance(tr, asyncio.Transport)
|
||||
self.assertIsInstance(pr, asyncio.Protocol)
|
||||
self.assertTrue('ssl' in tr.__class__.__name__.lower())
|
||||
self.assertIsNotNone(tr.get_extra_info('sockname'))
|
||||
self.loop.run_until_complete(pr.done)
|
||||
self.assertGreater(pr.nbytes, 0)
|
||||
tr.close()
|
||||
|
||||
self._basetest_create_ssl_connection(conn_fut)
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
|
||||
def test_create_ssl_unix_connection(self):
|
||||
with test_utils.run_test_unix_server(use_ssl=True) as httpd:
|
||||
conn_fut = self.loop.create_unix_connection(
|
||||
lambda: MyProto(loop=self.loop),
|
||||
httpd.address,
|
||||
ssl=test_utils.dummy_ssl_context(),
|
||||
server_hostname='127.0.0.1')
|
||||
|
||||
self._basetest_create_ssl_connection(conn_fut)
|
||||
|
||||
def test_create_connection_local_addr(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
|
@ -561,14 +602,8 @@ class EventLoopTestsMixin:
|
|||
self.assertIn(str(httpd.address), cm.exception.strerror)
|
||||
|
||||
def test_create_server(self):
|
||||
proto = None
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyProto()
|
||||
return proto
|
||||
|
||||
f = self.loop.create_server(factory, '0.0.0.0', 0)
|
||||
proto = MyProto()
|
||||
f = self.loop.create_server(lambda: proto, '0.0.0.0', 0)
|
||||
server = self.loop.run_until_complete(f)
|
||||
self.assertEqual(len(server.sockets), 1)
|
||||
sock = server.sockets[0]
|
||||
|
@ -605,38 +640,76 @@ class EventLoopTestsMixin:
|
|||
# close server
|
||||
server.close()
|
||||
|
||||
def _make_ssl_server(self, factory, certfile, keyfile=None):
|
||||
def _make_unix_server(self, factory, **kwargs):
|
||||
path = test_utils.gen_unix_socket_path()
|
||||
self.addCleanup(lambda: os.path.exists(path) and os.unlink(path))
|
||||
|
||||
f = self.loop.create_unix_server(factory, path, **kwargs)
|
||||
server = self.loop.run_until_complete(f)
|
||||
|
||||
return server, path
|
||||
|
||||
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
|
||||
def test_create_unix_server(self):
|
||||
proto = MyProto()
|
||||
server, path = self._make_unix_server(lambda: proto)
|
||||
self.assertEqual(len(server.sockets), 1)
|
||||
|
||||
client = socket.socket(socket.AF_UNIX)
|
||||
client.connect(path)
|
||||
client.sendall(b'xxx')
|
||||
test_utils.run_briefly(self.loop)
|
||||
test_utils.run_until(self.loop, lambda: proto is not None, 10)
|
||||
|
||||
self.assertIsInstance(proto, MyProto)
|
||||
self.assertEqual('INITIAL', proto.state)
|
||||
test_utils.run_briefly(self.loop)
|
||||
self.assertEqual('CONNECTED', proto.state)
|
||||
test_utils.run_until(self.loop, lambda: proto.nbytes > 0,
|
||||
timeout=10)
|
||||
self.assertEqual(3, proto.nbytes)
|
||||
|
||||
# close connection
|
||||
proto.transport.close()
|
||||
test_utils.run_briefly(self.loop) # windows iocp
|
||||
|
||||
self.assertEqual('CLOSED', proto.state)
|
||||
|
||||
# the client socket must be closed after to avoid ECONNRESET upon
|
||||
# recv()/send() on the serving socket
|
||||
client.close()
|
||||
|
||||
# close server
|
||||
server.close()
|
||||
|
||||
def _create_ssl_context(self, certfile, keyfile=None):
|
||||
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext.options |= ssl.OP_NO_SSLv2
|
||||
sslcontext.load_cert_chain(certfile, keyfile)
|
||||
return sslcontext
|
||||
|
||||
f = self.loop.create_server(
|
||||
factory, '127.0.0.1', 0, ssl=sslcontext)
|
||||
def _make_ssl_server(self, factory, certfile, keyfile=None):
|
||||
sslcontext = self._create_ssl_context(certfile, keyfile)
|
||||
|
||||
f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext)
|
||||
server = self.loop.run_until_complete(f)
|
||||
|
||||
sock = server.sockets[0]
|
||||
host, port = sock.getsockname()
|
||||
self.assertEqual(host, '127.0.0.1')
|
||||
return server, host, port
|
||||
|
||||
def _make_ssl_unix_server(self, factory, certfile, keyfile=None):
|
||||
sslcontext = self._create_ssl_context(certfile, keyfile)
|
||||
return self._make_unix_server(factory, ssl=sslcontext)
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
def test_create_server_ssl(self):
|
||||
proto = None
|
||||
proto = MyProto(loop=self.loop)
|
||||
server, host, port = self._make_ssl_server(
|
||||
lambda: proto, ONLYCERT, ONLYKEY)
|
||||
|
||||
class ClientMyProto(MyProto):
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
assert self.state == 'INITIAL', self.state
|
||||
self.state = 'CONNECTED'
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyProto(loop=self.loop)
|
||||
return proto
|
||||
|
||||
server, host, port = self._make_ssl_server(factory, ONLYCERT, ONLYKEY)
|
||||
|
||||
f_c = self.loop.create_connection(ClientMyProto, host, port,
|
||||
f_c = self.loop.create_connection(MyBaseProto, host, port,
|
||||
ssl=test_utils.dummy_ssl_context())
|
||||
client, pr = self.loop.run_until_complete(f_c)
|
||||
|
||||
|
@ -666,17 +739,46 @@ class EventLoopTestsMixin:
|
|||
# stop serving
|
||||
server.close()
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
|
||||
def test_create_unix_server_ssl(self):
|
||||
proto = MyProto(loop=self.loop)
|
||||
server, path = self._make_ssl_unix_server(
|
||||
lambda: proto, ONLYCERT, ONLYKEY)
|
||||
|
||||
f_c = self.loop.create_unix_connection(
|
||||
MyBaseProto, path, ssl=test_utils.dummy_ssl_context(),
|
||||
server_hostname='')
|
||||
|
||||
client, pr = self.loop.run_until_complete(f_c)
|
||||
|
||||
client.write(b'xxx')
|
||||
test_utils.run_briefly(self.loop)
|
||||
self.assertIsInstance(proto, MyProto)
|
||||
test_utils.run_briefly(self.loop)
|
||||
self.assertEqual('CONNECTED', proto.state)
|
||||
test_utils.run_until(self.loop, lambda: proto.nbytes > 0,
|
||||
timeout=10)
|
||||
self.assertEqual(3, proto.nbytes)
|
||||
|
||||
# close connection
|
||||
proto.transport.close()
|
||||
self.loop.run_until_complete(proto.done)
|
||||
self.assertEqual('CLOSED', proto.state)
|
||||
|
||||
# the client socket must be closed after to avoid ECONNRESET upon
|
||||
# recv()/send() on the serving socket
|
||||
client.close()
|
||||
|
||||
# stop serving
|
||||
server.close()
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
|
||||
def test_create_server_ssl_verify_failed(self):
|
||||
proto = None
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyProto(loop=self.loop)
|
||||
return proto
|
||||
|
||||
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
|
||||
proto = MyProto(loop=self.loop)
|
||||
server, host, port = self._make_ssl_server(
|
||||
lambda: proto, SIGNED_CERTFILE)
|
||||
|
||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||
|
@ -695,17 +797,38 @@ class EventLoopTestsMixin:
|
|||
self.assertIsNone(proto.transport)
|
||||
server.close()
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
|
||||
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
|
||||
def test_create_unix_server_ssl_verify_failed(self):
|
||||
proto = MyProto(loop=self.loop)
|
||||
server, path = self._make_ssl_unix_server(
|
||||
lambda: proto, SIGNED_CERTFILE)
|
||||
|
||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
|
||||
if hasattr(sslcontext_client, 'check_hostname'):
|
||||
sslcontext_client.check_hostname = True
|
||||
|
||||
# no CA loaded
|
||||
f_c = self.loop.create_unix_connection(MyProto, path,
|
||||
ssl=sslcontext_client,
|
||||
server_hostname='invalid')
|
||||
with self.assertRaisesRegex(ssl.SSLError,
|
||||
'certificate verify failed '):
|
||||
self.loop.run_until_complete(f_c)
|
||||
|
||||
# close connection
|
||||
self.assertIsNone(proto.transport)
|
||||
server.close()
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
|
||||
def test_create_server_ssl_match_failed(self):
|
||||
proto = None
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyProto(loop=self.loop)
|
||||
return proto
|
||||
|
||||
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
|
||||
proto = MyProto(loop=self.loop)
|
||||
server, host, port = self._make_ssl_server(
|
||||
lambda: proto, SIGNED_CERTFILE)
|
||||
|
||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||
|
@ -727,17 +850,38 @@ class EventLoopTestsMixin:
|
|||
proto.transport.close()
|
||||
server.close()
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
|
||||
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
|
||||
def test_create_unix_server_ssl_verified(self):
|
||||
proto = MyProto(loop=self.loop)
|
||||
server, path = self._make_ssl_unix_server(
|
||||
lambda: proto, SIGNED_CERTFILE)
|
||||
|
||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
|
||||
sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
|
||||
if hasattr(sslcontext_client, 'check_hostname'):
|
||||
sslcontext_client.check_hostname = True
|
||||
|
||||
# Connection succeeds with correct CA and server hostname.
|
||||
f_c = self.loop.create_unix_connection(MyProto, path,
|
||||
ssl=sslcontext_client,
|
||||
server_hostname='localhost')
|
||||
client, pr = self.loop.run_until_complete(f_c)
|
||||
|
||||
# close connection
|
||||
proto.transport.close()
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
|
||||
def test_create_server_ssl_verified(self):
|
||||
proto = None
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyProto(loop=self.loop)
|
||||
return proto
|
||||
|
||||
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
|
||||
proto = MyProto(loop=self.loop)
|
||||
server, host, port = self._make_ssl_server(
|
||||
lambda: proto, SIGNED_CERTFILE)
|
||||
|
||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||
|
@ -915,19 +1059,15 @@ class EventLoopTestsMixin:
|
|||
@unittest.skipUnless(sys.platform != 'win32',
|
||||
"Don't support pipes for Windows")
|
||||
def test_read_pipe(self):
|
||||
proto = None
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyReadPipeProto(loop=self.loop)
|
||||
return proto
|
||||
proto = MyReadPipeProto(loop=self.loop)
|
||||
|
||||
rpipe, wpipe = os.pipe()
|
||||
pipeobj = io.open(rpipe, 'rb', 1024)
|
||||
|
||||
@asyncio.coroutine
|
||||
def connect():
|
||||
t, p = yield from self.loop.connect_read_pipe(factory, pipeobj)
|
||||
t, p = yield from self.loop.connect_read_pipe(
|
||||
lambda: proto, pipeobj)
|
||||
self.assertIs(p, proto)
|
||||
self.assertIs(t, proto.transport)
|
||||
self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
|
||||
|
@ -959,19 +1099,14 @@ class EventLoopTestsMixin:
|
|||
# Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9
|
||||
@support.requires_freebsd_version(8)
|
||||
def test_read_pty_output(self):
|
||||
proto = None
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyReadPipeProto(loop=self.loop)
|
||||
return proto
|
||||
proto = MyReadPipeProto(loop=self.loop)
|
||||
|
||||
master, slave = os.openpty()
|
||||
master_read_obj = io.open(master, 'rb', 0)
|
||||
|
||||
@asyncio.coroutine
|
||||
def connect():
|
||||
t, p = yield from self.loop.connect_read_pipe(factory,
|
||||
t, p = yield from self.loop.connect_read_pipe(lambda: proto,
|
||||
master_read_obj)
|
||||
self.assertIs(p, proto)
|
||||
self.assertIs(t, proto.transport)
|
||||
|
@ -999,21 +1134,17 @@ class EventLoopTestsMixin:
|
|||
@unittest.skipUnless(sys.platform != 'win32',
|
||||
"Don't support pipes for Windows")
|
||||
def test_write_pipe(self):
|
||||
proto = None
|
||||
proto = MyWritePipeProto(loop=self.loop)
|
||||
transport = None
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyWritePipeProto(loop=self.loop)
|
||||
return proto
|
||||
|
||||
rpipe, wpipe = os.pipe()
|
||||
pipeobj = io.open(wpipe, 'wb', 1024)
|
||||
|
||||
@asyncio.coroutine
|
||||
def connect():
|
||||
nonlocal transport
|
||||
t, p = yield from self.loop.connect_write_pipe(factory, pipeobj)
|
||||
t, p = yield from self.loop.connect_write_pipe(
|
||||
lambda: proto, pipeobj)
|
||||
self.assertIs(p, proto)
|
||||
self.assertIs(t, proto.transport)
|
||||
self.assertEqual('CONNECTED', proto.state)
|
||||
|
@ -1045,21 +1176,16 @@ class EventLoopTestsMixin:
|
|||
@unittest.skipUnless(sys.platform != 'win32',
|
||||
"Don't support pipes for Windows")
|
||||
def test_write_pipe_disconnect_on_close(self):
|
||||
proto = None
|
||||
proto = MyWritePipeProto(loop=self.loop)
|
||||
transport = None
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyWritePipeProto(loop=self.loop)
|
||||
return proto
|
||||
|
||||
rsock, wsock = test_utils.socketpair()
|
||||
pipeobj = io.open(wsock.detach(), 'wb', 1024)
|
||||
|
||||
@asyncio.coroutine
|
||||
def connect():
|
||||
nonlocal transport
|
||||
t, p = yield from self.loop.connect_write_pipe(factory,
|
||||
t, p = yield from self.loop.connect_write_pipe(lambda: proto,
|
||||
pipeobj)
|
||||
self.assertIs(p, proto)
|
||||
self.assertIs(t, proto.transport)
|
||||
|
@ -1084,21 +1210,16 @@ class EventLoopTestsMixin:
|
|||
# older than 10.6 (Snow Leopard)
|
||||
@support.requires_mac_ver(10, 6)
|
||||
def test_write_pty(self):
|
||||
proto = None
|
||||
proto = MyWritePipeProto(loop=self.loop)
|
||||
transport = None
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyWritePipeProto(loop=self.loop)
|
||||
return proto
|
||||
|
||||
master, slave = os.openpty()
|
||||
slave_write_obj = io.open(slave, 'wb', 0)
|
||||
|
||||
@asyncio.coroutine
|
||||
def connect():
|
||||
nonlocal transport
|
||||
t, p = yield from self.loop.connect_write_pipe(factory,
|
||||
t, p = yield from self.loop.connect_write_pipe(lambda: proto,
|
||||
slave_write_obj)
|
||||
self.assertIs(p, proto)
|
||||
self.assertIs(t, proto.transport)
|
||||
|
|
|
@ -55,7 +55,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
|
|||
self.loop.remove_reader = unittest.mock.Mock()
|
||||
self.loop.remove_writer = unittest.mock.Mock()
|
||||
waiter = asyncio.Future(loop=self.loop)
|
||||
transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter)
|
||||
transport = self.loop._make_ssl_transport(
|
||||
m, asyncio.Protocol(), m, waiter)
|
||||
self.assertIsInstance(transport, _SelectorSslTransport)
|
||||
|
||||
@unittest.mock.patch('asyncio.selector_events.ssl', None)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
"""Tests for streams.py."""
|
||||
|
||||
import functools
|
||||
import gc
|
||||
import socket
|
||||
import unittest
|
||||
import unittest.mock
|
||||
try:
|
||||
|
@ -32,48 +34,85 @@ class StreamReaderTests(unittest.TestCase):
|
|||
stream = asyncio.StreamReader()
|
||||
self.assertIs(stream._loop, m_events.get_event_loop.return_value)
|
||||
|
||||
def _basetest_open_connection(self, open_connection_fut):
|
||||
reader, writer = self.loop.run_until_complete(open_connection_fut)
|
||||
writer.write(b'GET / HTTP/1.0\r\n\r\n')
|
||||
f = reader.readline()
|
||||
data = self.loop.run_until_complete(f)
|
||||
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
|
||||
f = reader.read()
|
||||
data = self.loop.run_until_complete(f)
|
||||
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
|
||||
writer.close()
|
||||
|
||||
def test_open_connection(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
f = asyncio.open_connection(*httpd.address, loop=self.loop)
|
||||
reader, writer = self.loop.run_until_complete(f)
|
||||
writer.write(b'GET / HTTP/1.0\r\n\r\n')
|
||||
f = reader.readline()
|
||||
data = self.loop.run_until_complete(f)
|
||||
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
|
||||
f = reader.read()
|
||||
data = self.loop.run_until_complete(f)
|
||||
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
|
||||
conn_fut = asyncio.open_connection(*httpd.address,
|
||||
loop=self.loop)
|
||||
self._basetest_open_connection(conn_fut)
|
||||
|
||||
writer.close()
|
||||
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
|
||||
def test_open_unix_connection(self):
|
||||
with test_utils.run_test_unix_server() as httpd:
|
||||
conn_fut = asyncio.open_unix_connection(httpd.address,
|
||||
loop=self.loop)
|
||||
self._basetest_open_connection(conn_fut)
|
||||
|
||||
def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
|
||||
try:
|
||||
reader, writer = self.loop.run_until_complete(open_connection_fut)
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
writer.write(b'GET / HTTP/1.0\r\n\r\n')
|
||||
f = reader.read()
|
||||
data = self.loop.run_until_complete(f)
|
||||
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
|
||||
|
||||
writer.close()
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
def test_open_connection_no_loop_ssl(self):
|
||||
with test_utils.run_test_server(use_ssl=True) as httpd:
|
||||
try:
|
||||
asyncio.set_event_loop(self.loop)
|
||||
f = asyncio.open_connection(*httpd.address,
|
||||
ssl=test_utils.dummy_ssl_context())
|
||||
reader, writer = self.loop.run_until_complete(f)
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
writer.write(b'GET / HTTP/1.0\r\n\r\n')
|
||||
f = reader.read()
|
||||
data = self.loop.run_until_complete(f)
|
||||
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
|
||||
conn_fut = asyncio.open_connection(
|
||||
*httpd.address,
|
||||
ssl=test_utils.dummy_ssl_context(),
|
||||
loop=self.loop)
|
||||
|
||||
writer.close()
|
||||
self._basetest_open_connection_no_loop_ssl(conn_fut)
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
|
||||
def test_open_unix_connection_no_loop_ssl(self):
|
||||
with test_utils.run_test_unix_server(use_ssl=True) as httpd:
|
||||
conn_fut = asyncio.open_unix_connection(
|
||||
httpd.address,
|
||||
ssl=test_utils.dummy_ssl_context(),
|
||||
server_hostname='',
|
||||
loop=self.loop)
|
||||
|
||||
self._basetest_open_connection_no_loop_ssl(conn_fut)
|
||||
|
||||
def _basetest_open_connection_error(self, open_connection_fut):
|
||||
reader, writer = self.loop.run_until_complete(open_connection_fut)
|
||||
writer._protocol.connection_lost(ZeroDivisionError())
|
||||
f = reader.read()
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
self.loop.run_until_complete(f)
|
||||
writer.close()
|
||||
test_utils.run_briefly(self.loop)
|
||||
|
||||
def test_open_connection_error(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
f = asyncio.open_connection(*httpd.address, loop=self.loop)
|
||||
reader, writer = self.loop.run_until_complete(f)
|
||||
writer._protocol.connection_lost(ZeroDivisionError())
|
||||
f = reader.read()
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
self.loop.run_until_complete(f)
|
||||
conn_fut = asyncio.open_connection(*httpd.address,
|
||||
loop=self.loop)
|
||||
self._basetest_open_connection_error(conn_fut)
|
||||
|
||||
writer.close()
|
||||
test_utils.run_briefly(self.loop)
|
||||
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
|
||||
def test_open_unix_connection_error(self):
|
||||
with test_utils.run_test_unix_server() as httpd:
|
||||
conn_fut = asyncio.open_unix_connection(httpd.address,
|
||||
loop=self.loop)
|
||||
self._basetest_open_connection_error(conn_fut)
|
||||
|
||||
def test_feed_empty_data(self):
|
||||
stream = asyncio.StreamReader(loop=self.loop)
|
||||
|
@ -415,10 +454,86 @@ class StreamReaderTests(unittest.TestCase):
|
|||
client_writer.write(data)
|
||||
|
||||
def start(self):
|
||||
sock = socket.socket()
|
||||
sock.bind(('127.0.0.1', 0))
|
||||
self.server = self.loop.run_until_complete(
|
||||
asyncio.start_server(self.handle_client,
|
||||
'127.0.0.1', 12345,
|
||||
sock=sock,
|
||||
loop=self.loop))
|
||||
return sock.getsockname()
|
||||
|
||||
def handle_client_callback(self, client_reader, client_writer):
|
||||
task = asyncio.Task(client_reader.readline(), loop=self.loop)
|
||||
|
||||
def done(task):
|
||||
client_writer.write(task.result())
|
||||
|
||||
task.add_done_callback(done)
|
||||
|
||||
def start_callback(self):
|
||||
sock = socket.socket()
|
||||
sock.bind(('127.0.0.1', 0))
|
||||
addr = sock.getsockname()
|
||||
sock.close()
|
||||
self.server = self.loop.run_until_complete(
|
||||
asyncio.start_server(self.handle_client_callback,
|
||||
host=addr[0], port=addr[1],
|
||||
loop=self.loop))
|
||||
return addr
|
||||
|
||||
def stop(self):
|
||||
if self.server is not None:
|
||||
self.server.close()
|
||||
self.loop.run_until_complete(self.server.wait_closed())
|
||||
self.server = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def client(addr):
|
||||
reader, writer = yield from asyncio.open_connection(
|
||||
*addr, loop=self.loop)
|
||||
# send a line
|
||||
writer.write(b"hello world!\n")
|
||||
# read it back
|
||||
msgback = yield from reader.readline()
|
||||
writer.close()
|
||||
return msgback
|
||||
|
||||
# test the server variant with a coroutine as client handler
|
||||
server = MyServer(self.loop)
|
||||
addr = server.start()
|
||||
msg = self.loop.run_until_complete(asyncio.Task(client(addr),
|
||||
loop=self.loop))
|
||||
server.stop()
|
||||
self.assertEqual(msg, b"hello world!\n")
|
||||
|
||||
# test the server variant with a callback as client handler
|
||||
server = MyServer(self.loop)
|
||||
addr = server.start_callback()
|
||||
msg = self.loop.run_until_complete(asyncio.Task(client(addr),
|
||||
loop=self.loop))
|
||||
server.stop()
|
||||
self.assertEqual(msg, b"hello world!\n")
|
||||
|
||||
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
|
||||
def test_start_unix_server(self):
|
||||
|
||||
class MyServer:
|
||||
|
||||
def __init__(self, loop, path):
|
||||
self.server = None
|
||||
self.loop = loop
|
||||
self.path = path
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_client(self, client_reader, client_writer):
|
||||
data = yield from client_reader.readline()
|
||||
client_writer.write(data)
|
||||
|
||||
def start(self):
|
||||
self.server = self.loop.run_until_complete(
|
||||
asyncio.start_unix_server(self.handle_client,
|
||||
path=self.path,
|
||||
loop=self.loop))
|
||||
|
||||
def handle_client_callback(self, client_reader, client_writer):
|
||||
task = asyncio.Task(client_reader.readline(), loop=self.loop)
|
||||
|
@ -430,9 +545,9 @@ class StreamReaderTests(unittest.TestCase):
|
|||
|
||||
def start_callback(self):
|
||||
self.server = self.loop.run_until_complete(
|
||||
asyncio.start_server(self.handle_client_callback,
|
||||
'127.0.0.1', 12345,
|
||||
loop=self.loop))
|
||||
asyncio.start_unix_server(self.handle_client_callback,
|
||||
path=self.path,
|
||||
loop=self.loop))
|
||||
|
||||
def stop(self):
|
||||
if self.server is not None:
|
||||
|
@ -441,9 +556,9 @@ class StreamReaderTests(unittest.TestCase):
|
|||
self.server = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def client():
|
||||
reader, writer = yield from asyncio.open_connection(
|
||||
'127.0.0.1', 12345, loop=self.loop)
|
||||
def client(path):
|
||||
reader, writer = yield from asyncio.open_unix_connection(
|
||||
path, loop=self.loop)
|
||||
# send a line
|
||||
writer.write(b"hello world!\n")
|
||||
# read it back
|
||||
|
@ -452,20 +567,22 @@ class StreamReaderTests(unittest.TestCase):
|
|||
return msgback
|
||||
|
||||
# test the server variant with a coroutine as client handler
|
||||
server = MyServer(self.loop)
|
||||
server.start()
|
||||
msg = self.loop.run_until_complete(asyncio.Task(client(),
|
||||
loop=self.loop))
|
||||
server.stop()
|
||||
self.assertEqual(msg, b"hello world!\n")
|
||||
with test_utils.unix_socket_path() as path:
|
||||
server = MyServer(self.loop, path)
|
||||
server.start()
|
||||
msg = self.loop.run_until_complete(asyncio.Task(client(path),
|
||||
loop=self.loop))
|
||||
server.stop()
|
||||
self.assertEqual(msg, b"hello world!\n")
|
||||
|
||||
# test the server variant with a callback as client handler
|
||||
server = MyServer(self.loop)
|
||||
server.start_callback()
|
||||
msg = self.loop.run_until_complete(asyncio.Task(client(),
|
||||
loop=self.loop))
|
||||
server.stop()
|
||||
self.assertEqual(msg, b"hello world!\n")
|
||||
with test_utils.unix_socket_path() as path:
|
||||
server = MyServer(self.loop, path)
|
||||
server.start_callback()
|
||||
msg = self.loop.run_until_complete(asyncio.Task(client(path),
|
||||
loop=self.loop))
|
||||
server.stop()
|
||||
self.assertEqual(msg, b"hello world!\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -7,8 +7,10 @@ import io
|
|||
import os
|
||||
import pprint
|
||||
import signal
|
||||
import socket
|
||||
import stat
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
@ -24,7 +26,7 @@ from asyncio import unix_events
|
|||
|
||||
|
||||
@unittest.skipUnless(signal, 'Signals are not supported')
|
||||
class SelectorEventLoopTests(unittest.TestCase):
|
||||
class SelectorEventLoopSignalTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.loop = asyncio.SelectorEventLoop()
|
||||
|
@ -200,6 +202,84 @@ class SelectorEventLoopTests(unittest.TestCase):
|
|||
m_signal.set_wakeup_fd.assert_called_once_with(-1)
|
||||
|
||||
|
||||
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
|
||||
'UNIX Sockets are not supported')
|
||||
class SelectorEventLoopUnixSocketTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.loop = asyncio.SelectorEventLoop()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close()
|
||||
|
||||
def test_create_unix_server_existing_path_sock(self):
|
||||
with test_utils.unix_socket_path() as path:
|
||||
sock = socket.socket(socket.AF_UNIX)
|
||||
sock.bind(path)
|
||||
|
||||
coro = self.loop.create_unix_server(lambda: None, path)
|
||||
with self.assertRaisesRegexp(OSError,
|
||||
'Address.*is already in use'):
|
||||
self.loop.run_until_complete(coro)
|
||||
|
||||
def test_create_unix_server_existing_path_nonsock(self):
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
coro = self.loop.create_unix_server(lambda: None, file.name)
|
||||
with self.assertRaisesRegexp(OSError,
|
||||
'Address.*is already in use'):
|
||||
self.loop.run_until_complete(coro)
|
||||
|
||||
def test_create_unix_server_ssl_bool(self):
|
||||
coro = self.loop.create_unix_server(lambda: None, path='spam',
|
||||
ssl=True)
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'ssl argument must be an SSLContext'):
|
||||
self.loop.run_until_complete(coro)
|
||||
|
||||
def test_create_unix_server_nopath_nosock(self):
|
||||
coro = self.loop.create_unix_server(lambda: None, path=None)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'path was not specified, and no sock'):
|
||||
self.loop.run_until_complete(coro)
|
||||
|
||||
def test_create_unix_server_path_inetsock(self):
|
||||
coro = self.loop.create_unix_server(lambda: None, path=None,
|
||||
sock=socket.socket())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'A UNIX Domain Socket was expected'):
|
||||
self.loop.run_until_complete(coro)
|
||||
|
||||
def test_create_unix_connection_path_sock(self):
|
||||
coro = self.loop.create_unix_connection(
|
||||
lambda: None, '/dev/null', sock=object())
|
||||
with self.assertRaisesRegex(ValueError, 'path and sock can not be'):
|
||||
self.loop.run_until_complete(coro)
|
||||
|
||||
def test_create_unix_connection_nopath_nosock(self):
|
||||
coro = self.loop.create_unix_connection(
|
||||
lambda: None, None)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'no path and sock were specified'):
|
||||
self.loop.run_until_complete(coro)
|
||||
|
||||
def test_create_unix_connection_nossl_serverhost(self):
|
||||
coro = self.loop.create_unix_connection(
|
||||
lambda: None, '/dev/null', server_hostname='spam')
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'server_hostname is only meaningful'):
|
||||
self.loop.run_until_complete(coro)
|
||||
|
||||
def test_create_unix_connection_ssl_noserverhost(self):
|
||||
coro = self.loop.create_unix_connection(
|
||||
lambda: None, '/dev/null', ssl=True)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'you have to pass server_hostname when using ssl'):
|
||||
|
||||
self.loop.run_until_complete(coro)
|
||||
|
||||
|
||||
class UnixReadPipeTransportTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
|
Loading…
Reference in New Issue