asyncio: sync with Tulip

- Tulip issue 185: Add a create_task() method to event loops. The create_task()
  method can be overriden in custom event loop to implement their own task
  class. For example, greenio and Pulsar projects use their own task class. The
  create_task() method is now preferred over creating directly task using the
  Task class.
- tests: fix a warning
- fix typo in the name of a test function
- Update AbstractEventLoop: add new event loop methods; update also the unit test
This commit is contained in:
Victor Stinner 2014-07-08 11:29:25 +02:00
parent 630a4f63c5
commit 896a25ab30
9 changed files with 63 additions and 5 deletions

View File

@ -151,6 +151,12 @@ class BaseEventLoop(events.AbstractEventLoop):
% (self.__class__.__name__, self.is_running(), % (self.__class__.__name__, self.is_running(),
self.is_closed(), self.get_debug())) self.is_closed(), self.get_debug()))
def create_task(self, coro):
"""Schedule a coroutine object.
Return a task object."""
return tasks.Task(coro, loop=self)
def _make_socket_transport(self, sock, protocol, waiter=None, *, def _make_socket_transport(self, sock, protocol, waiter=None, *,
extra=None, server=None): extra=None, server=None):
"""Create socket transport.""" """Create socket transport."""

View File

@ -200,6 +200,10 @@ class AbstractEventLoop:
"""Return whether the event loop is currently running.""" """Return whether the event loop is currently running."""
raise NotImplementedError raise NotImplementedError
def is_closed(self):
"""Returns True if the event loop was closed."""
raise NotImplementedError
def close(self): def close(self):
"""Close the loop. """Close the loop.
@ -225,6 +229,11 @@ class AbstractEventLoop:
def time(self): def time(self):
raise NotImplementedError raise NotImplementedError
# Method scheduling a coroutine object: create a task.
def create_task(self, coro):
raise NotImplementedError
# Methods for interacting with threads. # Methods for interacting with threads.
def call_soon_threadsafe(self, callback, *args): def call_soon_threadsafe(self, callback, *args):

View File

@ -213,7 +213,7 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
res = self._client_connected_cb(self._stream_reader, res = self._client_connected_cb(self._stream_reader,
self._stream_writer) self._stream_writer)
if coroutines.iscoroutine(res): if coroutines.iscoroutine(res):
tasks.Task(res, loop=self._loop) self._loop.create_task(res)
def connection_lost(self, exc): def connection_lost(self, exc):
if exc is None: if exc is None:

View File

@ -505,7 +505,9 @@ def async(coro_or_future, *, loop=None):
raise ValueError('loop argument must agree with Future') raise ValueError('loop argument must agree with Future')
return coro_or_future return coro_or_future
elif coroutines.iscoroutine(coro_or_future): elif coroutines.iscoroutine(coro_or_future):
task = Task(coro_or_future, loop=loop) if loop is None:
loop = events.get_event_loop()
task = loop.create_task(coro_or_future)
if task._source_traceback: if task._source_traceback:
del task._source_traceback[-1] del task._source_traceback[-1]
return task return task

View File

@ -48,7 +48,7 @@ def run_briefly(loop):
def once(): def once():
pass pass
gen = once() gen = once()
t = tasks.Task(gen, loop=loop) t = loop.create_task(gen)
# Don't log a warning if the task is not done after run_until_complete(). # Don't log a warning if the task is not done after run_until_complete().
# It occurs if the loop is stopped or if a task raises a BaseException. # It occurs if the loop is stopped or if a task raises a BaseException.
t._log_destroy_pending = False t._log_destroy_pending = False

View File

@ -12,6 +12,7 @@ from test.support import IPV6_ENABLED
import asyncio import asyncio
from asyncio import base_events from asyncio import base_events
from asyncio import events
from asyncio import constants from asyncio import constants
from asyncio import test_utils from asyncio import test_utils
@ -526,6 +527,29 @@ class BaseEventLoopTests(test_utils.TestCase):
PYTHONASYNCIODEBUG='1') PYTHONASYNCIODEBUG='1')
self.assertEqual(stdout.rstrip(), b'False') self.assertEqual(stdout.rstrip(), b'False')
def test_create_task(self):
class MyTask(asyncio.Task):
pass
@asyncio.coroutine
def test():
pass
class EventLoop(base_events.BaseEventLoop):
def create_task(self, coro):
return MyTask(coro, loop=loop)
loop = EventLoop()
self.set_event_loop(loop)
coro = test()
task = asyncio.async(coro, loop=loop)
self.assertIsInstance(task, MyTask)
# make warnings quiet
task._log_destroy_pending = False
coro.close()
class MyProto(asyncio.Protocol): class MyProto(asyncio.Protocol):
done = None done = None

View File

@ -1968,8 +1968,12 @@ class AbstractEventLoopTests(unittest.TestCase):
NotImplementedError, loop.stop) NotImplementedError, loop.stop)
self.assertRaises( self.assertRaises(
NotImplementedError, loop.is_running) NotImplementedError, loop.is_running)
self.assertRaises(
NotImplementedError, loop.is_closed)
self.assertRaises( self.assertRaises(
NotImplementedError, loop.close) NotImplementedError, loop.close)
self.assertRaises(
NotImplementedError, loop.create_task, None)
self.assertRaises( self.assertRaises(
NotImplementedError, loop.call_later, None, None) NotImplementedError, loop.call_later, None, None)
self.assertRaises( self.assertRaises(
@ -2027,6 +2031,16 @@ class AbstractEventLoopTests(unittest.TestCase):
mock.sentinel) mock.sentinel)
self.assertRaises( self.assertRaises(
NotImplementedError, loop.subprocess_exec, f) NotImplementedError, loop.subprocess_exec, f)
self.assertRaises(
NotImplementedError, loop.set_exception_handler, f)
self.assertRaises(
NotImplementedError, loop.default_exception_handler, f)
self.assertRaises(
NotImplementedError, loop.call_exception_handler, f)
self.assertRaises(
NotImplementedError, loop.get_debug)
self.assertRaises(
NotImplementedError, loop.set_debug, f)
class ProtocolsAbsTests(unittest.TestCase): class ProtocolsAbsTests(unittest.TestCase):

View File

@ -301,12 +301,12 @@ class FutureTests(test_utils.TestCase):
def test_future_exception_never_retrieved(self, m_log): def test_future_exception_never_retrieved(self, m_log):
self.loop.set_debug(True) self.loop.set_debug(True)
def memroy_error(): def memory_error():
try: try:
raise MemoryError() raise MemoryError()
except BaseException as exc: except BaseException as exc:
return exc return exc
exc = memroy_error() exc = memory_error()
future = asyncio.Future(loop=self.loop) future = asyncio.Future(loop=self.loop)
source_traceback = future._source_traceback source_traceback = future._source_traceback

View File

@ -233,6 +233,9 @@ class TaskTests(test_utils.TestCase):
self.assertRegex(repr(task), self.assertRegex(repr(task),
'<Task .* wait_for=%s>' % re.escape(repr(fut))) '<Task .* wait_for=%s>' % re.escape(repr(fut)))
fut.set_result(None)
self.loop.run_until_complete(task)
def test_task_basics(self): def test_task_basics(self):
@asyncio.coroutine @asyncio.coroutine
def outer(): def outer():