diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py index bbeada8772c..1ffac999013 100644 --- a/Lib/asyncio/windows_events.py +++ b/Lib/asyncio/windows_events.py @@ -46,6 +46,22 @@ class _OverlappedFuture(futures.Future): return super().cancel() +class _WaitHandleFuture(futures.Future): + """Subclass of Future which represents a wait handle.""" + + def __init__(self, wait_handle, *, loop=None): + super().__init__(loop=loop) + self._wait_handle = wait_handle + + def cancel(self): + super().cancel() + try: + _overlapped.UnregisterWait(self._wait_handle) + except OSError as e: + if e.winerror != _overlapped.ERROR_IO_PENDING: + raise + + class PipeServer(object): """Class representing a pipe server. @@ -271,6 +287,30 @@ class IocpProactor: return windows_utils.PipeHandle(handle) return self._register(ov, None, finish, wait_for_post=True) + def wait_for_handle(self, handle, timeout=None): + if timeout is None: + ms = _winapi.INFINITE + else: + ms = int(timeout * 1000 + 0.5) + + # We only create ov so we can use ov.address as a key for the cache. + ov = _overlapped.Overlapped(NULL) + wh = _overlapped.RegisterWaitWithQueue( + handle, self._iocp, ov.address, ms) + f = _WaitHandleFuture(wh, loop=self._loop) + + def finish(timed_out, _, ov): + if not f.cancelled(): + try: + _overlapped.UnregisterWait(wh) + except OSError as e: + if e.winerror != _overlapped.ERROR_IO_PENDING: + raise + return not timed_out + + self._cache[ov.address] = (f, ov, None, finish) + return f + def _register_with_iocp(self, obj): # To get notifications of finished ops on this objects sent to the # completion port, were must register the handle. diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py index 969360c1074..553ea34345b 100644 --- a/Lib/test/test_asyncio/test_windows_events.py +++ b/Lib/test/test_asyncio/test_windows_events.py @@ -5,13 +5,17 @@ import unittest if sys.platform != 'win32': raise unittest.SkipTest('Windows only') +import _winapi + import asyncio from asyncio import windows_events +from asyncio import futures from asyncio import protocols from asyncio import streams from asyncio import transports from asyncio import test_utils +from asyncio import _overlapped class UpperProto(protocols.Protocol): @@ -94,6 +98,42 @@ class ProactorTests(unittest.TestCase): return 'done' + def test_wait_for_handle(self): + event = _overlapped.CreateEvent(None, True, False, None) + self.addCleanup(_winapi.CloseHandle, event) + + # Wait for unset event with 0.2s timeout; + # result should be False at timeout + f = self.loop._proactor.wait_for_handle(event, 0.2) + start = self.loop.time() + self.loop.run_until_complete(f) + elapsed = self.loop.time() - start + self.assertFalse(f.result()) + self.assertTrue(0.18 < elapsed < 0.22, elapsed) + + _overlapped.SetEvent(event) + + # Wait for for set event; + # result should be True immediately + f = self.loop._proactor.wait_for_handle(event, 10) + start = self.loop.time() + self.loop.run_until_complete(f) + elapsed = self.loop.time() - start + self.assertTrue(f.result()) + self.assertTrue(0 <= elapsed < 0.02, elapsed) + + _overlapped.ResetEvent(event) + + # Wait for unset event with a cancelled future; + # CancelledError should be raised immediately + f = self.loop._proactor.wait_for_handle(event, 10) + f.cancel() + start = self.loop.time() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(f) + elapsed = self.loop.time() - start + self.assertTrue(0 <= elapsed < 0.02, elapsed) + if __name__ == '__main__': unittest.main() diff --git a/Modules/overlapped.c b/Modules/overlapped.c index 6a1d9e4a6d8..625c76eff4e 100644 --- a/Modules/overlapped.c +++ b/Modules/overlapped.c @@ -227,6 +227,172 @@ overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) Py_RETURN_NONE; } +/* + * Wait for a handle + */ + +struct PostCallbackData { + HANDLE CompletionPort; + LPOVERLAPPED Overlapped; +}; + +static VOID CALLBACK +PostToQueueCallback(PVOID lpParameter, BOOL TimerOrWaitFired) +{ + struct PostCallbackData *p = (struct PostCallbackData*) lpParameter; + + PostQueuedCompletionStatus(p->CompletionPort, TimerOrWaitFired, + 0, p->Overlapped); + /* ignore possible error! */ + PyMem_Free(p); +} + +PyDoc_STRVAR( + RegisterWaitWithQueue_doc, + "RegisterWaitWithQueue(Object, CompletionPort, Overlapped, Timeout)\n" + " -> WaitHandle\n\n" + "Register wait for Object; when complete CompletionPort is notified.\n"); + +static PyObject * +overlapped_RegisterWaitWithQueue(PyObject *self, PyObject *args) +{ + HANDLE NewWaitObject; + HANDLE Object; + ULONG Milliseconds; + struct PostCallbackData data, *pdata; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_POINTER F_DWORD, + &Object, + &data.CompletionPort, + &data.Overlapped, + &Milliseconds)) + return NULL; + + pdata = PyMem_Malloc(sizeof(struct PostCallbackData)); + if (pdata == NULL) + return SetFromWindowsErr(0); + + *pdata = data; + + if (!RegisterWaitForSingleObject( + &NewWaitObject, Object, (WAITORTIMERCALLBACK)PostToQueueCallback, + pdata, Milliseconds, + WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE)) + { + PyMem_Free(pdata); + return SetFromWindowsErr(0); + } + + return Py_BuildValue(F_HANDLE, NewWaitObject); +} + +PyDoc_STRVAR( + UnregisterWait_doc, + "UnregisterWait(WaitHandle) -> None\n\n" + "Unregister wait handle.\n"); + +static PyObject * +overlapped_UnregisterWait(PyObject *self, PyObject *args) +{ + HANDLE WaitHandle; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE, &WaitHandle)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = UnregisterWait(WaitHandle); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Event functions -- currently only used by tests + */ + +PyDoc_STRVAR( + CreateEvent_doc, + "CreateEvent(EventAttributes, ManualReset, InitialState, Name)" + " -> Handle\n\n" + "Create an event. EventAttributes must be None.\n"); + +static PyObject * +overlapped_CreateEvent(PyObject *self, PyObject *args) +{ + PyObject *EventAttributes; + BOOL ManualReset; + BOOL InitialState; + Py_UNICODE *Name; + HANDLE Event; + + if (!PyArg_ParseTuple(args, "O" F_BOOL F_BOOL "Z", + &EventAttributes, &ManualReset, + &InitialState, &Name)) + return NULL; + + if (EventAttributes != Py_None) { + PyErr_SetString(PyExc_ValueError, "EventAttributes must be None"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + Event = CreateEventW(NULL, ManualReset, InitialState, Name); + Py_END_ALLOW_THREADS + + if (Event == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, Event); +} + +PyDoc_STRVAR( + SetEvent_doc, + "SetEvent(Handle) -> None\n\n" + "Set event.\n"); + +static PyObject * +overlapped_SetEvent(PyObject *self, PyObject *args) +{ + HANDLE Handle; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Handle)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = SetEvent(Handle); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + ResetEvent_doc, + "ResetEvent(Handle) -> None\n\n" + "Reset event.\n"); + +static PyObject * +overlapped_ResetEvent(PyObject *self, PyObject *args) +{ + HANDLE Handle; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Handle)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = ResetEvent(Handle); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + /* * Bind socket handle to local port without doing slow getaddrinfo() */ @@ -1147,6 +1313,16 @@ static PyMethodDef overlapped_functions[] = { METH_VARARGS, FormatMessage_doc}, {"BindLocal", overlapped_BindLocal, METH_VARARGS, BindLocal_doc}, + {"RegisterWaitWithQueue", overlapped_RegisterWaitWithQueue, + METH_VARARGS, RegisterWaitWithQueue_doc}, + {"UnregisterWait", overlapped_UnregisterWait, + METH_VARARGS, UnregisterWait_doc}, + {"CreateEvent", overlapped_CreateEvent, + METH_VARARGS, CreateEvent_doc}, + {"SetEvent", overlapped_SetEvent, + METH_VARARGS, SetEvent_doc}, + {"ResetEvent", overlapped_ResetEvent, + METH_VARARGS, ResetEvent_doc}, {NULL} };