276 lines
7.8 KiB
Python
276 lines
7.8 KiB
Python
"""Utilities shared by tests."""
|
|
|
|
import collections
|
|
import contextlib
|
|
import io
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
import unittest
|
|
import unittest.mock
|
|
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
|
|
try:
|
|
import ssl
|
|
except ImportError: # pragma: no cover
|
|
ssl = None
|
|
|
|
from . import tasks
|
|
from . import base_events
|
|
from . import events
|
|
from . import selectors
|
|
|
|
|
|
if sys.platform == 'win32': # pragma: no cover
|
|
from .windows_utils import socketpair
|
|
else:
|
|
from socket import socketpair # pragma: no cover
|
|
|
|
|
|
def dummy_ssl_context():
|
|
if ssl is None:
|
|
return None
|
|
else:
|
|
return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
|
|
|
|
|
def run_briefly(loop):
|
|
@tasks.coroutine
|
|
def once():
|
|
pass
|
|
gen = once()
|
|
t = tasks.Task(gen, loop=loop)
|
|
try:
|
|
loop.run_until_complete(t)
|
|
finally:
|
|
gen.close()
|
|
|
|
|
|
def run_until(loop, pred, timeout=None):
|
|
if timeout is not None:
|
|
deadline = time.time() + timeout
|
|
while not pred():
|
|
if timeout is not None:
|
|
timeout = deadline - time.time()
|
|
if timeout <= 0:
|
|
return False
|
|
loop.run_until_complete(tasks.sleep(timeout, loop=loop))
|
|
else:
|
|
run_briefly(loop)
|
|
return True
|
|
|
|
|
|
def run_once(loop):
|
|
"""loop.stop() schedules _raise_stop_error()
|
|
and run_forever() runs until _raise_stop_error() callback.
|
|
this wont work if test waits for some IO events, because
|
|
_raise_stop_error() runs before any of io events callbacks.
|
|
"""
|
|
loop.stop()
|
|
loop.run_forever()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
|
|
|
|
class SilentWSGIRequestHandler(WSGIRequestHandler):
|
|
def get_stderr(self):
|
|
return io.StringIO()
|
|
|
|
def log_message(self, format, *args):
|
|
pass
|
|
|
|
class SilentWSGIServer(WSGIServer):
|
|
def handle_error(self, request, client_address):
|
|
pass
|
|
|
|
class SSLWSGIServer(SilentWSGIServer):
|
|
def finish_request(self, request, client_address):
|
|
# The relative location of our test directory (which
|
|
# contains the ssl key and certificate files) differs
|
|
# between the stdlib and stand-alone asyncio.
|
|
# Prefer our own if we can find it.
|
|
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
|
|
if not os.path.isdir(here):
|
|
here = os.path.join(os.path.dirname(os.__file__),
|
|
'test', 'test_asyncio')
|
|
keyfile = os.path.join(here, 'ssl_key.pem')
|
|
certfile = os.path.join(here, 'ssl_cert.pem')
|
|
ssock = ssl.wrap_socket(request,
|
|
keyfile=keyfile,
|
|
certfile=certfile,
|
|
server_side=True)
|
|
try:
|
|
self.RequestHandlerClass(ssock, client_address, self)
|
|
ssock.close()
|
|
except OSError:
|
|
# maybe socket has been closed by peer
|
|
pass
|
|
|
|
def app(environ, start_response):
|
|
status = '200 OK'
|
|
headers = [('Content-type', 'text/plain')]
|
|
start_response(status, headers)
|
|
return [b'Test message']
|
|
|
|
# Run the test WSGI server in a separate thread in order not to
|
|
# interfere with event handling in the main thread
|
|
server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
|
|
httpd = make_server(host, port, app,
|
|
server_class, SilentWSGIRequestHandler)
|
|
httpd.address = httpd.server_address
|
|
server_thread = threading.Thread(target=httpd.serve_forever)
|
|
server_thread.start()
|
|
try:
|
|
yield httpd
|
|
finally:
|
|
httpd.shutdown()
|
|
httpd.server_close()
|
|
server_thread.join()
|
|
|
|
|
|
def make_test_protocol(base):
|
|
dct = {}
|
|
for name in dir(base):
|
|
if name.startswith('__') and name.endswith('__'):
|
|
# skip magic names
|
|
continue
|
|
dct[name] = unittest.mock.Mock(return_value=None)
|
|
return type('TestProtocol', (base,) + base.__bases__, dct)()
|
|
|
|
|
|
class TestSelector(selectors.BaseSelector):
|
|
|
|
def __init__(self):
|
|
self.keys = {}
|
|
|
|
def register(self, fileobj, events, data=None):
|
|
key = selectors.SelectorKey(fileobj, 0, events, data)
|
|
self.keys[fileobj] = key
|
|
return key
|
|
|
|
def unregister(self, fileobj):
|
|
return self.keys.pop(fileobj)
|
|
|
|
def select(self, timeout):
|
|
return []
|
|
|
|
def get_map(self):
|
|
return self.keys
|
|
|
|
|
|
class TestLoop(base_events.BaseEventLoop):
|
|
"""Loop for unittests.
|
|
|
|
It manages self time directly.
|
|
If something scheduled to be executed later then
|
|
on next loop iteration after all ready handlers done
|
|
generator passed to __init__ is calling.
|
|
|
|
Generator should be like this:
|
|
|
|
def gen():
|
|
...
|
|
when = yield ...
|
|
... = yield time_advance
|
|
|
|
Value retuned by yield is absolute time of next scheduled handler.
|
|
Value passed to yield is time advance to move loop's time forward.
|
|
"""
|
|
|
|
def __init__(self, gen=None):
|
|
super().__init__()
|
|
|
|
if gen is None:
|
|
def gen():
|
|
yield
|
|
self._check_on_close = False
|
|
else:
|
|
self._check_on_close = True
|
|
|
|
self._gen = gen()
|
|
next(self._gen)
|
|
self._time = 0
|
|
self._timers = []
|
|
self._selector = TestSelector()
|
|
|
|
self.readers = {}
|
|
self.writers = {}
|
|
self.reset_counters()
|
|
|
|
def time(self):
|
|
return self._time
|
|
|
|
def advance_time(self, advance):
|
|
"""Move test time forward."""
|
|
if advance:
|
|
self._time += advance
|
|
|
|
def close(self):
|
|
if self._check_on_close:
|
|
try:
|
|
self._gen.send(0)
|
|
except StopIteration:
|
|
pass
|
|
else: # pragma: no cover
|
|
raise AssertionError("Time generator is not finished")
|
|
|
|
def add_reader(self, fd, callback, *args):
|
|
self.readers[fd] = events.make_handle(callback, args)
|
|
|
|
def remove_reader(self, fd):
|
|
self.remove_reader_count[fd] += 1
|
|
if fd in self.readers:
|
|
del self.readers[fd]
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def assert_reader(self, fd, callback, *args):
|
|
assert fd in self.readers, 'fd {} is not registered'.format(fd)
|
|
handle = self.readers[fd]
|
|
assert handle._callback == callback, '{!r} != {!r}'.format(
|
|
handle._callback, callback)
|
|
assert handle._args == args, '{!r} != {!r}'.format(
|
|
handle._args, args)
|
|
|
|
def add_writer(self, fd, callback, *args):
|
|
self.writers[fd] = events.make_handle(callback, args)
|
|
|
|
def remove_writer(self, fd):
|
|
self.remove_writer_count[fd] += 1
|
|
if fd in self.writers:
|
|
del self.writers[fd]
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def assert_writer(self, fd, callback, *args):
|
|
assert fd in self.writers, 'fd {} is not registered'.format(fd)
|
|
handle = self.writers[fd]
|
|
assert handle._callback == callback, '{!r} != {!r}'.format(
|
|
handle._callback, callback)
|
|
assert handle._args == args, '{!r} != {!r}'.format(
|
|
handle._args, args)
|
|
|
|
def reset_counters(self):
|
|
self.remove_reader_count = collections.defaultdict(int)
|
|
self.remove_writer_count = collections.defaultdict(int)
|
|
|
|
def _run_once(self):
|
|
super()._run_once()
|
|
for when in self._timers:
|
|
advance = self._gen.send(when)
|
|
self.advance_time(advance)
|
|
self._timers = []
|
|
|
|
def call_at(self, when, callback, *args):
|
|
self._timers.append(when)
|
|
return super().call_at(when, callback, *args)
|
|
|
|
def _process_events(self, event_list):
|
|
return
|
|
|
|
def _write_to_self(self):
|
|
pass
|