asyncio: Add Task.current_task() class method.

This commit is contained in:
Guido van Rossum 2013-12-06 12:57:40 -08:00
parent 2f8c83568c
commit 1a605ed5a3
3 changed files with 57 additions and 1 deletions

View File

@ -122,6 +122,22 @@ class Task(futures.Future):
# Weak set containing all tasks alive.
_all_tasks = weakref.WeakSet()
# Dictionary containing tasks that are currently active in
# all running event loops. {EventLoop: Task}
_current_tasks = {}
@classmethod
def current_task(cls, loop=None):
"""Return the currently running task in an event loop or None.
By default the current task for the current event loop is returned.
None is returned when called not in the context of a Task.
"""
if loop is None:
loop = events.get_event_loop()
return cls._current_tasks.get(loop)
@classmethod
def all_tasks(cls, loop=None):
"""Return a set of all tasks for an event loop.
@ -252,6 +268,8 @@ class Task(futures.Future):
self._must_cancel = False
coro = self._coro
self._fut_waiter = None
self.__class__._current_tasks[self._loop] = self
# Call either coro.throw(exc) or coro.send(value).
try:
if exc is not None:
@ -302,6 +320,8 @@ class Task(futures.Future):
self._step, None,
RuntimeError(
'Task got bad yield: {!r}'.format(result)))
finally:
self.__class__._current_tasks.pop(self._loop)
self = None
def _wakeup(self, future):

View File

@ -88,7 +88,7 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
class SSLWSGIServer(SilentWSGIServer):
def finish_request(self, request, client_address):
# The relative location of our test directory (which
# contains the sample key and certificate files) differs
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone Tulip/asyncio.
# Prefer our own if we can find it.
here = os.path.join(os.path.dirname(__file__), '..', 'tests')

View File

@ -1113,6 +1113,42 @@ class TaskTests(unittest.TestCase):
self.assertEqual(res, 'test')
self.assertIsNone(t2.result())
def test_current_task(self):
self.assertIsNone(tasks.Task.current_task(loop=self.loop))
@tasks.coroutine
def coro(loop):
self.assertTrue(tasks.Task.current_task(loop=loop) is task)
task = tasks.Task(coro(self.loop), loop=self.loop)
self.loop.run_until_complete(task)
self.assertIsNone(tasks.Task.current_task(loop=self.loop))
def test_current_task_with_interleaving_tasks(self):
self.assertIsNone(tasks.Task.current_task(loop=self.loop))
fut1 = futures.Future(loop=self.loop)
fut2 = futures.Future(loop=self.loop)
@tasks.coroutine
def coro1(loop):
self.assertTrue(tasks.Task.current_task(loop=loop) is task1)
yield from fut1
self.assertTrue(tasks.Task.current_task(loop=loop) is task1)
fut2.set_result(True)
@tasks.coroutine
def coro2(loop):
self.assertTrue(tasks.Task.current_task(loop=loop) is task2)
fut1.set_result(True)
yield from fut2
self.assertTrue(tasks.Task.current_task(loop=loop) is task2)
task1 = tasks.Task(coro1(self.loop), loop=self.loop)
task2 = tasks.Task(coro2(self.loop), loop=self.loop)
self.loop.run_until_complete(tasks.wait((task1, task2), loop=self.loop))
self.assertIsNone(tasks.Task.current_task(loop=self.loop))
# Some thorough tests for cancellation propagation through
# coroutines, tasks and wait().