diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index e9953682251..331d28d0a45 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,6 +1,8 @@ """Stream-related things.""" -__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection'] +__all__ = ['StreamReader', 'StreamReaderProtocol', + 'open_connection', 'start_server', + ] import collections @@ -43,6 +45,42 @@ def open_connection(host=None, port=None, *, return reader, writer +@tasks.coroutine +def start_server(client_connected_cb, host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Start a socket server, call back for each client connected. + + The first parameter, `client_connected_cb`, takes two parameters: + client_reader, client_writer. client_reader is a StreamReader + object, while client_writer is a StreamWriter object. This + parameter can either be a plain callback function or a coroutine; + if it is a coroutine, it will be automatically converted into a + Task. + + The rest of the arguments are all the usual arguments to + loop.create_server() except protocol_factory; most common are + positional host and port, with various optional keyword arguments + following. The return value is the same as loop.create_server(). + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + The return value is the same as loop.create_server(), i.e. a + Server object which can be used to stop the service. + """ + if loop is None: + loop = events.get_event_loop() + + def factory(): + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop) + return protocol + + return (yield from loop.create_server(factory, host, port, **kwds)) + + class StreamReaderProtocol(protocols.Protocol): """Trivial helper class to adapt between Protocol and StreamReader. @@ -52,13 +90,24 @@ class StreamReaderProtocol(protocols.Protocol): call inappropriate methods of the protocol.) """ - def __init__(self, stream_reader): + def __init__(self, stream_reader, client_connected_cb=None, loop=None): self._stream_reader = stream_reader + self._stream_writer = None self._drain_waiter = None self._paused = False + self._client_connected_cb = client_connected_cb + self._loop = loop # May be None; we may never need it. def connection_made(self, transport): self._stream_reader.set_transport(transport) + if self._client_connected_cb is not None: + self._stream_writer = StreamWriter(transport, self, + self._stream_reader, + self._loop) + res = self._client_connected_cb(self._stream_reader, + self._stream_writer) + if tasks.iscoroutine(res): + tasks.Task(res, loop=self._loop) def connection_lost(self, exc): if exc is None: diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 69e2246f441..5516c15873e 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -359,6 +359,72 @@ class StreamReaderTests(unittest.TestCase): test_utils.run_briefly(self.loop) self.assertIs(stream._waiter, None) + def test_start_server(self): + + class MyServer: + + def __init__(self, loop): + self.server = None + self.loop = loop + + @tasks.coroutine + def handle_client(self, client_reader, client_writer): + data = yield from client_reader.readline() + client_writer.write(data) + + def start(self): + self.server = self.loop.run_until_complete( + streams.start_server(self.handle_client, + '127.0.0.1', 12345, + loop=self.loop)) + + def handle_client_callback(self, client_reader, client_writer): + task = tasks.Task(client_reader.readline(), loop=self.loop) + + def done(task): + client_writer.write(task.result()) + + task.add_done_callback(done) + + def start_callback(self): + self.server = self.loop.run_until_complete( + streams.start_server(self.handle_client_callback, + '127.0.0.1', 12345, + loop=self.loop)) + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + @tasks.coroutine + def client(): + reader, writer = yield from streams.open_connection( + '127.0.0.1', 12345, loop=self.loop) + # send a line + writer.write(b"hello world!\n") + # read it back + msgback = yield from reader.readline() + writer.close() + return msgback + + # test the server variant with a coroutine as client handler + server = MyServer(self.loop) + server.start() + msg = self.loop.run_until_complete(tasks.Task(client(), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + # test the server variant with a callback as client handler + server = MyServer(self.loop) + server.start_callback() + msg = self.loop.run_until_complete(tasks.Task(client(), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + if __name__ == '__main__': unittest.main()