asyncio: Add streams.start_server(), by Gustavo Carneiro.

This commit is contained in:
Guido van Rossum 2013-11-19 11:43:38 -08:00
parent 4a9ee26750
commit 1540b16ff4
2 changed files with 117 additions and 2 deletions

View File

@ -1,6 +1,8 @@
"""Stream-related things."""
__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection']
__all__ = ['StreamReader', 'StreamReaderProtocol',
'open_connection', 'start_server',
]
import collections
@ -43,6 +45,42 @@ def open_connection(host=None, port=None, *,
return reader, writer
@tasks.coroutine
def start_server(client_connected_cb, host=None, port=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Start a socket server, call back for each client connected.
The first parameter, `client_connected_cb`, takes two parameters:
client_reader, client_writer. client_reader is a StreamReader
object, while client_writer is a StreamWriter object. This
parameter can either be a plain callback function or a coroutine;
if it is a coroutine, it will be automatically converted into a
Task.
The rest of the arguments are all the usual arguments to
loop.create_server() except protocol_factory; most common are
positional host and port, with various optional keyword arguments
following. The return value is the same as loop.create_server().
Additional optional keyword arguments are loop (to set the event loop
instance to use) and limit (to set the buffer limit passed to the
StreamReader).
The return value is the same as loop.create_server(), i.e. a
Server object which can be used to stop the service.
"""
if loop is None:
loop = events.get_event_loop()
def factory():
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop)
return protocol
return (yield from loop.create_server(factory, host, port, **kwds))
class StreamReaderProtocol(protocols.Protocol):
"""Trivial helper class to adapt between Protocol and StreamReader.
@ -52,13 +90,24 @@ class StreamReaderProtocol(protocols.Protocol):
call inappropriate methods of the protocol.)
"""
def __init__(self, stream_reader):
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
self._stream_reader = stream_reader
self._stream_writer = None
self._drain_waiter = None
self._paused = False
self._client_connected_cb = client_connected_cb
self._loop = loop # May be None; we may never need it.
def connection_made(self, transport):
self._stream_reader.set_transport(transport)
if self._client_connected_cb is not None:
self._stream_writer = StreamWriter(transport, self,
self._stream_reader,
self._loop)
res = self._client_connected_cb(self._stream_reader,
self._stream_writer)
if tasks.iscoroutine(res):
tasks.Task(res, loop=self._loop)
def connection_lost(self, exc):
if exc is None:

View File

@ -359,6 +359,72 @@ class StreamReaderTests(unittest.TestCase):
test_utils.run_briefly(self.loop)
self.assertIs(stream._waiter, None)
def test_start_server(self):
class MyServer:
def __init__(self, loop):
self.server = None
self.loop = loop
@tasks.coroutine
def handle_client(self, client_reader, client_writer):
data = yield from client_reader.readline()
client_writer.write(data)
def start(self):
self.server = self.loop.run_until_complete(
streams.start_server(self.handle_client,
'127.0.0.1', 12345,
loop=self.loop))
def handle_client_callback(self, client_reader, client_writer):
task = tasks.Task(client_reader.readline(), loop=self.loop)
def done(task):
client_writer.write(task.result())
task.add_done_callback(done)
def start_callback(self):
self.server = self.loop.run_until_complete(
streams.start_server(self.handle_client_callback,
'127.0.0.1', 12345,
loop=self.loop))
def stop(self):
if self.server is not None:
self.server.close()
self.loop.run_until_complete(self.server.wait_closed())
self.server = None
@tasks.coroutine
def client():
reader, writer = yield from streams.open_connection(
'127.0.0.1', 12345, loop=self.loop)
# send a line
writer.write(b"hello world!\n")
# read it back
msgback = yield from reader.readline()
writer.close()
return msgback
# test the server variant with a coroutine as client handler
server = MyServer(self.loop)
server.start()
msg = self.loop.run_until_complete(tasks.Task(client(),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
# test the server variant with a callback as client handler
server = MyServer(self.loop)
server.start_callback()
msg = self.loop.run_until_complete(tasks.Task(client(),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
if __name__ == '__main__':
unittest.main()