asyncio: Add Task.current_task() class method.
This commit is contained in:
parent
2f8c83568c
commit
1a605ed5a3
|
@ -122,6 +122,22 @@ class Task(futures.Future):
|
|||
# Weak set containing all tasks alive.
|
||||
_all_tasks = weakref.WeakSet()
|
||||
|
||||
# Dictionary containing tasks that are currently active in
|
||||
# all running event loops. {EventLoop: Task}
|
||||
_current_tasks = {}
|
||||
|
||||
@classmethod
|
||||
def current_task(cls, loop=None):
|
||||
"""Return the currently running task in an event loop or None.
|
||||
|
||||
By default the current task for the current event loop is returned.
|
||||
|
||||
None is returned when called not in the context of a Task.
|
||||
"""
|
||||
if loop is None:
|
||||
loop = events.get_event_loop()
|
||||
return cls._current_tasks.get(loop)
|
||||
|
||||
@classmethod
|
||||
def all_tasks(cls, loop=None):
|
||||
"""Return a set of all tasks for an event loop.
|
||||
|
@ -252,6 +268,8 @@ class Task(futures.Future):
|
|||
self._must_cancel = False
|
||||
coro = self._coro
|
||||
self._fut_waiter = None
|
||||
|
||||
self.__class__._current_tasks[self._loop] = self
|
||||
# Call either coro.throw(exc) or coro.send(value).
|
||||
try:
|
||||
if exc is not None:
|
||||
|
@ -302,6 +320,8 @@ class Task(futures.Future):
|
|||
self._step, None,
|
||||
RuntimeError(
|
||||
'Task got bad yield: {!r}'.format(result)))
|
||||
finally:
|
||||
self.__class__._current_tasks.pop(self._loop)
|
||||
self = None
|
||||
|
||||
def _wakeup(self, future):
|
||||
|
|
|
@ -88,7 +88,7 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
|
|||
class SSLWSGIServer(SilentWSGIServer):
|
||||
def finish_request(self, request, client_address):
|
||||
# The relative location of our test directory (which
|
||||
# contains the sample key and certificate files) differs
|
||||
# contains the ssl key and certificate files) differs
|
||||
# between the stdlib and stand-alone Tulip/asyncio.
|
||||
# Prefer our own if we can find it.
|
||||
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
|
||||
|
|
|
@ -1113,6 +1113,42 @@ class TaskTests(unittest.TestCase):
|
|||
self.assertEqual(res, 'test')
|
||||
self.assertIsNone(t2.result())
|
||||
|
||||
def test_current_task(self):
|
||||
self.assertIsNone(tasks.Task.current_task(loop=self.loop))
|
||||
@tasks.coroutine
|
||||
def coro(loop):
|
||||
self.assertTrue(tasks.Task.current_task(loop=loop) is task)
|
||||
|
||||
task = tasks.Task(coro(self.loop), loop=self.loop)
|
||||
self.loop.run_until_complete(task)
|
||||
self.assertIsNone(tasks.Task.current_task(loop=self.loop))
|
||||
|
||||
def test_current_task_with_interleaving_tasks(self):
|
||||
self.assertIsNone(tasks.Task.current_task(loop=self.loop))
|
||||
|
||||
fut1 = futures.Future(loop=self.loop)
|
||||
fut2 = futures.Future(loop=self.loop)
|
||||
|
||||
@tasks.coroutine
|
||||
def coro1(loop):
|
||||
self.assertTrue(tasks.Task.current_task(loop=loop) is task1)
|
||||
yield from fut1
|
||||
self.assertTrue(tasks.Task.current_task(loop=loop) is task1)
|
||||
fut2.set_result(True)
|
||||
|
||||
@tasks.coroutine
|
||||
def coro2(loop):
|
||||
self.assertTrue(tasks.Task.current_task(loop=loop) is task2)
|
||||
fut1.set_result(True)
|
||||
yield from fut2
|
||||
self.assertTrue(tasks.Task.current_task(loop=loop) is task2)
|
||||
|
||||
task1 = tasks.Task(coro1(self.loop), loop=self.loop)
|
||||
task2 = tasks.Task(coro2(self.loop), loop=self.loop)
|
||||
|
||||
self.loop.run_until_complete(tasks.wait((task1, task2), loop=self.loop))
|
||||
self.assertIsNone(tasks.Task.current_task(loop=self.loop))
|
||||
|
||||
# Some thorough tests for cancellation propagation through
|
||||
# coroutines, tasks and wait().
|
||||
|
||||
|
|
Loading…
Reference in New Issue