From f23746a934177c48eff754411aba54c31d6be2f0 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 22 Jan 2018 19:11:18 -0500 Subject: [PATCH] bpo-32436: Implement PEP 567 (#5027) --- Include/Python.h | 1 + Include/context.h | 86 + Include/internal/context.h | 41 + Include/internal/hamt.h | 113 + Include/pystate.h | 8 + Lib/asyncio/base_events.py | 21 +- Lib/asyncio/base_futures.py | 8 +- Lib/asyncio/events.py | 15 +- Lib/asyncio/futures.py | 17 +- Lib/asyncio/selector_events.py | 4 +- Lib/asyncio/tasks.py | 24 +- Lib/asyncio/unix_events.py | 2 +- Lib/contextvars.py | 4 + Lib/test/test_asyncio/test_base_events.py | 10 +- Lib/test/test_asyncio/test_futures.py | 14 +- Lib/test/test_asyncio/test_tasks.py | 109 +- Lib/test/test_asyncio/utils.py | 8 +- Lib/test/test_context.py | 1064 ++++++ Makefile.pre.in | 4 + .../2017-12-28-00-20-42.bpo-32436.H159Jv.rst | 1 + Modules/Setup.dist | 1 + Modules/_asynciomodule.c | 209 +- Modules/_contextvarsmodule.c | 75 + Modules/_testcapimodule.c | 8 + Modules/clinic/_asynciomodule.c.h | 29 +- Modules/clinic/_contextvarsmodule.c.h | 21 + Modules/gcmodule.c | 2 + Objects/object.c | 1 + PCbuild/_contextvars.vcxproj | 77 + PCbuild/_contextvars.vcxproj.filters | 16 + PCbuild/_decimal.vcxproj | 2 +- PCbuild/pcbuild.proj | 2 +- PCbuild/pythoncore.vcxproj | 6 + PCbuild/pythoncore.vcxproj.filters | 18 + Python/clinic/context.c.h | 146 + Python/context.c | 1220 +++++++ Python/hamt.c | 2982 +++++++++++++++++ Python/pylifecycle.c | 6 + Python/pystate.c | 9 + Tools/msi/lib/lib_files.wxs | 2 +- setup.py | 3 + 41 files changed, 6269 insertions(+), 120 deletions(-) create mode 100644 Include/context.h create mode 100644 Include/internal/context.h create mode 100644 Include/internal/hamt.h create mode 100644 Lib/contextvars.py create mode 100644 Lib/test/test_context.py create mode 100644 Misc/NEWS.d/next/Core and Builtins/2017-12-28-00-20-42.bpo-32436.H159Jv.rst create mode 100644 Modules/_contextvarsmodule.c create mode 100644 Modules/clinic/_contextvarsmodule.c.h create mode 100644 PCbuild/_contextvars.vcxproj create mode 100644 PCbuild/_contextvars.vcxproj.filters create mode 100644 Python/clinic/context.c.h create mode 100644 Python/context.c create mode 100644 Python/hamt.c diff --git a/Include/Python.h b/Include/Python.h index 20051e7ff08..dd595ea5e4c 100644 --- a/Include/Python.h +++ b/Include/Python.h @@ -109,6 +109,7 @@ #include "pyerrors.h" #include "pystate.h" +#include "context.h" #include "pyarena.h" #include "modsupport.h" diff --git a/Include/context.h b/Include/context.h new file mode 100644 index 00000000000..f872dceee0c --- /dev/null +++ b/Include/context.h @@ -0,0 +1,86 @@ +#ifndef Py_CONTEXT_H +#define Py_CONTEXT_H +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef Py_LIMITED_API + + +PyAPI_DATA(PyTypeObject) PyContext_Type; +typedef struct _pycontextobject PyContext; + +PyAPI_DATA(PyTypeObject) PyContextVar_Type; +typedef struct _pycontextvarobject PyContextVar; + +PyAPI_DATA(PyTypeObject) PyContextToken_Type; +typedef struct _pycontexttokenobject PyContextToken; + + +#define PyContext_CheckExact(o) (Py_TYPE(o) == &PyContext_Type) +#define PyContextVar_CheckExact(o) (Py_TYPE(o) == &PyContextVar_Type) +#define PyContextToken_CheckExact(o) (Py_TYPE(o) == &PyContextToken_Type) + + +PyAPI_FUNC(PyContext *) PyContext_New(void); +PyAPI_FUNC(PyContext *) PyContext_Copy(PyContext *); +PyAPI_FUNC(PyContext *) PyContext_CopyCurrent(void); + +PyAPI_FUNC(int) PyContext_Enter(PyContext *); +PyAPI_FUNC(int) PyContext_Exit(PyContext *); + + +/* Create a new context variable. + + default_value can be NULL. +*/ +PyAPI_FUNC(PyContextVar *) PyContextVar_New( + const char *name, PyObject *default_value); + + +/* Get a value for the variable. + + Returns -1 if an error occurred during lookup. + + Returns 0 if value either was or was not found. + + If value was found, *value will point to it. + If not, it will point to: + + - default_value, if not NULL; + - the default value of "var", if not NULL; + - NULL. + + '*value' will be a new ref, if not NULL. +*/ +PyAPI_FUNC(int) PyContextVar_Get( + PyContextVar *var, PyObject *default_value, PyObject **value); + + +/* Set a new value for the variable. + Returns NULL if an error occurs. +*/ +PyAPI_FUNC(PyContextToken *) PyContextVar_Set( + PyContextVar *var, PyObject *value); + + +/* Reset a variable to its previous value. + Returns 0 on sucess, -1 on error. +*/ +PyAPI_FUNC(int) PyContextVar_Reset( + PyContextVar *var, PyContextToken *token); + + +/* This method is exposed only for CPython tests. Don not use it. */ +PyAPI_FUNC(PyObject *) _PyContext_NewHamtForTests(void); + + +PyAPI_FUNC(int) PyContext_ClearFreeList(void); + + +#endif /* !Py_LIMITED_API */ + +#ifdef __cplusplus +} +#endif +#endif /* !Py_CONTEXT_H */ diff --git a/Include/internal/context.h b/Include/internal/context.h new file mode 100644 index 00000000000..59f88f2614e --- /dev/null +++ b/Include/internal/context.h @@ -0,0 +1,41 @@ +#ifndef Py_INTERNAL_CONTEXT_H +#define Py_INTERNAL_CONTEXT_H + + +#include "internal/hamt.h" + + +struct _pycontextobject { + PyObject_HEAD + PyContext *ctx_prev; + PyHamtObject *ctx_vars; + PyObject *ctx_weakreflist; + int ctx_entered; +}; + + +struct _pycontextvarobject { + PyObject_HEAD + PyObject *var_name; + PyObject *var_default; + PyObject *var_cached; + uint64_t var_cached_tsid; + uint64_t var_cached_tsver; + Py_hash_t var_hash; +}; + + +struct _pycontexttokenobject { + PyObject_HEAD + PyContext *tok_ctx; + PyContextVar *tok_var; + PyObject *tok_oldval; + int tok_used; +}; + + +int _PyContext_Init(void); +void _PyContext_Fini(void); + + +#endif /* !Py_INTERNAL_CONTEXT_H */ diff --git a/Include/internal/hamt.h b/Include/internal/hamt.h new file mode 100644 index 00000000000..52488d0858d --- /dev/null +++ b/Include/internal/hamt.h @@ -0,0 +1,113 @@ +#ifndef Py_INTERNAL_HAMT_H +#define Py_INTERNAL_HAMT_H + + +#define _Py_HAMT_MAX_TREE_DEPTH 7 + + +#define PyHamt_Check(o) (Py_TYPE(o) == &_PyHamt_Type) + + +/* Abstract tree node. */ +typedef struct { + PyObject_HEAD +} PyHamtNode; + + +/* An HAMT immutable mapping collection. */ +typedef struct { + PyObject_HEAD + PyHamtNode *h_root; + PyObject *h_weakreflist; + Py_ssize_t h_count; +} PyHamtObject; + + +/* A struct to hold the state of depth-first traverse of the tree. + + HAMT is an immutable collection. Iterators will hold a strong reference + to it, and every node in the HAMT has strong references to its children. + + So for iterators, we can implement zero allocations and zero reference + inc/dec depth-first iteration. + + - i_nodes: an array of seven pointers to tree nodes + - i_level: the current node in i_nodes + - i_pos: an array of positions within nodes in i_nodes. +*/ +typedef struct { + PyHamtNode *i_nodes[_Py_HAMT_MAX_TREE_DEPTH]; + Py_ssize_t i_pos[_Py_HAMT_MAX_TREE_DEPTH]; + int8_t i_level; +} PyHamtIteratorState; + + +/* Base iterator object. + + Contains the iteration state, a pointer to the HAMT tree, + and a pointer to the 'yield function'. The latter is a simple + function that returns a key/value tuple for the 'Items' iterator, + just a key for the 'Keys' iterator, and a value for the 'Values' + iterator. +*/ +typedef struct { + PyObject_HEAD + PyHamtObject *hi_obj; + PyHamtIteratorState hi_iter; + binaryfunc hi_yield; +} PyHamtIterator; + + +PyAPI_DATA(PyTypeObject) _PyHamt_Type; +PyAPI_DATA(PyTypeObject) _PyHamt_ArrayNode_Type; +PyAPI_DATA(PyTypeObject) _PyHamt_BitmapNode_Type; +PyAPI_DATA(PyTypeObject) _PyHamt_CollisionNode_Type; +PyAPI_DATA(PyTypeObject) _PyHamtKeys_Type; +PyAPI_DATA(PyTypeObject) _PyHamtValues_Type; +PyAPI_DATA(PyTypeObject) _PyHamtItems_Type; + + +/* Create a new HAMT immutable mapping. */ +PyHamtObject * _PyHamt_New(void); + +/* Return a new collection based on "o", but with an additional + key/val pair. */ +PyHamtObject * _PyHamt_Assoc(PyHamtObject *o, PyObject *key, PyObject *val); + +/* Return a new collection based on "o", but without "key". */ +PyHamtObject * _PyHamt_Without(PyHamtObject *o, PyObject *key); + +/* Find "key" in the "o" collection. + + Return: + - -1: An error ocurred. + - 0: "key" wasn't found in "o". + - 1: "key" is in "o"; "*val" is set to its value (a borrowed ref). +*/ +int _PyHamt_Find(PyHamtObject *o, PyObject *key, PyObject **val); + +/* Check if "v" is equal to "w". + + Return: + - 0: v != w + - 1: v == w + - -1: An error occurred. +*/ +int _PyHamt_Eq(PyHamtObject *v, PyHamtObject *w); + +/* Return the size of "o"; equivalent of "len(o)". */ +Py_ssize_t _PyHamt_Len(PyHamtObject *o); + +/* Return a Keys iterator over "o". */ +PyObject * _PyHamt_NewIterKeys(PyHamtObject *o); + +/* Return a Values iterator over "o". */ +PyObject * _PyHamt_NewIterValues(PyHamtObject *o); + +/* Return a Items iterator over "o". */ +PyObject * _PyHamt_NewIterItems(PyHamtObject *o); + +int _PyHamt_Init(void); +void _PyHamt_Fini(void); + +#endif /* !Py_INTERNAL_HAMT_H */ diff --git a/Include/pystate.h b/Include/pystate.h index 5a69e1471a0..d004be5e906 100644 --- a/Include/pystate.h +++ b/Include/pystate.h @@ -143,6 +143,8 @@ typedef struct _is { /* AtExit module */ void (*pyexitfunc)(PyObject *); PyObject *pyexitmodule; + + uint64_t tstate_next_unique_id; } PyInterpreterState; #endif /* !Py_LIMITED_API */ @@ -270,6 +272,12 @@ typedef struct _ts { PyObject *async_gen_firstiter; PyObject *async_gen_finalizer; + PyObject *context; + uint64_t context_ver; + + /* Unique thread state id. */ + uint64_t id; + /* XXX signal handlers should also be here */ } PyThreadState; diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index ca9eee765e3..e722cf26b51 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -489,7 +489,7 @@ class BaseEventLoop(events.AbstractEventLoop): """ return time.monotonic() - def call_later(self, delay, callback, *args): + def call_later(self, delay, callback, *args, context=None): """Arrange for a callback to be called at a given time. Return a Handle: an opaque object with a cancel() method that @@ -505,12 +505,13 @@ class BaseEventLoop(events.AbstractEventLoop): Any positional arguments after the callback will be passed to the callback when it is called. """ - timer = self.call_at(self.time() + delay, callback, *args) + timer = self.call_at(self.time() + delay, callback, *args, + context=context) if timer._source_traceback: del timer._source_traceback[-1] return timer - def call_at(self, when, callback, *args): + def call_at(self, when, callback, *args, context=None): """Like call_later(), but uses an absolute time. Absolute time corresponds to the event loop's time() method. @@ -519,14 +520,14 @@ class BaseEventLoop(events.AbstractEventLoop): if self._debug: self._check_thread() self._check_callback(callback, 'call_at') - timer = events.TimerHandle(when, callback, args, self) + timer = events.TimerHandle(when, callback, args, self, context) if timer._source_traceback: del timer._source_traceback[-1] heapq.heappush(self._scheduled, timer) timer._scheduled = True return timer - def call_soon(self, callback, *args): + def call_soon(self, callback, *args, context=None): """Arrange for a callback to be called as soon as possible. This operates as a FIFO queue: callbacks are called in the @@ -540,7 +541,7 @@ class BaseEventLoop(events.AbstractEventLoop): if self._debug: self._check_thread() self._check_callback(callback, 'call_soon') - handle = self._call_soon(callback, args) + handle = self._call_soon(callback, args, context) if handle._source_traceback: del handle._source_traceback[-1] return handle @@ -555,8 +556,8 @@ class BaseEventLoop(events.AbstractEventLoop): f'a callable object was expected by {method}(), ' f'got {callback!r}') - def _call_soon(self, callback, args): - handle = events.Handle(callback, args, self) + def _call_soon(self, callback, args, context): + handle = events.Handle(callback, args, self, context) if handle._source_traceback: del handle._source_traceback[-1] self._ready.append(handle) @@ -579,12 +580,12 @@ class BaseEventLoop(events.AbstractEventLoop): "Non-thread-safe operation invoked on an event loop other " "than the current one") - def call_soon_threadsafe(self, callback, *args): + def call_soon_threadsafe(self, callback, *args, context=None): """Like call_soon(), but thread-safe.""" self._check_closed() if self._debug: self._check_callback(callback, 'call_soon_threadsafe') - handle = self._call_soon(callback, args) + handle = self._call_soon(callback, args, context) if handle._source_traceback: del handle._source_traceback[-1] self._write_to_self() diff --git a/Lib/asyncio/base_futures.py b/Lib/asyncio/base_futures.py index 008812eda91..5182884e16d 100644 --- a/Lib/asyncio/base_futures.py +++ b/Lib/asyncio/base_futures.py @@ -41,13 +41,13 @@ def _format_callbacks(cb): return format_helpers._format_callback_source(callback, ()) if size == 1: - cb = format_cb(cb[0]) + cb = format_cb(cb[0][0]) elif size == 2: - cb = '{}, {}'.format(format_cb(cb[0]), format_cb(cb[1])) + cb = '{}, {}'.format(format_cb(cb[0][0]), format_cb(cb[1][0])) elif size > 2: - cb = '{}, <{} more>, {}'.format(format_cb(cb[0]), + cb = '{}, <{} more>, {}'.format(format_cb(cb[0][0]), size - 2, - format_cb(cb[-1])) + format_cb(cb[-1][0])) return f'cb=[{cb}]' diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index d5365dc480d..5c68d4cb97d 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -11,6 +11,7 @@ __all__ = ( '_get_running_loop', ) +import contextvars import os import socket import subprocess @@ -32,9 +33,13 @@ class Handle: """Object returned by callback registration methods.""" __slots__ = ('_callback', '_args', '_cancelled', '_loop', - '_source_traceback', '_repr', '__weakref__') + '_source_traceback', '_repr', '__weakref__', + '_context') - def __init__(self, callback, args, loop): + def __init__(self, callback, args, loop, context=None): + if context is None: + context = contextvars.copy_context() + self._context = context self._loop = loop self._callback = callback self._args = args @@ -80,7 +85,7 @@ class Handle: def _run(self): try: - self._callback(*self._args) + self._context.run(self._callback, *self._args) except Exception as exc: cb = format_helpers._format_callback_source( self._callback, self._args) @@ -101,9 +106,9 @@ class TimerHandle(Handle): __slots__ = ['_scheduled', '_when'] - def __init__(self, when, callback, args, loop): + def __init__(self, when, callback, args, loop, context=None): assert when is not None - super().__init__(callback, args, loop) + super().__init__(callback, args, loop, context) if self._source_traceback: del self._source_traceback[-1] self._when = when diff --git a/Lib/asyncio/futures.py b/Lib/asyncio/futures.py index 1c05b2231c1..59621ffb6e7 100644 --- a/Lib/asyncio/futures.py +++ b/Lib/asyncio/futures.py @@ -6,6 +6,7 @@ __all__ = ( ) import concurrent.futures +import contextvars import logging import sys @@ -144,8 +145,8 @@ class Future: return self._callbacks[:] = [] - for callback in callbacks: - self._loop.call_soon(callback, self) + for callback, ctx in callbacks: + self._loop.call_soon(callback, self, context=ctx) def cancelled(self): """Return True if the future was cancelled.""" @@ -192,7 +193,7 @@ class Future: self.__log_traceback = False return self._exception - def add_done_callback(self, fn): + def add_done_callback(self, fn, *, context=None): """Add a callback to be run when the future becomes done. The callback is called with a single argument - the future object. If @@ -200,9 +201,11 @@ class Future: scheduled with call_soon. """ if self._state != _PENDING: - self._loop.call_soon(fn, self) + self._loop.call_soon(fn, self, context=context) else: - self._callbacks.append(fn) + if context is None: + context = contextvars.copy_context() + self._callbacks.append((fn, context)) # New method not in PEP 3148. @@ -211,7 +214,9 @@ class Future: Returns the number of callbacks removed. """ - filtered_callbacks = [f for f in self._callbacks if f != fn] + filtered_callbacks = [(f, ctx) + for (f, ctx) in self._callbacks + if f != fn] removed_count = len(self._callbacks) - len(filtered_callbacks) if removed_count: self._callbacks[:] = filtered_callbacks diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 5692e38486a..9446ae6a3bc 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -256,7 +256,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): def _add_reader(self, fd, callback, *args): self._check_closed() - handle = events.Handle(callback, args, self) + handle = events.Handle(callback, args, self, None) try: key = self._selector.get_key(fd) except KeyError: @@ -292,7 +292,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): def _add_writer(self, fd, callback, *args): self._check_closed() - handle = events.Handle(callback, args, self) + handle = events.Handle(callback, args, self, None) try: key = self._selector.get_key(fd) except KeyError: diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index b11808853e2..609b8e8a048 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -10,6 +10,7 @@ __all__ = ( ) import concurrent.futures +import contextvars import functools import inspect import types @@ -96,8 +97,9 @@ class Task(futures._PyFuture): # Inherit Python Task implementation self._must_cancel = False self._fut_waiter = None self._coro = coro + self._context = contextvars.copy_context() - self._loop.call_soon(self._step) + self._loop.call_soon(self._step, context=self._context) _register_task(self) def __del__(self): @@ -229,15 +231,18 @@ class Task(futures._PyFuture): # Inherit Python Task implementation new_exc = RuntimeError( f'Task {self!r} got Future ' f'{result!r} attached to a different loop') - self._loop.call_soon(self._step, new_exc) + self._loop.call_soon( + self._step, new_exc, context=self._context) elif blocking: if result is self: new_exc = RuntimeError( f'Task cannot await on itself: {self!r}') - self._loop.call_soon(self._step, new_exc) + self._loop.call_soon( + self._step, new_exc, context=self._context) else: result._asyncio_future_blocking = False - result.add_done_callback(self._wakeup) + result.add_done_callback( + self._wakeup, context=self._context) self._fut_waiter = result if self._must_cancel: if self._fut_waiter.cancel(): @@ -246,21 +251,24 @@ class Task(futures._PyFuture): # Inherit Python Task implementation new_exc = RuntimeError( f'yield was used instead of yield from ' f'in task {self!r} with {result!r}') - self._loop.call_soon(self._step, new_exc) + self._loop.call_soon( + self._step, new_exc, context=self._context) elif result is None: # Bare yield relinquishes control for one event loop iteration. - self._loop.call_soon(self._step) + self._loop.call_soon(self._step, context=self._context) elif inspect.isgenerator(result): # Yielding a generator is just wrong. new_exc = RuntimeError( f'yield was used instead of yield from for ' f'generator in task {self!r} with {result}') - self._loop.call_soon(self._step, new_exc) + self._loop.call_soon( + self._step, new_exc, context=self._context) else: # Yielding something else is an error. new_exc = RuntimeError(f'Task got bad yield: {result!r}') - self._loop.call_soon(self._step, new_exc) + self._loop.call_soon( + self._step, new_exc, context=self._context) finally: _leave_task(self._loop, self) self = None # Needed to break cycles when an exception occurs. diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index 028a0ca8f83..9b9d0043b50 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -92,7 +92,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): except (ValueError, OSError) as exc: raise RuntimeError(str(exc)) - handle = events.Handle(callback, args, self) + handle = events.Handle(callback, args, self, None) self._signal_handlers[sig] = handle try: diff --git a/Lib/contextvars.py b/Lib/contextvars.py new file mode 100644 index 00000000000..d78c80dfe6f --- /dev/null +++ b/Lib/contextvars.py @@ -0,0 +1,4 @@ +from _contextvars import Context, ContextVar, Token, copy_context + + +__all__ = ('Context', 'ContextVar', 'Token', 'copy_context') diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index fc3b81096da..8d72df6a72e 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -192,14 +192,14 @@ class BaseEventLoopTests(test_utils.TestCase): self.assertRaises(RuntimeError, self.loop.run_until_complete, f) def test__add_callback_handle(self): - h = asyncio.Handle(lambda: False, (), self.loop) + h = asyncio.Handle(lambda: False, (), self.loop, None) self.loop._add_callback(h) self.assertFalse(self.loop._scheduled) self.assertIn(h, self.loop._ready) def test__add_callback_cancelled_handle(self): - h = asyncio.Handle(lambda: False, (), self.loop) + h = asyncio.Handle(lambda: False, (), self.loop, None) h.cancel() self.loop._add_callback(h) @@ -333,9 +333,9 @@ class BaseEventLoopTests(test_utils.TestCase): def test__run_once(self): h1 = asyncio.TimerHandle(time.monotonic() + 5.0, lambda: True, (), - self.loop) + self.loop, None) h2 = asyncio.TimerHandle(time.monotonic() + 10.0, lambda: True, (), - self.loop) + self.loop, None) h1.cancel() @@ -390,7 +390,7 @@ class BaseEventLoopTests(test_utils.TestCase): handle = loop.call_soon(lambda: True) h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,), - self.loop) + self.loop, None) self.loop._process_events = mock.Mock() self.loop._scheduled.append(h) diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py index ab45ee39ab9..37f4c6562fe 100644 --- a/Lib/test/test_asyncio/test_futures.py +++ b/Lib/test/test_asyncio/test_futures.py @@ -565,16 +565,22 @@ class BaseFutureTests: @unittest.skipUnless(hasattr(futures, '_CFuture'), 'requires the C _asyncio module') class CFutureTests(BaseFutureTests, test_utils.TestCase): - cls = futures._CFuture + try: + cls = futures._CFuture + except AttributeError: + cls = None @unittest.skipUnless(hasattr(futures, '_CFuture'), 'requires the C _asyncio module') class CSubFutureTests(BaseFutureTests, test_utils.TestCase): - class CSubFuture(futures._CFuture): - pass + try: + class CSubFuture(futures._CFuture): + pass - cls = CSubFuture + cls = CSubFuture + except AttributeError: + cls = None class PyFutureTests(BaseFutureTests, test_utils.TestCase): diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index 26e4f643d1a..96d2658cb4c 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -2,10 +2,11 @@ import collections import contextlib +import contextvars import functools import gc import io -import os +import random import re import sys import types @@ -1377,9 +1378,9 @@ class BaseTaskTests: self.cb_added = False super().__init__(*args, **kwds) - def add_done_callback(self, fn): + def add_done_callback(self, *args, **kwargs): self.cb_added = True - super().add_done_callback(fn) + super().add_done_callback(*args, **kwargs) fut = Fut(loop=self.loop) result = None @@ -2091,7 +2092,7 @@ class BaseTaskTests: @mock.patch('asyncio.base_events.logger') def test_error_in_call_soon(self, m_log): - def call_soon(callback, *args): + def call_soon(callback, *args, **kwargs): raise ValueError self.loop.call_soon = call_soon @@ -2176,6 +2177,91 @@ class BaseTaskTests: self.loop.run_until_complete(coro()) + def test_context_1(self): + cvar = contextvars.ContextVar('cvar', default='nope') + + async def sub(): + await asyncio.sleep(0.01, loop=loop) + self.assertEqual(cvar.get(), 'nope') + cvar.set('something else') + + async def main(): + self.assertEqual(cvar.get(), 'nope') + subtask = self.new_task(loop, sub()) + cvar.set('yes') + self.assertEqual(cvar.get(), 'yes') + await subtask + self.assertEqual(cvar.get(), 'yes') + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + loop.run_until_complete(task) + finally: + loop.close() + + def test_context_2(self): + cvar = contextvars.ContextVar('cvar', default='nope') + + async def main(): + def fut_on_done(fut): + # This change must not pollute the context + # of the "main()" task. + cvar.set('something else') + + self.assertEqual(cvar.get(), 'nope') + + for j in range(2): + fut = self.new_future(loop) + fut.add_done_callback(fut_on_done) + cvar.set(f'yes{j}') + loop.call_soon(fut.set_result, None) + await fut + self.assertEqual(cvar.get(), f'yes{j}') + + for i in range(3): + # Test that task passed its context to add_done_callback: + cvar.set(f'yes{i}-{j}') + await asyncio.sleep(0.001, loop=loop) + self.assertEqual(cvar.get(), f'yes{i}-{j}') + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual(cvar.get(), 'nope') + + def test_context_3(self): + # Run 100 Tasks in parallel, each modifying cvar. + + cvar = contextvars.ContextVar('cvar', default=-1) + + async def sub(num): + for i in range(10): + cvar.set(num + i) + await asyncio.sleep( + random.uniform(0.001, 0.05), loop=loop) + self.assertEqual(cvar.get(), num + i) + + async def main(): + tasks = [] + for i in range(100): + task = loop.create_task(sub(random.randint(0, 10))) + tasks.append(task) + + await asyncio.gather(*tasks, loop=loop) + + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(main()) + finally: + loop.close() + + self.assertEqual(cvar.get(), -1) + def add_subclass_tests(cls): BaseTask = cls.Task @@ -2193,9 +2279,9 @@ def add_subclass_tests(cls): self.calls['_schedule_callbacks'] += 1 return super()._schedule_callbacks() - def add_done_callback(self, *args): + def add_done_callback(self, *args, **kwargs): self.calls['add_done_callback'] += 1 - return super().add_done_callback(*args) + return super().add_done_callback(*args, **kwargs) class Task(CommonFuture, BaseTask): def _step(self, *args): @@ -2486,10 +2572,13 @@ class PyIntrospectionTests(unittest.TestCase, BaseTaskIntrospectionTests): @unittest.skipUnless(hasattr(tasks, '_c_register_task'), 'requires the C _asyncio module') class CIntrospectionTests(unittest.TestCase, BaseTaskIntrospectionTests): - _register_task = staticmethod(tasks._c_register_task) - _unregister_task = staticmethod(tasks._c_unregister_task) - _enter_task = staticmethod(tasks._c_enter_task) - _leave_task = staticmethod(tasks._c_leave_task) + if hasattr(tasks, '_c_register_task'): + _register_task = staticmethod(tasks._c_register_task) + _unregister_task = staticmethod(tasks._c_unregister_task) + _enter_task = staticmethod(tasks._c_enter_task) + _leave_task = staticmethod(tasks._c_leave_task) + else: + _register_task = _unregister_task = _enter_task = _leave_task = None class BaseCurrentLoopTests: diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py index f756ec9016f..96dfe2f85b4 100644 --- a/Lib/test/test_asyncio/utils.py +++ b/Lib/test/test_asyncio/utils.py @@ -365,7 +365,7 @@ class TestLoop(base_events.BaseEventLoop): raise AssertionError("Time generator is not finished") def _add_reader(self, fd, callback, *args): - self.readers[fd] = events.Handle(callback, args, self) + self.readers[fd] = events.Handle(callback, args, self, None) def _remove_reader(self, fd): self.remove_reader_count[fd] += 1 @@ -391,7 +391,7 @@ class TestLoop(base_events.BaseEventLoop): raise AssertionError(f'fd {fd} is registered') def _add_writer(self, fd, callback, *args): - self.writers[fd] = events.Handle(callback, args, self) + self.writers[fd] = events.Handle(callback, args, self, None) def _remove_writer(self, fd): self.remove_writer_count[fd] += 1 @@ -457,9 +457,9 @@ class TestLoop(base_events.BaseEventLoop): self.advance_time(advance) self._timers = [] - def call_at(self, when, callback, *args): + def call_at(self, when, callback, *args, context=None): self._timers.append(when) - return super().call_at(when, callback, *args) + return super().call_at(when, callback, *args, context=context) def _process_events(self, event_list): return diff --git a/Lib/test/test_context.py b/Lib/test/test_context.py new file mode 100644 index 00000000000..74d05fc5c41 --- /dev/null +++ b/Lib/test/test_context.py @@ -0,0 +1,1064 @@ +import concurrent.futures +import contextvars +import functools +import gc +import random +import time +import unittest +import weakref + +try: + from _testcapi import hamt +except ImportError: + hamt = None + + +def isolated_context(func): + """Needed to make reftracking test mode work.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + ctx = contextvars.Context() + return ctx.run(func, *args, **kwargs) + return wrapper + + +class ContextTest(unittest.TestCase): + def test_context_var_new_1(self): + with self.assertRaisesRegex(TypeError, 'takes exactly 1'): + contextvars.ContextVar() + + with self.assertRaisesRegex(TypeError, 'must be a str'): + contextvars.ContextVar(1) + + c = contextvars.ContextVar('a') + self.assertNotEqual(hash(c), hash('a')) + + def test_context_var_new_2(self): + self.assertIsNone(contextvars.ContextVar[int]) + + @isolated_context + def test_context_var_repr_1(self): + c = contextvars.ContextVar('a') + self.assertIn('a', repr(c)) + + c = contextvars.ContextVar('a', default=123) + self.assertIn('123', repr(c)) + + lst = [] + c = contextvars.ContextVar('a', default=lst) + lst.append(c) + self.assertIn('...', repr(c)) + self.assertIn('...', repr(lst)) + + t = c.set(1) + self.assertIn(repr(c), repr(t)) + self.assertNotIn(' used ', repr(t)) + c.reset(t) + self.assertIn(' used ', repr(t)) + + def test_context_subclassing_1(self): + with self.assertRaisesRegex(TypeError, 'not an acceptable base type'): + class MyContextVar(contextvars.ContextVar): + # Potentially we might want ContextVars to be subclassable. + pass + + with self.assertRaisesRegex(TypeError, 'not an acceptable base type'): + class MyContext(contextvars.Context): + pass + + with self.assertRaisesRegex(TypeError, 'not an acceptable base type'): + class MyToken(contextvars.Token): + pass + + def test_context_new_1(self): + with self.assertRaisesRegex(TypeError, 'any arguments'): + contextvars.Context(1) + with self.assertRaisesRegex(TypeError, 'any arguments'): + contextvars.Context(1, a=1) + with self.assertRaisesRegex(TypeError, 'any arguments'): + contextvars.Context(a=1) + contextvars.Context(**{}) + + def test_context_typerrors_1(self): + ctx = contextvars.Context() + + with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'): + ctx[1] + with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'): + 1 in ctx + with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'): + ctx.get(1) + + def test_context_get_context_1(self): + ctx = contextvars.copy_context() + self.assertIsInstance(ctx, contextvars.Context) + + def test_context_run_1(self): + ctx = contextvars.Context() + + with self.assertRaisesRegex(TypeError, 'missing 1 required'): + ctx.run() + + def test_context_run_2(self): + ctx = contextvars.Context() + + def func(*args, **kwargs): + kwargs['spam'] = 'foo' + args += ('bar',) + return args, kwargs + + for f in (func, functools.partial(func)): + # partial doesn't support FASTCALL + + self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'})) + self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'})) + + self.assertEqual( + ctx.run(f, a=2), + (('bar',), {'a': 2, 'spam': 'foo'})) + + self.assertEqual( + ctx.run(f, 11, a=2), + ((11, 'bar'), {'a': 2, 'spam': 'foo'})) + + a = {} + self.assertEqual( + ctx.run(f, 11, **a), + ((11, 'bar'), {'spam': 'foo'})) + self.assertEqual(a, {}) + + def test_context_run_3(self): + ctx = contextvars.Context() + + def func(*args, **kwargs): + 1 / 0 + + with self.assertRaises(ZeroDivisionError): + ctx.run(func) + with self.assertRaises(ZeroDivisionError): + ctx.run(func, 1, 2) + with self.assertRaises(ZeroDivisionError): + ctx.run(func, 1, 2, a=123) + + @isolated_context + def test_context_run_4(self): + ctx1 = contextvars.Context() + ctx2 = contextvars.Context() + var = contextvars.ContextVar('var') + + def func2(): + self.assertIsNone(var.get(None)) + + def func1(): + self.assertIsNone(var.get(None)) + var.set('spam') + ctx2.run(func2) + self.assertEqual(var.get(None), 'spam') + + cur = contextvars.copy_context() + self.assertEqual(len(cur), 1) + self.assertEqual(cur[var], 'spam') + return cur + + returned_ctx = ctx1.run(func1) + self.assertEqual(ctx1, returned_ctx) + self.assertEqual(returned_ctx[var], 'spam') + self.assertIn(var, returned_ctx) + + def test_context_run_5(self): + ctx = contextvars.Context() + var = contextvars.ContextVar('var') + + def func(): + self.assertIsNone(var.get(None)) + var.set('spam') + 1 / 0 + + with self.assertRaises(ZeroDivisionError): + ctx.run(func) + + self.assertIsNone(var.get(None)) + + def test_context_run_6(self): + ctx = contextvars.Context() + c = contextvars.ContextVar('a', default=0) + + def fun(): + self.assertEqual(c.get(), 0) + self.assertIsNone(ctx.get(c)) + + c.set(42) + self.assertEqual(c.get(), 42) + self.assertEqual(ctx.get(c), 42) + + ctx.run(fun) + + def test_context_run_7(self): + ctx = contextvars.Context() + + def fun(): + with self.assertRaisesRegex(RuntimeError, 'is already entered'): + ctx.run(fun) + + ctx.run(fun) + + @isolated_context + def test_context_getset_1(self): + c = contextvars.ContextVar('c') + with self.assertRaises(LookupError): + c.get() + + self.assertIsNone(c.get(None)) + + t0 = c.set(42) + self.assertEqual(c.get(), 42) + self.assertEqual(c.get(None), 42) + self.assertIs(t0.old_value, t0.MISSING) + self.assertIs(t0.old_value, contextvars.Token.MISSING) + self.assertIs(t0.var, c) + + t = c.set('spam') + self.assertEqual(c.get(), 'spam') + self.assertEqual(c.get(None), 'spam') + self.assertEqual(t.old_value, 42) + c.reset(t) + + self.assertEqual(c.get(), 42) + self.assertEqual(c.get(None), 42) + + c.set('spam2') + with self.assertRaisesRegex(RuntimeError, 'has already been used'): + c.reset(t) + self.assertEqual(c.get(), 'spam2') + + ctx1 = contextvars.copy_context() + self.assertIn(c, ctx1) + + c.reset(t0) + with self.assertRaisesRegex(RuntimeError, 'has already been used'): + c.reset(t0) + self.assertIsNone(c.get(None)) + + self.assertIn(c, ctx1) + self.assertEqual(ctx1[c], 'spam2') + self.assertEqual(ctx1.get(c, 'aa'), 'spam2') + self.assertEqual(len(ctx1), 1) + self.assertEqual(list(ctx1.items()), [(c, 'spam2')]) + self.assertEqual(list(ctx1.values()), ['spam2']) + self.assertEqual(list(ctx1.keys()), [c]) + self.assertEqual(list(ctx1), [c]) + + ctx2 = contextvars.copy_context() + self.assertNotIn(c, ctx2) + with self.assertRaises(KeyError): + ctx2[c] + self.assertEqual(ctx2.get(c, 'aa'), 'aa') + self.assertEqual(len(ctx2), 0) + self.assertEqual(list(ctx2), []) + + @isolated_context + def test_context_getset_2(self): + v1 = contextvars.ContextVar('v1') + v2 = contextvars.ContextVar('v2') + + t1 = v1.set(42) + with self.assertRaisesRegex(ValueError, 'by a different'): + v2.reset(t1) + + @isolated_context + def test_context_getset_3(self): + c = contextvars.ContextVar('c', default=42) + ctx = contextvars.Context() + + def fun(): + self.assertEqual(c.get(), 42) + with self.assertRaises(KeyError): + ctx[c] + self.assertIsNone(ctx.get(c)) + self.assertEqual(ctx.get(c, 'spam'), 'spam') + self.assertNotIn(c, ctx) + self.assertEqual(list(ctx.keys()), []) + + t = c.set(1) + self.assertEqual(list(ctx.keys()), [c]) + self.assertEqual(ctx[c], 1) + + c.reset(t) + self.assertEqual(list(ctx.keys()), []) + with self.assertRaises(KeyError): + ctx[c] + + ctx.run(fun) + + @isolated_context + def test_context_getset_4(self): + c = contextvars.ContextVar('c', default=42) + ctx = contextvars.Context() + + tok = ctx.run(c.set, 1) + + with self.assertRaisesRegex(ValueError, 'different Context'): + c.reset(tok) + + @isolated_context + def test_context_getset_5(self): + c = contextvars.ContextVar('c', default=42) + c.set([]) + + def fun(): + c.set([]) + c.get().append(42) + self.assertEqual(c.get(), [42]) + + contextvars.copy_context().run(fun) + self.assertEqual(c.get(), []) + + def test_context_copy_1(self): + ctx1 = contextvars.Context() + c = contextvars.ContextVar('c', default=42) + + def ctx1_fun(): + c.set(10) + + ctx2 = ctx1.copy() + self.assertEqual(ctx2[c], 10) + + c.set(20) + self.assertEqual(ctx1[c], 20) + self.assertEqual(ctx2[c], 10) + + ctx2.run(ctx2_fun) + self.assertEqual(ctx1[c], 20) + self.assertEqual(ctx2[c], 30) + + def ctx2_fun(): + self.assertEqual(c.get(), 10) + c.set(30) + self.assertEqual(c.get(), 30) + + ctx1.run(ctx1_fun) + + @isolated_context + def test_context_threads_1(self): + cvar = contextvars.ContextVar('cvar') + + def sub(num): + for i in range(10): + cvar.set(num + i) + time.sleep(random.uniform(0.001, 0.05)) + self.assertEqual(cvar.get(), num + i) + return num + + tp = concurrent.futures.ThreadPoolExecutor(max_workers=10) + try: + results = list(tp.map(sub, range(10))) + finally: + tp.shutdown() + self.assertEqual(results, list(range(10))) + + +# HAMT Tests + + +class HashKey: + _crasher = None + + def __init__(self, hash, name, *, error_on_eq_to=None): + assert hash != -1 + self.name = name + self.hash = hash + self.error_on_eq_to = error_on_eq_to + + def __repr__(self): + return f'' + + def __hash__(self): + if self._crasher is not None and self._crasher.error_on_hash: + raise HashingError + + return self.hash + + def __eq__(self, other): + if not isinstance(other, HashKey): + return NotImplemented + + if self._crasher is not None and self._crasher.error_on_eq: + raise EqError + + if self.error_on_eq_to is not None and self.error_on_eq_to is other: + raise ValueError(f'cannot compare {self!r} to {other!r}') + if other.error_on_eq_to is not None and other.error_on_eq_to is self: + raise ValueError(f'cannot compare {other!r} to {self!r}') + + return (self.name, self.hash) == (other.name, other.hash) + + +class KeyStr(str): + def __hash__(self): + if HashKey._crasher is not None and HashKey._crasher.error_on_hash: + raise HashingError + return super().__hash__() + + def __eq__(self, other): + if HashKey._crasher is not None and HashKey._crasher.error_on_eq: + raise EqError + return super().__eq__(other) + + +class HaskKeyCrasher: + def __init__(self, *, error_on_hash=False, error_on_eq=False): + self.error_on_hash = error_on_hash + self.error_on_eq = error_on_eq + + def __enter__(self): + if HashKey._crasher is not None: + raise RuntimeError('cannot nest crashers') + HashKey._crasher = self + + def __exit__(self, *exc): + HashKey._crasher = None + + +class HashingError(Exception): + pass + + +class EqError(Exception): + pass + + +@unittest.skipIf(hamt is None, '_testcapi lacks "hamt()" function') +class HamtTest(unittest.TestCase): + + def test_hashkey_helper_1(self): + k1 = HashKey(10, 'aaa') + k2 = HashKey(10, 'bbb') + + self.assertNotEqual(k1, k2) + self.assertEqual(hash(k1), hash(k2)) + + d = dict() + d[k1] = 'a' + d[k2] = 'b' + + self.assertEqual(d[k1], 'a') + self.assertEqual(d[k2], 'b') + + def test_hamt_basics_1(self): + h = hamt() + h = None # NoQA + + def test_hamt_basics_2(self): + h = hamt() + self.assertEqual(len(h), 0) + + h2 = h.set('a', 'b') + self.assertIsNot(h, h2) + self.assertEqual(len(h), 0) + self.assertEqual(len(h2), 1) + + self.assertIsNone(h.get('a')) + self.assertEqual(h.get('a', 42), 42) + + self.assertEqual(h2.get('a'), 'b') + + h3 = h2.set('b', 10) + self.assertIsNot(h2, h3) + self.assertEqual(len(h), 0) + self.assertEqual(len(h2), 1) + self.assertEqual(len(h3), 2) + self.assertEqual(h3.get('a'), 'b') + self.assertEqual(h3.get('b'), 10) + + self.assertIsNone(h.get('b')) + self.assertIsNone(h2.get('b')) + + self.assertIsNone(h.get('a')) + self.assertEqual(h2.get('a'), 'b') + + h = h2 = h3 = None + + def test_hamt_basics_3(self): + h = hamt() + o = object() + h1 = h.set('1', o) + h2 = h1.set('1', o) + self.assertIs(h1, h2) + + def test_hamt_basics_4(self): + h = hamt() + h1 = h.set('key', []) + h2 = h1.set('key', []) + self.assertIsNot(h1, h2) + self.assertEqual(len(h1), 1) + self.assertEqual(len(h2), 1) + self.assertIsNot(h1.get('key'), h2.get('key')) + + def test_hamt_collision_1(self): + k1 = HashKey(10, 'aaa') + k2 = HashKey(10, 'bbb') + k3 = HashKey(10, 'ccc') + + h = hamt() + h2 = h.set(k1, 'a') + h3 = h2.set(k2, 'b') + + self.assertEqual(h.get(k1), None) + self.assertEqual(h.get(k2), None) + + self.assertEqual(h2.get(k1), 'a') + self.assertEqual(h2.get(k2), None) + + self.assertEqual(h3.get(k1), 'a') + self.assertEqual(h3.get(k2), 'b') + + h4 = h3.set(k2, 'cc') + h5 = h4.set(k3, 'aa') + + self.assertEqual(h3.get(k1), 'a') + self.assertEqual(h3.get(k2), 'b') + self.assertEqual(h4.get(k1), 'a') + self.assertEqual(h4.get(k2), 'cc') + self.assertEqual(h4.get(k3), None) + self.assertEqual(h5.get(k1), 'a') + self.assertEqual(h5.get(k2), 'cc') + self.assertEqual(h5.get(k2), 'cc') + self.assertEqual(h5.get(k3), 'aa') + + self.assertEqual(len(h), 0) + self.assertEqual(len(h2), 1) + self.assertEqual(len(h3), 2) + self.assertEqual(len(h4), 2) + self.assertEqual(len(h5), 3) + + def test_hamt_stress(self): + COLLECTION_SIZE = 7000 + TEST_ITERS_EVERY = 647 + CRASH_HASH_EVERY = 97 + CRASH_EQ_EVERY = 11 + RUN_XTIMES = 3 + + for _ in range(RUN_XTIMES): + h = hamt() + d = dict() + + for i in range(COLLECTION_SIZE): + key = KeyStr(i) + + if not (i % CRASH_HASH_EVERY): + with HaskKeyCrasher(error_on_hash=True): + with self.assertRaises(HashingError): + h.set(key, i) + + h = h.set(key, i) + + if not (i % CRASH_EQ_EVERY): + with HaskKeyCrasher(error_on_eq=True): + with self.assertRaises(EqError): + h.get(KeyStr(i)) # really trigger __eq__ + + d[key] = i + self.assertEqual(len(d), len(h)) + + if not (i % TEST_ITERS_EVERY): + self.assertEqual(set(h.items()), set(d.items())) + self.assertEqual(len(h.items()), len(d.items())) + + self.assertEqual(len(h), COLLECTION_SIZE) + + for key in range(COLLECTION_SIZE): + self.assertEqual(h.get(KeyStr(key), 'not found'), key) + + keys_to_delete = list(range(COLLECTION_SIZE)) + random.shuffle(keys_to_delete) + for iter_i, i in enumerate(keys_to_delete): + key = KeyStr(i) + + if not (iter_i % CRASH_HASH_EVERY): + with HaskKeyCrasher(error_on_hash=True): + with self.assertRaises(HashingError): + h.delete(key) + + if not (iter_i % CRASH_EQ_EVERY): + with HaskKeyCrasher(error_on_eq=True): + with self.assertRaises(EqError): + h.delete(KeyStr(i)) + + h = h.delete(key) + self.assertEqual(h.get(key, 'not found'), 'not found') + del d[key] + self.assertEqual(len(d), len(h)) + + if iter_i == COLLECTION_SIZE // 2: + hm = h + dm = d.copy() + + if not (iter_i % TEST_ITERS_EVERY): + self.assertEqual(set(h.keys()), set(d.keys())) + self.assertEqual(len(h.keys()), len(d.keys())) + + self.assertEqual(len(d), 0) + self.assertEqual(len(h), 0) + + # ============ + + for key in dm: + self.assertEqual(hm.get(str(key)), dm[key]) + self.assertEqual(len(dm), len(hm)) + + for i, key in enumerate(keys_to_delete): + hm = hm.delete(str(key)) + self.assertEqual(hm.get(str(key), 'not found'), 'not found') + dm.pop(str(key), None) + self.assertEqual(len(d), len(h)) + + if not (i % TEST_ITERS_EVERY): + self.assertEqual(set(h.values()), set(d.values())) + self.assertEqual(len(h.values()), len(d.values())) + + self.assertEqual(len(d), 0) + self.assertEqual(len(h), 0) + self.assertEqual(list(h.items()), []) + + def test_hamt_delete_1(self): + A = HashKey(100, 'A') + B = HashKey(101, 'B') + C = HashKey(102, 'C') + D = HashKey(103, 'D') + E = HashKey(104, 'E') + Z = HashKey(-100, 'Z') + + Er = HashKey(103, 'Er', error_on_eq_to=D) + + h = hamt() + h = h.set(A, 'a') + h = h.set(B, 'b') + h = h.set(C, 'c') + h = h.set(D, 'd') + h = h.set(E, 'e') + + orig_len = len(h) + + # BitmapNode(size=10 bitmap=0b111110000 id=0x10eadc618): + # : 'a' + # : 'b' + # : 'c' + # : 'd' + # : 'e' + + h = h.delete(C) + self.assertEqual(len(h), orig_len - 1) + + with self.assertRaisesRegex(ValueError, 'cannot compare'): + h.delete(Er) + + h = h.delete(D) + self.assertEqual(len(h), orig_len - 2) + + h2 = h.delete(Z) + self.assertIs(h2, h) + + h = h.delete(A) + self.assertEqual(len(h), orig_len - 3) + + self.assertEqual(h.get(A, 42), 42) + self.assertEqual(h.get(B), 'b') + self.assertEqual(h.get(E), 'e') + + def test_hamt_delete_2(self): + A = HashKey(100, 'A') + B = HashKey(201001, 'B') + C = HashKey(101001, 'C') + D = HashKey(103, 'D') + E = HashKey(104, 'E') + Z = HashKey(-100, 'Z') + + Er = HashKey(201001, 'Er', error_on_eq_to=B) + + h = hamt() + h = h.set(A, 'a') + h = h.set(B, 'b') + h = h.set(C, 'c') + h = h.set(D, 'd') + h = h.set(E, 'e') + + orig_len = len(h) + + # BitmapNode(size=8 bitmap=0b1110010000): + # : 'a' + # : 'd' + # : 'e' + # NULL: + # BitmapNode(size=4 bitmap=0b100000000001000000000): + # : 'b' + # : 'c' + + with self.assertRaisesRegex(ValueError, 'cannot compare'): + h.delete(Er) + + h = h.delete(Z) + self.assertEqual(len(h), orig_len) + + h = h.delete(C) + self.assertEqual(len(h), orig_len - 1) + + h = h.delete(B) + self.assertEqual(len(h), orig_len - 2) + + h = h.delete(A) + self.assertEqual(len(h), orig_len - 3) + + self.assertEqual(h.get(D), 'd') + self.assertEqual(h.get(E), 'e') + + h = h.delete(A) + h = h.delete(B) + h = h.delete(D) + h = h.delete(E) + self.assertEqual(len(h), 0) + + def test_hamt_delete_3(self): + A = HashKey(100, 'A') + B = HashKey(101, 'B') + C = HashKey(100100, 'C') + D = HashKey(100100, 'D') + E = HashKey(104, 'E') + + h = hamt() + h = h.set(A, 'a') + h = h.set(B, 'b') + h = h.set(C, 'c') + h = h.set(D, 'd') + h = h.set(E, 'e') + + orig_len = len(h) + + # BitmapNode(size=6 bitmap=0b100110000): + # NULL: + # BitmapNode(size=4 bitmap=0b1000000000000000000001000): + # : 'a' + # NULL: + # CollisionNode(size=4 id=0x108572410): + # : 'c' + # : 'd' + # : 'b' + # : 'e' + + h = h.delete(A) + self.assertEqual(len(h), orig_len - 1) + + h = h.delete(E) + self.assertEqual(len(h), orig_len - 2) + + self.assertEqual(h.get(C), 'c') + self.assertEqual(h.get(B), 'b') + + def test_hamt_delete_4(self): + A = HashKey(100, 'A') + B = HashKey(101, 'B') + C = HashKey(100100, 'C') + D = HashKey(100100, 'D') + E = HashKey(100100, 'E') + + h = hamt() + h = h.set(A, 'a') + h = h.set(B, 'b') + h = h.set(C, 'c') + h = h.set(D, 'd') + h = h.set(E, 'e') + + orig_len = len(h) + + # BitmapNode(size=4 bitmap=0b110000): + # NULL: + # BitmapNode(size=4 bitmap=0b1000000000000000000001000): + # : 'a' + # NULL: + # CollisionNode(size=6 id=0x10515ef30): + # : 'c' + # : 'd' + # : 'e' + # : 'b' + + h = h.delete(D) + self.assertEqual(len(h), orig_len - 1) + + h = h.delete(E) + self.assertEqual(len(h), orig_len - 2) + + h = h.delete(C) + self.assertEqual(len(h), orig_len - 3) + + h = h.delete(A) + self.assertEqual(len(h), orig_len - 4) + + h = h.delete(B) + self.assertEqual(len(h), 0) + + def test_hamt_delete_5(self): + h = hamt() + + keys = [] + for i in range(17): + key = HashKey(i, str(i)) + keys.append(key) + h = h.set(key, f'val-{i}') + + collision_key16 = HashKey(16, '18') + h = h.set(collision_key16, 'collision') + + # ArrayNode(id=0x10f8b9318): + # 0:: + # BitmapNode(size=2 count=1 bitmap=0b1): + # : 'val-0' + # + # ... 14 more BitmapNodes ... + # + # 15:: + # BitmapNode(size=2 count=1 bitmap=0b1): + # : 'val-15' + # + # 16:: + # BitmapNode(size=2 count=1 bitmap=0b1): + # NULL: + # CollisionNode(size=4 id=0x10f2f5af8): + # : 'val-16' + # : 'collision' + + self.assertEqual(len(h), 18) + + h = h.delete(keys[2]) + self.assertEqual(len(h), 17) + + h = h.delete(collision_key16) + self.assertEqual(len(h), 16) + h = h.delete(keys[16]) + self.assertEqual(len(h), 15) + + h = h.delete(keys[1]) + self.assertEqual(len(h), 14) + h = h.delete(keys[1]) + self.assertEqual(len(h), 14) + + for key in keys: + h = h.delete(key) + self.assertEqual(len(h), 0) + + def test_hamt_items_1(self): + A = HashKey(100, 'A') + B = HashKey(201001, 'B') + C = HashKey(101001, 'C') + D = HashKey(103, 'D') + E = HashKey(104, 'E') + F = HashKey(110, 'F') + + h = hamt() + h = h.set(A, 'a') + h = h.set(B, 'b') + h = h.set(C, 'c') + h = h.set(D, 'd') + h = h.set(E, 'e') + h = h.set(F, 'f') + + it = h.items() + self.assertEqual( + set(list(it)), + {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')}) + + def test_hamt_items_2(self): + A = HashKey(100, 'A') + B = HashKey(101, 'B') + C = HashKey(100100, 'C') + D = HashKey(100100, 'D') + E = HashKey(100100, 'E') + F = HashKey(110, 'F') + + h = hamt() + h = h.set(A, 'a') + h = h.set(B, 'b') + h = h.set(C, 'c') + h = h.set(D, 'd') + h = h.set(E, 'e') + h = h.set(F, 'f') + + it = h.items() + self.assertEqual( + set(list(it)), + {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')}) + + def test_hamt_keys_1(self): + A = HashKey(100, 'A') + B = HashKey(101, 'B') + C = HashKey(100100, 'C') + D = HashKey(100100, 'D') + E = HashKey(100100, 'E') + F = HashKey(110, 'F') + + h = hamt() + h = h.set(A, 'a') + h = h.set(B, 'b') + h = h.set(C, 'c') + h = h.set(D, 'd') + h = h.set(E, 'e') + h = h.set(F, 'f') + + self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F}) + self.assertEqual(set(list(h)), {A, B, C, D, E, F}) + + def test_hamt_items_3(self): + h = hamt() + self.assertEqual(len(h.items()), 0) + self.assertEqual(list(h.items()), []) + + def test_hamt_eq_1(self): + A = HashKey(100, 'A') + B = HashKey(101, 'B') + C = HashKey(100100, 'C') + D = HashKey(100100, 'D') + E = HashKey(120, 'E') + + h1 = hamt() + h1 = h1.set(A, 'a') + h1 = h1.set(B, 'b') + h1 = h1.set(C, 'c') + h1 = h1.set(D, 'd') + + h2 = hamt() + h2 = h2.set(A, 'a') + + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h2 = h2.set(B, 'b') + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h2 = h2.set(C, 'c') + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h2 = h2.set(D, 'd2') + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h2 = h2.set(D, 'd') + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2 = h2.set(E, 'e') + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h2 = h2.delete(D) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h2 = h2.set(E, 'd') + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + def test_hamt_eq_2(self): + A = HashKey(100, 'A') + Er = HashKey(100, 'Er', error_on_eq_to=A) + + h1 = hamt() + h1 = h1.set(A, 'a') + + h2 = hamt() + h2 = h2.set(Er, 'a') + + with self.assertRaisesRegex(ValueError, 'cannot compare'): + h1 == h2 + + with self.assertRaisesRegex(ValueError, 'cannot compare'): + h1 != h2 + + def test_hamt_gc_1(self): + A = HashKey(100, 'A') + + h = hamt() + h = h.set(0, 0) # empty HAMT node is memoized in hamt.c + ref = weakref.ref(h) + + a = [] + a.append(a) + a.append(h) + b = [] + a.append(b) + b.append(a) + h = h.set(A, b) + + del h, a, b + + gc.collect() + gc.collect() + gc.collect() + + self.assertIsNone(ref()) + + def test_hamt_gc_2(self): + A = HashKey(100, 'A') + B = HashKey(101, 'B') + + h = hamt() + h = h.set(A, 'a') + h = h.set(A, h) + + ref = weakref.ref(h) + hi = h.items() + next(hi) + + del h, hi + + gc.collect() + gc.collect() + gc.collect() + + self.assertIsNone(ref()) + + def test_hamt_in_1(self): + A = HashKey(100, 'A') + AA = HashKey(100, 'A') + + B = HashKey(101, 'B') + + h = hamt() + h = h.set(A, 1) + + self.assertTrue(A in h) + self.assertFalse(B in h) + + with self.assertRaises(EqError): + with HaskKeyCrasher(error_on_eq=True): + AA in h + + with self.assertRaises(HashingError): + with HaskKeyCrasher(error_on_hash=True): + AA in h + + def test_hamt_getitem_1(self): + A = HashKey(100, 'A') + AA = HashKey(100, 'A') + + B = HashKey(101, 'B') + + h = hamt() + h = h.set(A, 1) + + self.assertEqual(h[A], 1) + self.assertEqual(h[AA], 1) + + with self.assertRaises(KeyError): + h[B] + + with self.assertRaises(EqError): + with HaskKeyCrasher(error_on_eq=True): + h[AA] + + with self.assertRaises(HashingError): + with HaskKeyCrasher(error_on_hash=True): + h[AA] + + +if __name__ == "__main__": + unittest.main() diff --git a/Makefile.pre.in b/Makefile.pre.in index d1512671735..162006a2f9c 100644 --- a/Makefile.pre.in +++ b/Makefile.pre.in @@ -354,6 +354,8 @@ PYTHON_OBJS= \ Python/pylifecycle.o \ Python/pymath.o \ Python/pystate.o \ + Python/context.o \ + Python/hamt.o \ Python/pythonrun.o \ Python/pytime.o \ Python/bootstrap_hash.o \ @@ -996,6 +998,7 @@ PYTHON_HEADERS= \ $(srcdir)/Include/pymem.h \ $(srcdir)/Include/pyport.h \ $(srcdir)/Include/pystate.h \ + $(srcdir)/Include/context.h \ $(srcdir)/Include/pystrcmp.h \ $(srcdir)/Include/pystrtod.h \ $(srcdir)/Include/pystrhex.h \ @@ -1023,6 +1026,7 @@ PYTHON_HEADERS= \ $(srcdir)/Include/internal/mem.h \ $(srcdir)/Include/internal/pygetopt.h \ $(srcdir)/Include/internal/pystate.h \ + $(srcdir)/Include/internal/context.h \ $(srcdir)/Include/internal/warnings.h \ $(DTRACE_HEADERS) diff --git a/Misc/NEWS.d/next/Core and Builtins/2017-12-28-00-20-42.bpo-32436.H159Jv.rst b/Misc/NEWS.d/next/Core and Builtins/2017-12-28-00-20-42.bpo-32436.H159Jv.rst new file mode 100644 index 00000000000..8586d771781 --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2017-12-28-00-20-42.bpo-32436.H159Jv.rst @@ -0,0 +1 @@ +Implement PEP 567 diff --git a/Modules/Setup.dist b/Modules/Setup.dist index 1f2d56c0655..239550e26b5 100644 --- a/Modules/Setup.dist +++ b/Modules/Setup.dist @@ -176,6 +176,7 @@ _symtable symtablemodule.c #array arraymodule.c # array objects #cmath cmathmodule.c _math.c # -lm # complex math library functions #math mathmodule.c _math.c # -lm # math library functions, e.g. sin() +#_contextvars _contextvarsmodule.c # Context Variables #_struct _struct.c # binary structure packing/unpacking #_weakref _weakref.c # basic weak reference support #_testcapi _testcapimodule.c # Python C API test module diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c index 22ce32c593c..f77ec999b09 100644 --- a/Modules/_asynciomodule.c +++ b/Modules/_asynciomodule.c @@ -36,6 +36,7 @@ static PyObject *asyncio_task_print_stack_func; static PyObject *asyncio_task_repr_info_func; static PyObject *asyncio_InvalidStateError; static PyObject *asyncio_CancelledError; +static PyObject *context_kwname; /* WeakSet containing all alive tasks. */ @@ -59,6 +60,7 @@ typedef enum { PyObject_HEAD \ PyObject *prefix##_loop; \ PyObject *prefix##_callback0; \ + PyContext *prefix##_context0; \ PyObject *prefix##_callbacks; \ PyObject *prefix##_exception; \ PyObject *prefix##_result; \ @@ -77,6 +79,7 @@ typedef struct { FutureObj_HEAD(task) PyObject *task_fut_waiter; PyObject *task_coro; + PyContext *task_context; int task_must_cancel; int task_log_destroy_pending; } TaskObj; @@ -336,11 +339,38 @@ get_event_loop(void) static int -call_soon(PyObject *loop, PyObject *func, PyObject *arg) +call_soon(PyObject *loop, PyObject *func, PyObject *arg, PyContext *ctx) { PyObject *handle; - handle = _PyObject_CallMethodIdObjArgs( - loop, &PyId_call_soon, func, arg, NULL); + PyObject *stack[3]; + Py_ssize_t nargs; + + if (ctx == NULL) { + handle = _PyObject_CallMethodIdObjArgs( + loop, &PyId_call_soon, func, arg, NULL); + } + else { + /* Use FASTCALL to pass a keyword-only argument to call_soon */ + + PyObject *callable = _PyObject_GetAttrId(loop, &PyId_call_soon); + if (callable == NULL) { + return -1; + } + + /* All refs in 'stack' are borrowed. */ + nargs = 1; + stack[0] = func; + if (arg != NULL) { + stack[1] = arg; + nargs++; + } + stack[nargs] = (PyObject *)ctx; + + handle = _PyObject_FastCallKeywords( + callable, stack, nargs, context_kwname); + Py_DECREF(callable); + } + if (handle == NULL) { return -1; } @@ -387,8 +417,11 @@ future_schedule_callbacks(FutureObj *fut) /* There's a 1st callback */ int ret = call_soon( - fut->fut_loop, fut->fut_callback0, (PyObject *)fut); + fut->fut_loop, fut->fut_callback0, + (PyObject *)fut, fut->fut_context0); + Py_CLEAR(fut->fut_callback0); + Py_CLEAR(fut->fut_context0); if (ret) { /* If an error occurs in pure-Python implementation, all callbacks are cleared. */ @@ -413,9 +446,11 @@ future_schedule_callbacks(FutureObj *fut) } for (i = 0; i < len; i++) { - PyObject *cb = PyList_GET_ITEM(fut->fut_callbacks, i); + PyObject *cb_tup = PyList_GET_ITEM(fut->fut_callbacks, i); + PyObject *cb = PyTuple_GET_ITEM(cb_tup, 0); + PyObject *ctx = PyTuple_GET_ITEM(cb_tup, 1); - if (call_soon(fut->fut_loop, cb, (PyObject *)fut)) { + if (call_soon(fut->fut_loop, cb, (PyObject *)fut, (PyContext *)ctx)) { /* If an error occurs in pure-Python implementation, all callbacks are cleared. */ Py_CLEAR(fut->fut_callbacks); @@ -462,6 +497,7 @@ future_init(FutureObj *fut, PyObject *loop) } fut->fut_callback0 = NULL; + fut->fut_context0 = NULL; fut->fut_callbacks = NULL; return 0; @@ -566,7 +602,7 @@ future_get_result(FutureObj *fut, PyObject **result) } static PyObject * -future_add_done_callback(FutureObj *fut, PyObject *arg) +future_add_done_callback(FutureObj *fut, PyObject *arg, PyContext *ctx) { if (!future_is_alive(fut)) { PyErr_SetString(PyExc_RuntimeError, "uninitialized Future object"); @@ -576,7 +612,7 @@ future_add_done_callback(FutureObj *fut, PyObject *arg) if (fut->fut_state != STATE_PENDING) { /* The future is done/cancelled, so schedule the callback right away. */ - if (call_soon(fut->fut_loop, arg, (PyObject*) fut)) { + if (call_soon(fut->fut_loop, arg, (PyObject*) fut, ctx)) { return NULL; } } @@ -602,24 +638,38 @@ future_add_done_callback(FutureObj *fut, PyObject *arg) with a new list and add the new callback to it. */ - if (fut->fut_callbacks != NULL) { - int err = PyList_Append(fut->fut_callbacks, arg); - if (err != 0) { - return NULL; - } - } - else if (fut->fut_callback0 == NULL) { + if (fut->fut_callbacks == NULL && fut->fut_callback0 == NULL) { Py_INCREF(arg); fut->fut_callback0 = arg; + Py_INCREF(ctx); + fut->fut_context0 = ctx; } else { - fut->fut_callbacks = PyList_New(1); - if (fut->fut_callbacks == NULL) { + PyObject *tup = PyTuple_New(2); + if (tup == NULL) { return NULL; } - Py_INCREF(arg); - PyList_SET_ITEM(fut->fut_callbacks, 0, arg); + PyTuple_SET_ITEM(tup, 0, arg); + Py_INCREF(ctx); + PyTuple_SET_ITEM(tup, 1, (PyObject *)ctx); + + if (fut->fut_callbacks != NULL) { + int err = PyList_Append(fut->fut_callbacks, tup); + if (err) { + Py_DECREF(tup); + return NULL; + } + Py_DECREF(tup); + } + else { + fut->fut_callbacks = PyList_New(1); + if (fut->fut_callbacks == NULL) { + return NULL; + } + + PyList_SET_ITEM(fut->fut_callbacks, 0, tup); /* borrow */ + } } } @@ -676,6 +726,7 @@ FutureObj_clear(FutureObj *fut) { Py_CLEAR(fut->fut_loop); Py_CLEAR(fut->fut_callback0); + Py_CLEAR(fut->fut_context0); Py_CLEAR(fut->fut_callbacks); Py_CLEAR(fut->fut_result); Py_CLEAR(fut->fut_exception); @@ -689,6 +740,7 @@ FutureObj_traverse(FutureObj *fut, visitproc visit, void *arg) { Py_VISIT(fut->fut_loop); Py_VISIT(fut->fut_callback0); + Py_VISIT(fut->fut_context0); Py_VISIT(fut->fut_callbacks); Py_VISIT(fut->fut_result); Py_VISIT(fut->fut_exception); @@ -821,6 +873,8 @@ _asyncio.Future.add_done_callback fn: object / + * + context: object = NULL Add a callback to be run when the future becomes done. @@ -830,10 +884,21 @@ scheduled with call_soon. [clinic start generated code]*/ static PyObject * -_asyncio_Future_add_done_callback(FutureObj *self, PyObject *fn) -/*[clinic end generated code: output=819e09629b2ec2b5 input=8f818b39990b027d]*/ +_asyncio_Future_add_done_callback_impl(FutureObj *self, PyObject *fn, + PyObject *context) +/*[clinic end generated code: output=7ce635bbc9554c1e input=15ab0693a96e9533]*/ { - return future_add_done_callback(self, fn); + if (context == NULL) { + context = (PyObject *)PyContext_CopyCurrent(); + if (context == NULL) { + return NULL; + } + PyObject *res = future_add_done_callback( + self, fn, (PyContext *)context); + Py_DECREF(context); + return res; + } + return future_add_done_callback(self, fn, (PyContext *)context); } /*[clinic input] @@ -865,6 +930,7 @@ _asyncio_Future_remove_done_callback(FutureObj *self, PyObject *fn) if (cmp == 1) { /* callback0 == fn */ Py_CLEAR(self->fut_callback0); + Py_CLEAR(self->fut_context0); cleared_callback0 = 1; } } @@ -880,8 +946,9 @@ _asyncio_Future_remove_done_callback(FutureObj *self, PyObject *fn) } if (len == 1) { + PyObject *cb_tup = PyList_GET_ITEM(self->fut_callbacks, 0); int cmp = PyObject_RichCompareBool( - fn, PyList_GET_ITEM(self->fut_callbacks, 0), Py_EQ); + fn, PyTuple_GET_ITEM(cb_tup, 0), Py_EQ); if (cmp == -1) { return NULL; } @@ -903,7 +970,7 @@ _asyncio_Future_remove_done_callback(FutureObj *self, PyObject *fn) int ret; PyObject *item = PyList_GET_ITEM(self->fut_callbacks, i); Py_INCREF(item); - ret = PyObject_RichCompareBool(fn, item, Py_EQ); + ret = PyObject_RichCompareBool(fn, PyTuple_GET_ITEM(item, 0), Py_EQ); if (ret == 0) { if (j < len) { PyList_SET_ITEM(newlist, j, item); @@ -1081,47 +1148,49 @@ static PyObject * FutureObj_get_callbacks(FutureObj *fut) { Py_ssize_t i; - Py_ssize_t len; - PyObject *new_list; ENSURE_FUTURE_ALIVE(fut) - if (fut->fut_callbacks == NULL) { - if (fut->fut_callback0 == NULL) { + if (fut->fut_callback0 == NULL) { + if (fut->fut_callbacks == NULL) { Py_RETURN_NONE; } - else { - new_list = PyList_New(1); - if (new_list == NULL) { - return NULL; - } - Py_INCREF(fut->fut_callback0); - PyList_SET_ITEM(new_list, 0, fut->fut_callback0); - return new_list; - } - } - assert(fut->fut_callbacks != NULL); - - if (fut->fut_callback0 == NULL) { Py_INCREF(fut->fut_callbacks); return fut->fut_callbacks; } - assert(fut->fut_callback0 != NULL); + Py_ssize_t len = 1; + if (fut->fut_callbacks != NULL) { + len += PyList_GET_SIZE(fut->fut_callbacks); + } - len = PyList_GET_SIZE(fut->fut_callbacks); - new_list = PyList_New(len + 1); + + PyObject *new_list = PyList_New(len); if (new_list == NULL) { return NULL; } + PyObject *tup0 = PyTuple_New(2); + if (tup0 == NULL) { + Py_DECREF(new_list); + return NULL; + } + Py_INCREF(fut->fut_callback0); - PyList_SET_ITEM(new_list, 0, fut->fut_callback0); - for (i = 0; i < len; i++) { - PyObject *cb = PyList_GET_ITEM(fut->fut_callbacks, i); - Py_INCREF(cb); - PyList_SET_ITEM(new_list, i + 1, cb); + PyTuple_SET_ITEM(tup0, 0, fut->fut_callback0); + assert(fut->fut_context0 != NULL); + Py_INCREF(fut->fut_context0); + PyTuple_SET_ITEM(tup0, 1, (PyObject *)fut->fut_context0); + + PyList_SET_ITEM(new_list, 0, tup0); + + if (fut->fut_callbacks != NULL) { + for (i = 0; i < PyList_GET_SIZE(fut->fut_callbacks); i++) { + PyObject *cb = PyList_GET_ITEM(fut->fut_callbacks, i); + Py_INCREF(cb); + PyList_SET_ITEM(new_list, i + 1, cb); + } } return new_list; @@ -1912,6 +1981,11 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop) return -1; } + self->task_context = PyContext_CopyCurrent(); + if (self->task_context == NULL) { + return -1; + } + self->task_fut_waiter = NULL; self->task_must_cancel = 0; self->task_log_destroy_pending = 1; @@ -1928,6 +2002,7 @@ static int TaskObj_clear(TaskObj *task) { (void)FutureObj_clear((FutureObj*) task); + Py_CLEAR(task->task_context); Py_CLEAR(task->task_coro); Py_CLEAR(task->task_fut_waiter); return 0; @@ -1936,6 +2011,7 @@ TaskObj_clear(TaskObj *task) static int TaskObj_traverse(TaskObj *task, visitproc visit, void *arg) { + Py_VISIT(task->task_context); Py_VISIT(task->task_coro); Py_VISIT(task->task_fut_waiter); (void)FutureObj_traverse((FutureObj*) task, visit, arg); @@ -2451,7 +2527,7 @@ task_call_step_soon(TaskObj *task, PyObject *arg) return -1; } - int ret = call_soon(task->task_loop, cb, NULL); + int ret = call_soon(task->task_loop, cb, NULL, task->task_context); Py_DECREF(cb); return ret; } @@ -2650,7 +2726,8 @@ set_exception: if (wrapper == NULL) { goto fail; } - res = future_add_done_callback((FutureObj*)result, wrapper); + res = future_add_done_callback( + (FutureObj*)result, wrapper, task->task_context); Py_DECREF(wrapper); if (res == NULL) { goto fail; @@ -2724,14 +2801,23 @@ set_exception: goto fail; } - /* result.add_done_callback(task._wakeup) */ wrapper = TaskWakeupMethWrapper_new(task); if (wrapper == NULL) { goto fail; } - res = _PyObject_CallMethodIdObjArgs(result, - &PyId_add_done_callback, - wrapper, NULL); + + /* result.add_done_callback(task._wakeup) */ + PyObject *add_cb = _PyObject_GetAttrId( + result, &PyId_add_done_callback); + if (add_cb == NULL) { + goto fail; + } + PyObject *stack[2]; + stack[0] = wrapper; + stack[1] = (PyObject *)task->task_context; + res = _PyObject_FastCallKeywords( + add_cb, stack, 1, context_kwname); + Py_DECREF(add_cb); Py_DECREF(wrapper); if (res == NULL) { goto fail; @@ -3141,6 +3227,8 @@ module_free(void *m) Py_CLEAR(current_tasks); Py_CLEAR(iscoroutine_typecache); + Py_CLEAR(context_kwname); + module_free_freelists(); } @@ -3164,6 +3252,17 @@ module_init(void) goto fail; } + + context_kwname = PyTuple_New(1); + if (context_kwname == NULL) { + goto fail; + } + PyObject *context_str = PyUnicode_FromString("context"); + if (context_str == NULL) { + goto fail; + } + PyTuple_SET_ITEM(context_kwname, 0, context_str); + #define WITH_MOD(NAME) \ Py_CLEAR(module); \ module = PyImport_ImportModule(NAME); \ diff --git a/Modules/_contextvarsmodule.c b/Modules/_contextvarsmodule.c new file mode 100644 index 00000000000..b7d112dd601 --- /dev/null +++ b/Modules/_contextvarsmodule.c @@ -0,0 +1,75 @@ +#include "Python.h" + +#include "clinic/_contextvarsmodule.c.h" + +/*[clinic input] +module _contextvars +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=a0955718c8b8cea6]*/ + + +/*[clinic input] +_contextvars.copy_context +[clinic start generated code]*/ + +static PyObject * +_contextvars_copy_context_impl(PyObject *module) +/*[clinic end generated code: output=1fcd5da7225c4fa9 input=89bb9ae485888440]*/ +{ + return (PyObject *)PyContext_CopyCurrent(); +} + + +PyDoc_STRVAR(module_doc, "Context Variables"); + +static PyMethodDef _contextvars_methods[] = { + _CONTEXTVARS_COPY_CONTEXT_METHODDEF + {NULL, NULL} +}; + +static struct PyModuleDef _contextvarsmodule = { + PyModuleDef_HEAD_INIT, /* m_base */ + "_contextvars", /* m_name */ + module_doc, /* m_doc */ + -1, /* m_size */ + _contextvars_methods, /* m_methods */ + NULL, /* m_slots */ + NULL, /* m_traverse */ + NULL, /* m_clear */ + NULL, /* m_free */ +}; + +PyMODINIT_FUNC +PyInit__contextvars(void) +{ + PyObject *m = PyModule_Create(&_contextvarsmodule); + if (m == NULL) { + return NULL; + } + + Py_INCREF(&PyContext_Type); + if (PyModule_AddObject(m, "Context", + (PyObject *)&PyContext_Type) < 0) + { + Py_DECREF(&PyContext_Type); + return NULL; + } + + Py_INCREF(&PyContextVar_Type); + if (PyModule_AddObject(m, "ContextVar", + (PyObject *)&PyContextVar_Type) < 0) + { + Py_DECREF(&PyContextVar_Type); + return NULL; + } + + Py_INCREF(&PyContextToken_Type); + if (PyModule_AddObject(m, "Token", + (PyObject *)&PyContextToken_Type) < 0) + { + Py_DECREF(&PyContextToken_Type); + return NULL; + } + + return m; +} diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c index 0d6bf45032d..e3be7d3d829 100644 --- a/Modules/_testcapimodule.c +++ b/Modules/_testcapimodule.c @@ -4438,6 +4438,13 @@ test_pythread_tss_key_state(PyObject *self, PyObject *args) } +static PyObject* +new_hamt(PyObject *self, PyObject *args) +{ + return _PyContext_NewHamtForTests(); +} + + static PyMethodDef TestMethods[] = { {"raise_exception", raise_exception, METH_VARARGS}, {"raise_memoryerror", (PyCFunction)raise_memoryerror, METH_NOARGS}, @@ -4655,6 +4662,7 @@ static PyMethodDef TestMethods[] = { {"get_mapping_values", get_mapping_values, METH_O}, {"get_mapping_items", get_mapping_items, METH_O}, {"test_pythread_tss_key_state", test_pythread_tss_key_state, METH_VARARGS}, + {"hamt", new_hamt, METH_NOARGS}, {NULL, NULL} /* sentinel */ }; diff --git a/Modules/clinic/_asynciomodule.c.h b/Modules/clinic/_asynciomodule.c.h index f2e0f405ada..9fc9d6b17c4 100644 --- a/Modules/clinic/_asynciomodule.c.h +++ b/Modules/clinic/_asynciomodule.c.h @@ -110,7 +110,7 @@ PyDoc_STRVAR(_asyncio_Future_set_exception__doc__, {"set_exception", (PyCFunction)_asyncio_Future_set_exception, METH_O, _asyncio_Future_set_exception__doc__}, PyDoc_STRVAR(_asyncio_Future_add_done_callback__doc__, -"add_done_callback($self, fn, /)\n" +"add_done_callback($self, fn, /, *, context=None)\n" "--\n" "\n" "Add a callback to be run when the future becomes done.\n" @@ -120,7 +120,30 @@ PyDoc_STRVAR(_asyncio_Future_add_done_callback__doc__, "scheduled with call_soon."); #define _ASYNCIO_FUTURE_ADD_DONE_CALLBACK_METHODDEF \ - {"add_done_callback", (PyCFunction)_asyncio_Future_add_done_callback, METH_O, _asyncio_Future_add_done_callback__doc__}, + {"add_done_callback", (PyCFunction)_asyncio_Future_add_done_callback, METH_FASTCALL|METH_KEYWORDS, _asyncio_Future_add_done_callback__doc__}, + +static PyObject * +_asyncio_Future_add_done_callback_impl(FutureObj *self, PyObject *fn, + PyObject *context); + +static PyObject * +_asyncio_Future_add_done_callback(FutureObj *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + static const char * const _keywords[] = {"", "context", NULL}; + static _PyArg_Parser _parser = {"O|$O:add_done_callback", _keywords, 0}; + PyObject *fn; + PyObject *context = NULL; + + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &_parser, + &fn, &context)) { + goto exit; + } + return_value = _asyncio_Future_add_done_callback_impl(self, fn, context); + +exit: + return return_value; +} PyDoc_STRVAR(_asyncio_Future_remove_done_callback__doc__, "remove_done_callback($self, fn, /)\n" @@ -763,4 +786,4 @@ _asyncio__leave_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs, exit: return return_value; } -/*[clinic end generated code: output=616e814431893dcc input=a9049054013a1b77]*/ +/*[clinic end generated code: output=bcbaf1b2480f4aa9 input=a9049054013a1b77]*/ diff --git a/Modules/clinic/_contextvarsmodule.c.h b/Modules/clinic/_contextvarsmodule.c.h new file mode 100644 index 00000000000..b1885e41c35 --- /dev/null +++ b/Modules/clinic/_contextvarsmodule.c.h @@ -0,0 +1,21 @@ +/*[clinic input] +preserve +[clinic start generated code]*/ + +PyDoc_STRVAR(_contextvars_copy_context__doc__, +"copy_context($module, /)\n" +"--\n" +"\n"); + +#define _CONTEXTVARS_COPY_CONTEXT_METHODDEF \ + {"copy_context", (PyCFunction)_contextvars_copy_context, METH_NOARGS, _contextvars_copy_context__doc__}, + +static PyObject * +_contextvars_copy_context_impl(PyObject *module); + +static PyObject * +_contextvars_copy_context(PyObject *module, PyObject *Py_UNUSED(ignored)) +{ + return _contextvars_copy_context_impl(module); +} +/*[clinic end generated code: output=26e07024451baf52 input=a9049054013a1b77]*/ diff --git a/Modules/gcmodule.c b/Modules/gcmodule.c index ea3c294dcff..8ba1093c029 100644 --- a/Modules/gcmodule.c +++ b/Modules/gcmodule.c @@ -24,6 +24,7 @@ */ #include "Python.h" +#include "internal/context.h" #include "internal/mem.h" #include "internal/pystate.h" #include "frameobject.h" /* for PyFrame_ClearFreeList */ @@ -790,6 +791,7 @@ clear_freelists(void) (void)PyDict_ClearFreeList(); (void)PySet_ClearFreeList(); (void)PyAsyncGen_ClearFreeLists(); + (void)PyContext_ClearFreeList(); } /* This is the main function. Read this to understand how the diff --git a/Objects/object.c b/Objects/object.c index 62d7fbebf40..8cec6e2f122 100644 --- a/Objects/object.c +++ b/Objects/object.c @@ -3,6 +3,7 @@ #include "Python.h" #include "internal/pystate.h" +#include "internal/context.h" #include "frameobject.h" #ifdef __cplusplus diff --git a/PCbuild/_contextvars.vcxproj b/PCbuild/_contextvars.vcxproj new file mode 100644 index 00000000000..7418e86570a --- /dev/null +++ b/PCbuild/_contextvars.vcxproj @@ -0,0 +1,77 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + PGInstrument + Win32 + + + PGInstrument + x64 + + + PGUpdate + Win32 + + + PGUpdate + x64 + + + Release + Win32 + + + Release + x64 + + + + {B8BF1D81-09DC-42D4-B406-4F868B33A89E} + _contextvars + Win32Proj + + + + + DynamicLibrary + NotSet + + + + .pyd + + + + + + + + + + <_ProjectFileVersion>10.0.30319.1 + + + + + + + + + + {cf7ac3d1-e2df-41d2-bea6-1e2556cdea26} + false + + + + + + diff --git a/PCbuild/_contextvars.vcxproj.filters b/PCbuild/_contextvars.vcxproj.filters new file mode 100644 index 00000000000..b3002b7ff67 --- /dev/null +++ b/PCbuild/_contextvars.vcxproj.filters @@ -0,0 +1,16 @@ + + + + + + + + {7CBD8910-233D-4E9A-9164-9BA66C1F0E6D} + + + + + Source Files + + + diff --git a/PCbuild/_decimal.vcxproj b/PCbuild/_decimal.vcxproj index b14f31093ff..df9f600cdaf 100644 --- a/PCbuild/_decimal.vcxproj +++ b/PCbuild/_decimal.vcxproj @@ -121,4 +121,4 @@ - \ No newline at end of file + diff --git a/PCbuild/pcbuild.proj b/PCbuild/pcbuild.proj index 848b3b2877f..5e341959bd3 100644 --- a/PCbuild/pcbuild.proj +++ b/PCbuild/pcbuild.proj @@ -49,7 +49,7 @@ - + diff --git a/PCbuild/pythoncore.vcxproj b/PCbuild/pythoncore.vcxproj index bf2ce669f08..fbcd0512139 100644 --- a/PCbuild/pythoncore.vcxproj +++ b/PCbuild/pythoncore.vcxproj @@ -94,6 +94,7 @@ + @@ -112,7 +113,9 @@ + + @@ -232,6 +235,7 @@ + @@ -359,6 +363,7 @@ + @@ -373,6 +378,7 @@ + diff --git a/PCbuild/pythoncore.vcxproj.filters b/PCbuild/pythoncore.vcxproj.filters index 13600cb5c9f..a10686c194b 100644 --- a/PCbuild/pythoncore.vcxproj.filters +++ b/PCbuild/pythoncore.vcxproj.filters @@ -81,6 +81,9 @@ Include + + Include + Include @@ -135,9 +138,15 @@ Include + + Include + Include + + Include + Include @@ -842,6 +851,9 @@ Python + + Python + Python @@ -884,6 +896,9 @@ Python + + Python + Python @@ -998,6 +1013,9 @@ Modules + + Modules + Modules\zlib diff --git a/Python/clinic/context.c.h b/Python/clinic/context.c.h new file mode 100644 index 00000000000..dcf4c214e1b --- /dev/null +++ b/Python/clinic/context.c.h @@ -0,0 +1,146 @@ +/*[clinic input] +preserve +[clinic start generated code]*/ + +PyDoc_STRVAR(_contextvars_Context_get__doc__, +"get($self, key, default=None, /)\n" +"--\n" +"\n"); + +#define _CONTEXTVARS_CONTEXT_GET_METHODDEF \ + {"get", (PyCFunction)_contextvars_Context_get, METH_FASTCALL, _contextvars_Context_get__doc__}, + +static PyObject * +_contextvars_Context_get_impl(PyContext *self, PyObject *key, + PyObject *default_value); + +static PyObject * +_contextvars_Context_get(PyContext *self, PyObject *const *args, Py_ssize_t nargs) +{ + PyObject *return_value = NULL; + PyObject *key; + PyObject *default_value = Py_None; + + if (!_PyArg_UnpackStack(args, nargs, "get", + 1, 2, + &key, &default_value)) { + goto exit; + } + return_value = _contextvars_Context_get_impl(self, key, default_value); + +exit: + return return_value; +} + +PyDoc_STRVAR(_contextvars_Context_items__doc__, +"items($self, /)\n" +"--\n" +"\n"); + +#define _CONTEXTVARS_CONTEXT_ITEMS_METHODDEF \ + {"items", (PyCFunction)_contextvars_Context_items, METH_NOARGS, _contextvars_Context_items__doc__}, + +static PyObject * +_contextvars_Context_items_impl(PyContext *self); + +static PyObject * +_contextvars_Context_items(PyContext *self, PyObject *Py_UNUSED(ignored)) +{ + return _contextvars_Context_items_impl(self); +} + +PyDoc_STRVAR(_contextvars_Context_keys__doc__, +"keys($self, /)\n" +"--\n" +"\n"); + +#define _CONTEXTVARS_CONTEXT_KEYS_METHODDEF \ + {"keys", (PyCFunction)_contextvars_Context_keys, METH_NOARGS, _contextvars_Context_keys__doc__}, + +static PyObject * +_contextvars_Context_keys_impl(PyContext *self); + +static PyObject * +_contextvars_Context_keys(PyContext *self, PyObject *Py_UNUSED(ignored)) +{ + return _contextvars_Context_keys_impl(self); +} + +PyDoc_STRVAR(_contextvars_Context_values__doc__, +"values($self, /)\n" +"--\n" +"\n"); + +#define _CONTEXTVARS_CONTEXT_VALUES_METHODDEF \ + {"values", (PyCFunction)_contextvars_Context_values, METH_NOARGS, _contextvars_Context_values__doc__}, + +static PyObject * +_contextvars_Context_values_impl(PyContext *self); + +static PyObject * +_contextvars_Context_values(PyContext *self, PyObject *Py_UNUSED(ignored)) +{ + return _contextvars_Context_values_impl(self); +} + +PyDoc_STRVAR(_contextvars_Context_copy__doc__, +"copy($self, /)\n" +"--\n" +"\n"); + +#define _CONTEXTVARS_CONTEXT_COPY_METHODDEF \ + {"copy", (PyCFunction)_contextvars_Context_copy, METH_NOARGS, _contextvars_Context_copy__doc__}, + +static PyObject * +_contextvars_Context_copy_impl(PyContext *self); + +static PyObject * +_contextvars_Context_copy(PyContext *self, PyObject *Py_UNUSED(ignored)) +{ + return _contextvars_Context_copy_impl(self); +} + +PyDoc_STRVAR(_contextvars_ContextVar_get__doc__, +"get($self, default=None, /)\n" +"--\n" +"\n"); + +#define _CONTEXTVARS_CONTEXTVAR_GET_METHODDEF \ + {"get", (PyCFunction)_contextvars_ContextVar_get, METH_FASTCALL, _contextvars_ContextVar_get__doc__}, + +static PyObject * +_contextvars_ContextVar_get_impl(PyContextVar *self, PyObject *default_value); + +static PyObject * +_contextvars_ContextVar_get(PyContextVar *self, PyObject *const *args, Py_ssize_t nargs) +{ + PyObject *return_value = NULL; + PyObject *default_value = NULL; + + if (!_PyArg_UnpackStack(args, nargs, "get", + 0, 1, + &default_value)) { + goto exit; + } + return_value = _contextvars_ContextVar_get_impl(self, default_value); + +exit: + return return_value; +} + +PyDoc_STRVAR(_contextvars_ContextVar_set__doc__, +"set($self, value, /)\n" +"--\n" +"\n"); + +#define _CONTEXTVARS_CONTEXTVAR_SET_METHODDEF \ + {"set", (PyCFunction)_contextvars_ContextVar_set, METH_O, _contextvars_ContextVar_set__doc__}, + +PyDoc_STRVAR(_contextvars_ContextVar_reset__doc__, +"reset($self, token, /)\n" +"--\n" +"\n"); + +#define _CONTEXTVARS_CONTEXTVAR_RESET_METHODDEF \ + {"reset", (PyCFunction)_contextvars_ContextVar_reset, METH_O, _contextvars_ContextVar_reset__doc__}, +/*[clinic end generated code: output=d9a675e3a52a14fc input=a9049054013a1b77]*/ diff --git a/Python/context.c b/Python/context.c new file mode 100644 index 00000000000..2f1d0f5c342 --- /dev/null +++ b/Python/context.c @@ -0,0 +1,1220 @@ +#include "Python.h" + +#include "structmember.h" +#include "internal/pystate.h" +#include "internal/context.h" +#include "internal/hamt.h" + + +#define CONTEXT_FREELIST_MAXLEN 255 +static PyContext *ctx_freelist = NULL; +static Py_ssize_t ctx_freelist_len = 0; + + +#include "clinic/context.c.h" +/*[clinic input] +module _contextvars +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=a0955718c8b8cea6]*/ + + +/////////////////////////// Context API + + +static PyContext * +context_new_empty(void); + +static PyContext * +context_new_from_vars(PyHamtObject *vars); + +static inline PyContext * +context_get(void); + +static PyContextToken * +token_new(PyContext *ctx, PyContextVar *var, PyObject *val); + +static PyContextVar * +contextvar_new(PyObject *name, PyObject *def); + +static int +contextvar_set(PyContextVar *var, PyObject *val); + +static int +contextvar_del(PyContextVar *var); + + +PyObject * +_PyContext_NewHamtForTests(void) +{ + return (PyObject *)_PyHamt_New(); +} + + +PyContext * +PyContext_New(void) +{ + return context_new_empty(); +} + + +PyContext * +PyContext_Copy(PyContext * ctx) +{ + return context_new_from_vars(ctx->ctx_vars); +} + + +PyContext * +PyContext_CopyCurrent(void) +{ + PyContext *ctx = context_get(); + if (ctx == NULL) { + return NULL; + } + + return context_new_from_vars(ctx->ctx_vars); +} + + +int +PyContext_Enter(PyContext *ctx) +{ + if (ctx->ctx_entered) { + PyErr_Format(PyExc_RuntimeError, + "cannot enter context: %R is already entered", ctx); + return -1; + } + + PyThreadState *ts = PyThreadState_Get(); + + ctx->ctx_prev = (PyContext *)ts->context; /* borrow */ + ctx->ctx_entered = 1; + + Py_INCREF(ctx); + ts->context = (PyObject *)ctx; + ts->context_ver++; + + return 0; +} + + +int +PyContext_Exit(PyContext *ctx) +{ + if (!ctx->ctx_entered) { + PyErr_Format(PyExc_RuntimeError, + "cannot exit context: %R has not been entered", ctx); + return -1; + } + + PyThreadState *ts = PyThreadState_Get(); + + if (ts->context != (PyObject *)ctx) { + /* Can only happen if someone misuses the C API */ + PyErr_SetString(PyExc_RuntimeError, + "cannot exit context: thread state references " + "a different context object"); + return -1; + } + + Py_SETREF(ts->context, (PyObject *)ctx->ctx_prev); + ts->context_ver++; + + ctx->ctx_prev = NULL; + ctx->ctx_entered = 0; + + return 0; +} + + +PyContextVar * +PyContextVar_New(const char *name, PyObject *def) +{ + PyObject *pyname = PyUnicode_FromString(name); + if (pyname == NULL) { + return NULL; + } + return contextvar_new(pyname, def); +} + + +int +PyContextVar_Get(PyContextVar *var, PyObject *def, PyObject **val) +{ + assert(PyContextVar_CheckExact(var)); + + PyThreadState *ts = PyThreadState_Get(); + if (ts->context == NULL) { + goto not_found; + } + + if (var->var_cached != NULL && + var->var_cached_tsid == ts->id && + var->var_cached_tsver == ts->context_ver) + { + *val = var->var_cached; + goto found; + } + + assert(PyContext_CheckExact(ts->context)); + PyHamtObject *vars = ((PyContext *)ts->context)->ctx_vars; + + PyObject *found = NULL; + int res = _PyHamt_Find(vars, (PyObject*)var, &found); + if (res < 0) { + goto error; + } + if (res == 1) { + assert(found != NULL); + var->var_cached = found; /* borrow */ + var->var_cached_tsid = ts->id; + var->var_cached_tsver = ts->context_ver; + + *val = found; + goto found; + } + +not_found: + if (def == NULL) { + if (var->var_default != NULL) { + *val = var->var_default; + goto found; + } + + *val = NULL; + goto found; + } + else { + *val = def; + goto found; + } + +found: + Py_XINCREF(*val); + return 0; + +error: + *val = NULL; + return -1; +} + + +PyContextToken * +PyContextVar_Set(PyContextVar *var, PyObject *val) +{ + if (!PyContextVar_CheckExact(var)) { + PyErr_SetString( + PyExc_TypeError, "an instance of ContextVar was expected"); + return NULL; + } + + PyContext *ctx = context_get(); + if (ctx == NULL) { + return NULL; + } + + PyObject *old_val = NULL; + int found = _PyHamt_Find(ctx->ctx_vars, (PyObject *)var, &old_val); + if (found < 0) { + return NULL; + } + + Py_XINCREF(old_val); + PyContextToken *tok = token_new(ctx, var, old_val); + Py_XDECREF(old_val); + + if (contextvar_set(var, val)) { + Py_DECREF(tok); + return NULL; + } + + return tok; +} + + +int +PyContextVar_Reset(PyContextVar *var, PyContextToken *tok) +{ + if (tok->tok_used) { + PyErr_Format(PyExc_RuntimeError, + "%R has already been used once", tok); + return -1; + } + + if (var != tok->tok_var) { + PyErr_Format(PyExc_ValueError, + "%R was created by a different ContextVar", tok); + return -1; + } + + PyContext *ctx = context_get(); + if (ctx != tok->tok_ctx) { + PyErr_Format(PyExc_ValueError, + "%R was created in a different Context", tok); + return -1; + } + + tok->tok_used = 1; + + if (tok->tok_oldval == NULL) { + return contextvar_del(var); + } + else { + return contextvar_set(var, tok->tok_oldval); + } +} + + +/////////////////////////// PyContext + +/*[clinic input] +class _contextvars.Context "PyContext *" "&PyContext_Type" +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=bdf87f8e0cb580e8]*/ + + +static inline PyContext * +_context_alloc(void) +{ + PyContext *ctx; + if (ctx_freelist_len) { + ctx_freelist_len--; + ctx = ctx_freelist; + ctx_freelist = (PyContext *)ctx->ctx_weakreflist; + ctx->ctx_weakreflist = NULL; + _Py_NewReference((PyObject *)ctx); + } + else { + ctx = PyObject_GC_New(PyContext, &PyContext_Type); + if (ctx == NULL) { + return NULL; + } + } + + ctx->ctx_vars = NULL; + ctx->ctx_prev = NULL; + ctx->ctx_entered = 0; + ctx->ctx_weakreflist = NULL; + + return ctx; +} + + +static PyContext * +context_new_empty(void) +{ + PyContext *ctx = _context_alloc(); + if (ctx == NULL) { + return NULL; + } + + ctx->ctx_vars = _PyHamt_New(); + if (ctx->ctx_vars == NULL) { + Py_DECREF(ctx); + return NULL; + } + + _PyObject_GC_TRACK(ctx); + return ctx; +} + + +static PyContext * +context_new_from_vars(PyHamtObject *vars) +{ + PyContext *ctx = _context_alloc(); + if (ctx == NULL) { + return NULL; + } + + Py_INCREF(vars); + ctx->ctx_vars = vars; + + _PyObject_GC_TRACK(ctx); + return ctx; +} + + +static inline PyContext * +context_get(void) +{ + PyThreadState *ts = PyThreadState_Get(); + PyContext *current_ctx = (PyContext *)ts->context; + if (current_ctx == NULL) { + current_ctx = context_new_empty(); + if (current_ctx == NULL) { + return NULL; + } + ts->context = (PyObject *)current_ctx; + } + return current_ctx; +} + +static int +context_check_key_type(PyObject *key) +{ + if (!PyContextVar_CheckExact(key)) { + // abort(); + PyErr_Format(PyExc_TypeError, + "a ContextVar key was expected, got %R", key); + return -1; + } + return 0; +} + +static PyObject * +context_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + if (PyTuple_Size(args) || (kwds != NULL && PyDict_Size(kwds))) { + PyErr_SetString( + PyExc_TypeError, "Context() does not accept any arguments"); + return NULL; + } + return (PyObject *)PyContext_New(); +} + +static int +context_tp_clear(PyContext *self) +{ + Py_CLEAR(self->ctx_prev); + Py_CLEAR(self->ctx_vars); + return 0; +} + +static int +context_tp_traverse(PyContext *self, visitproc visit, void *arg) +{ + Py_VISIT(self->ctx_prev); + Py_VISIT(self->ctx_vars); + return 0; +} + +static void +context_tp_dealloc(PyContext *self) +{ + _PyObject_GC_UNTRACK(self); + + if (self->ctx_weakreflist != NULL) { + PyObject_ClearWeakRefs((PyObject*)self); + } + (void)context_tp_clear(self); + + if (ctx_freelist_len < CONTEXT_FREELIST_MAXLEN) { + ctx_freelist_len++; + self->ctx_weakreflist = (PyObject *)ctx_freelist; + ctx_freelist = self; + } + else { + Py_TYPE(self)->tp_free(self); + } +} + +static PyObject * +context_tp_iter(PyContext *self) +{ + return _PyHamt_NewIterKeys(self->ctx_vars); +} + +static PyObject * +context_tp_richcompare(PyObject *v, PyObject *w, int op) +{ + if (!PyContext_CheckExact(v) || !PyContext_CheckExact(w) || + (op != Py_EQ && op != Py_NE)) + { + Py_RETURN_NOTIMPLEMENTED; + } + + int res = _PyHamt_Eq( + ((PyContext *)v)->ctx_vars, ((PyContext *)w)->ctx_vars); + if (res < 0) { + return NULL; + } + + if (op == Py_NE) { + res = !res; + } + + if (res) { + Py_RETURN_TRUE; + } + else { + Py_RETURN_FALSE; + } +} + +static Py_ssize_t +context_tp_len(PyContext *self) +{ + return _PyHamt_Len(self->ctx_vars); +} + +static PyObject * +context_tp_subscript(PyContext *self, PyObject *key) +{ + if (context_check_key_type(key)) { + return NULL; + } + PyObject *val = NULL; + int found = _PyHamt_Find(self->ctx_vars, key, &val); + if (found < 0) { + return NULL; + } + if (found == 0) { + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + } + Py_INCREF(val); + return val; +} + +static int +context_tp_contains(PyContext *self, PyObject *key) +{ + if (context_check_key_type(key)) { + return -1; + } + PyObject *val = NULL; + return _PyHamt_Find(self->ctx_vars, key, &val); +} + + +/*[clinic input] +_contextvars.Context.get + key: object + default: object = None + / +[clinic start generated code]*/ + +static PyObject * +_contextvars_Context_get_impl(PyContext *self, PyObject *key, + PyObject *default_value) +/*[clinic end generated code: output=0c54aa7664268189 input=8d4c33c8ecd6d769]*/ +{ + if (context_check_key_type(key)) { + return NULL; + } + + PyObject *val = NULL; + int found = _PyHamt_Find(self->ctx_vars, key, &val); + if (found < 0) { + return NULL; + } + if (found == 0) { + Py_INCREF(default_value); + return default_value; + } + Py_INCREF(val); + return val; +} + + +/*[clinic input] +_contextvars.Context.items +[clinic start generated code]*/ + +static PyObject * +_contextvars_Context_items_impl(PyContext *self) +/*[clinic end generated code: output=fa1655c8a08502af input=2d570d1455004979]*/ +{ + return _PyHamt_NewIterItems(self->ctx_vars); +} + + +/*[clinic input] +_contextvars.Context.keys +[clinic start generated code]*/ + +static PyObject * +_contextvars_Context_keys_impl(PyContext *self) +/*[clinic end generated code: output=177227c6b63ec0e2 input=13005e142fbbf37d]*/ +{ + return _PyHamt_NewIterKeys(self->ctx_vars); +} + + +/*[clinic input] +_contextvars.Context.values +[clinic start generated code]*/ + +static PyObject * +_contextvars_Context_values_impl(PyContext *self) +/*[clinic end generated code: output=d286dabfc8db6dde input=c2cbc40a4470e905]*/ +{ + return _PyHamt_NewIterValues(self->ctx_vars); +} + + +/*[clinic input] +_contextvars.Context.copy +[clinic start generated code]*/ + +static PyObject * +_contextvars_Context_copy_impl(PyContext *self) +/*[clinic end generated code: output=30ba8896c4707a15 input=3e3fd72d598653ab]*/ +{ + return (PyObject *)context_new_from_vars(self->ctx_vars); +} + + +static PyObject * +context_run(PyContext *self, PyObject *const *args, + Py_ssize_t nargs, PyObject *kwnames) +{ + if (nargs < 1) { + PyErr_SetString(PyExc_TypeError, + "run() missing 1 required positional argument"); + return NULL; + } + + if (PyContext_Enter(self)) { + return NULL; + } + + PyObject *call_result = _PyObject_FastCallKeywords( + args[0], args + 1, nargs - 1, kwnames); + + if (PyContext_Exit(self)) { + return NULL; + } + + return call_result; +} + + +static PyMethodDef PyContext_methods[] = { + _CONTEXTVARS_CONTEXT_GET_METHODDEF + _CONTEXTVARS_CONTEXT_ITEMS_METHODDEF + _CONTEXTVARS_CONTEXT_KEYS_METHODDEF + _CONTEXTVARS_CONTEXT_VALUES_METHODDEF + _CONTEXTVARS_CONTEXT_COPY_METHODDEF + {"run", (PyCFunction)context_run, METH_FASTCALL | METH_KEYWORDS, NULL}, + {NULL, NULL} +}; + +static PySequenceMethods PyContext_as_sequence = { + 0, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + 0, /* sq_item */ + 0, /* sq_slice */ + 0, /* sq_ass_item */ + 0, /* sq_ass_slice */ + (objobjproc)context_tp_contains, /* sq_contains */ + 0, /* sq_inplace_concat */ + 0, /* sq_inplace_repeat */ +}; + +static PyMappingMethods PyContext_as_mapping = { + (lenfunc)context_tp_len, /* mp_length */ + (binaryfunc)context_tp_subscript, /* mp_subscript */ +}; + +PyTypeObject PyContext_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "Context", + sizeof(PyContext), + .tp_methods = PyContext_methods, + .tp_as_mapping = &PyContext_as_mapping, + .tp_as_sequence = &PyContext_as_sequence, + .tp_iter = (getiterfunc)context_tp_iter, + .tp_dealloc = (destructor)context_tp_dealloc, + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_richcompare = context_tp_richcompare, + .tp_traverse = (traverseproc)context_tp_traverse, + .tp_clear = (inquiry)context_tp_clear, + .tp_new = context_tp_new, + .tp_weaklistoffset = offsetof(PyContext, ctx_weakreflist), + .tp_hash = PyObject_HashNotImplemented, +}; + + +/////////////////////////// ContextVar + + +static int +contextvar_set(PyContextVar *var, PyObject *val) +{ + var->var_cached = NULL; + PyThreadState *ts = PyThreadState_Get(); + + PyContext *ctx = context_get(); + if (ctx == NULL) { + return -1; + } + + PyHamtObject *new_vars = _PyHamt_Assoc( + ctx->ctx_vars, (PyObject *)var, val); + if (new_vars == NULL) { + return -1; + } + + Py_SETREF(ctx->ctx_vars, new_vars); + + var->var_cached = val; /* borrow */ + var->var_cached_tsid = ts->id; + var->var_cached_tsver = ts->context_ver; + return 0; +} + +static int +contextvar_del(PyContextVar *var) +{ + var->var_cached = NULL; + + PyContext *ctx = context_get(); + if (ctx == NULL) { + return -1; + } + + PyHamtObject *vars = ctx->ctx_vars; + PyHamtObject *new_vars = _PyHamt_Without(vars, (PyObject *)var); + if (new_vars == NULL) { + return -1; + } + + if (vars == new_vars) { + Py_DECREF(new_vars); + PyErr_SetObject(PyExc_LookupError, (PyObject *)var); + return -1; + } + + Py_SETREF(ctx->ctx_vars, new_vars); + return 0; +} + +static Py_hash_t +contextvar_generate_hash(void *addr, PyObject *name) +{ + /* Take hash of `name` and XOR it with the object's addr. + + The structure of the tree is encoded in objects' hashes, which + means that sufficiently similar hashes would result in tall trees + with many Collision nodes. Which would, in turn, result in slower + get and set operations. + + The XORing helps to ensure that: + + (1) sequentially allocated ContextVar objects have + different hashes; + + (2) context variables with equal names have + different hashes. + */ + + Py_hash_t name_hash = PyObject_Hash(name); + if (name_hash == -1) { + return -1; + } + + Py_hash_t res = _Py_HashPointer(addr) ^ name_hash; + return res == -1 ? -2 : res; +} + +static PyContextVar * +contextvar_new(PyObject *name, PyObject *def) +{ + if (!PyUnicode_Check(name)) { + PyErr_SetString(PyExc_TypeError, + "context variable name must be a str"); + return NULL; + } + + PyContextVar *var = PyObject_GC_New(PyContextVar, &PyContextVar_Type); + if (var == NULL) { + return NULL; + } + + var->var_hash = contextvar_generate_hash(var, name); + if (var->var_hash == -1) { + Py_DECREF(var); + return NULL; + } + + Py_INCREF(name); + var->var_name = name; + + Py_XINCREF(def); + var->var_default = def; + + var->var_cached = NULL; + var->var_cached_tsid = 0; + var->var_cached_tsver = 0; + + if (_PyObject_GC_IS_TRACKED(name) || + (def != NULL && _PyObject_GC_IS_TRACKED(def))) + { + PyObject_GC_Track(var); + } + return var; +} + + +/*[clinic input] +class _contextvars.ContextVar "PyContextVar *" "&PyContextVar_Type" +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=445da935fa8883c3]*/ + + +static PyObject * +contextvar_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"", "default", NULL}; + PyObject *name; + PyObject *def = NULL; + + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "O|$O:ContextVar", kwlist, &name, &def)) + { + return NULL; + } + + return (PyObject *)contextvar_new(name, def); +} + +static int +contextvar_tp_clear(PyContextVar *self) +{ + Py_CLEAR(self->var_name); + Py_CLEAR(self->var_default); + self->var_cached = NULL; + self->var_cached_tsid = 0; + self->var_cached_tsver = 0; + return 0; +} + +static int +contextvar_tp_traverse(PyContextVar *self, visitproc visit, void *arg) +{ + Py_VISIT(self->var_name); + Py_VISIT(self->var_default); + return 0; +} + +static void +contextvar_tp_dealloc(PyContextVar *self) +{ + PyObject_GC_UnTrack(self); + (void)contextvar_tp_clear(self); + Py_TYPE(self)->tp_free(self); +} + +static Py_hash_t +contextvar_tp_hash(PyContextVar *self) +{ + return self->var_hash; +} + +static PyObject * +contextvar_tp_repr(PyContextVar *self) +{ + _PyUnicodeWriter writer; + + _PyUnicodeWriter_Init(&writer); + + if (_PyUnicodeWriter_WriteASCIIString( + &writer, "", self); + if (addr == NULL) { + goto error; + } + if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) { + Py_DECREF(addr); + goto error; + } + Py_DECREF(addr); + + return _PyUnicodeWriter_Finish(&writer); + +error: + _PyUnicodeWriter_Dealloc(&writer); + return NULL; +} + + +/*[clinic input] +_contextvars.ContextVar.get + default: object = NULL + / +[clinic start generated code]*/ + +static PyObject * +_contextvars_ContextVar_get_impl(PyContextVar *self, PyObject *default_value) +/*[clinic end generated code: output=0746bd0aa2ced7bf input=8d002b02eebbb247]*/ +{ + if (!PyContextVar_CheckExact(self)) { + PyErr_SetString( + PyExc_TypeError, "an instance of ContextVar was expected"); + return NULL; + } + + PyObject *val; + if (PyContextVar_Get(self, default_value, &val) < 0) { + return NULL; + } + + if (val == NULL) { + PyErr_SetObject(PyExc_LookupError, (PyObject *)self); + return NULL; + } + + return val; +} + +/*[clinic input] +_contextvars.ContextVar.set + value: object + / +[clinic start generated code]*/ + +static PyObject * +_contextvars_ContextVar_set(PyContextVar *self, PyObject *value) +/*[clinic end generated code: output=446ed5e820d6d60b input=a2d88f57c6d86f7c]*/ +{ + return (PyObject *)PyContextVar_Set(self, value); +} + +/*[clinic input] +_contextvars.ContextVar.reset + token: object + / +[clinic start generated code]*/ + +static PyObject * +_contextvars_ContextVar_reset(PyContextVar *self, PyObject *token) +/*[clinic end generated code: output=d4ee34d0742d62ee input=4c871b6f1f31a65f]*/ +{ + if (!PyContextToken_CheckExact(token)) { + PyErr_Format(PyExc_TypeError, + "expected an instance of Token, got %R", token); + return NULL; + } + + if (PyContextVar_Reset(self, (PyContextToken *)token)) { + return NULL; + } + + Py_RETURN_NONE; +} + + +static PyObject * +contextvar_cls_getitem(PyObject *self, PyObject *args) +{ + Py_RETURN_NONE; +} + + +static PyMethodDef PyContextVar_methods[] = { + _CONTEXTVARS_CONTEXTVAR_GET_METHODDEF + _CONTEXTVARS_CONTEXTVAR_SET_METHODDEF + _CONTEXTVARS_CONTEXTVAR_RESET_METHODDEF + {"__class_getitem__", contextvar_cls_getitem, + METH_VARARGS | METH_STATIC, NULL}, + {NULL, NULL} +}; + +PyTypeObject PyContextVar_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "ContextVar", + sizeof(PyContextVar), + .tp_methods = PyContextVar_methods, + .tp_dealloc = (destructor)contextvar_tp_dealloc, + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_traverse = (traverseproc)contextvar_tp_traverse, + .tp_clear = (inquiry)contextvar_tp_clear, + .tp_new = contextvar_tp_new, + .tp_free = PyObject_GC_Del, + .tp_hash = (hashfunc)contextvar_tp_hash, + .tp_repr = (reprfunc)contextvar_tp_repr, +}; + + +/////////////////////////// Token + +static PyObject * get_token_missing(void); + + +/*[clinic input] +class _contextvars.Token "PyContextToken *" "&PyContextToken_Type" +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=338a5e2db13d3f5b]*/ + + +static PyObject * +token_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + PyErr_SetString(PyExc_RuntimeError, + "Tokens can only be created by ContextVars"); + return NULL; +} + +static int +token_tp_clear(PyContextToken *self) +{ + Py_CLEAR(self->tok_ctx); + Py_CLEAR(self->tok_var); + Py_CLEAR(self->tok_oldval); + return 0; +} + +static int +token_tp_traverse(PyContextToken *self, visitproc visit, void *arg) +{ + Py_VISIT(self->tok_ctx); + Py_VISIT(self->tok_var); + Py_VISIT(self->tok_oldval); + return 0; +} + +static void +token_tp_dealloc(PyContextToken *self) +{ + PyObject_GC_UnTrack(self); + (void)token_tp_clear(self); + Py_TYPE(self)->tp_free(self); +} + +static PyObject * +token_tp_repr(PyContextToken *self) +{ + _PyUnicodeWriter writer; + + _PyUnicodeWriter_Init(&writer); + + if (_PyUnicodeWriter_WriteASCIIString(&writer, "tok_used) { + if (_PyUnicodeWriter_WriteASCIIString(&writer, " used", 5) < 0) { + goto error; + } + } + + if (_PyUnicodeWriter_WriteASCIIString(&writer, " var=", 5) < 0) { + goto error; + } + + PyObject *var = PyObject_Repr((PyObject *)self->tok_var); + if (var == NULL) { + goto error; + } + if (_PyUnicodeWriter_WriteStr(&writer, var) < 0) { + Py_DECREF(var); + goto error; + } + Py_DECREF(var); + + PyObject *addr = PyUnicode_FromFormat(" at %p>", self); + if (addr == NULL) { + goto error; + } + if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) { + Py_DECREF(addr); + goto error; + } + Py_DECREF(addr); + + return _PyUnicodeWriter_Finish(&writer); + +error: + _PyUnicodeWriter_Dealloc(&writer); + return NULL; +} + +static PyObject * +token_get_var(PyContextToken *self) +{ + Py_INCREF(self->tok_var); + return (PyObject *)self->tok_var; +} + +static PyObject * +token_get_old_value(PyContextToken *self) +{ + if (self->tok_oldval == NULL) { + return get_token_missing(); + } + + Py_INCREF(self->tok_oldval); + return self->tok_oldval; +} + +static PyGetSetDef PyContextTokenType_getsetlist[] = { + {"var", (getter)token_get_var, NULL, NULL}, + {"old_value", (getter)token_get_old_value, NULL, NULL}, + {NULL} +}; + +PyTypeObject PyContextToken_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "Token", + sizeof(PyContextToken), + .tp_getset = PyContextTokenType_getsetlist, + .tp_dealloc = (destructor)token_tp_dealloc, + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_traverse = (traverseproc)token_tp_traverse, + .tp_clear = (inquiry)token_tp_clear, + .tp_new = token_tp_new, + .tp_free = PyObject_GC_Del, + .tp_hash = PyObject_HashNotImplemented, + .tp_repr = (reprfunc)token_tp_repr, +}; + +static PyContextToken * +token_new(PyContext *ctx, PyContextVar *var, PyObject *val) +{ + PyContextToken *tok = PyObject_GC_New(PyContextToken, &PyContextToken_Type); + if (tok == NULL) { + return NULL; + } + + Py_INCREF(ctx); + tok->tok_ctx = ctx; + + Py_INCREF(var); + tok->tok_var = var; + + Py_XINCREF(val); + tok->tok_oldval = val; + + tok->tok_used = 0; + + PyObject_GC_Track(tok); + return tok; +} + + +/////////////////////////// Token.MISSING + + +static PyObject *_token_missing; + + +typedef struct { + PyObject_HEAD +} PyContextTokenMissing; + + +static PyObject * +context_token_missing_tp_repr(PyObject *self) +{ + return PyUnicode_FromString(""); +} + + +PyTypeObject PyContextTokenMissing_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "Token.MISSING", + sizeof(PyContextTokenMissing), + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_repr = context_token_missing_tp_repr, +}; + + +static PyObject * +get_token_missing(void) +{ + if (_token_missing != NULL) { + Py_INCREF(_token_missing); + return _token_missing; + } + + _token_missing = (PyObject *)PyObject_New( + PyContextTokenMissing, &PyContextTokenMissing_Type); + if (_token_missing == NULL) { + return NULL; + } + + Py_INCREF(_token_missing); + return _token_missing; +} + + +/////////////////////////// + + +int +PyContext_ClearFreeList(void) +{ + int size = ctx_freelist_len; + while (ctx_freelist_len) { + PyContext *ctx = ctx_freelist; + ctx_freelist = (PyContext *)ctx->ctx_weakreflist; + ctx->ctx_weakreflist = NULL; + PyObject_GC_Del(ctx); + ctx_freelist_len--; + } + return size; +} + + +void +_PyContext_Fini(void) +{ + Py_CLEAR(_token_missing); + (void)PyContext_ClearFreeList(); + (void)_PyHamt_Fini(); +} + + +int +_PyContext_Init(void) +{ + if (!_PyHamt_Init()) { + return 0; + } + + if ((PyType_Ready(&PyContext_Type) < 0) || + (PyType_Ready(&PyContextVar_Type) < 0) || + (PyType_Ready(&PyContextToken_Type) < 0) || + (PyType_Ready(&PyContextTokenMissing_Type) < 0)) + { + return 0; + } + + PyObject *missing = get_token_missing(); + if (PyDict_SetItemString( + PyContextToken_Type.tp_dict, "MISSING", missing)) + { + Py_DECREF(missing); + return 0; + } + Py_DECREF(missing); + + return 1; +} diff --git a/Python/hamt.c b/Python/hamt.c new file mode 100644 index 00000000000..8ba5082026f --- /dev/null +++ b/Python/hamt.c @@ -0,0 +1,2982 @@ +#include "Python.h" + +#include "structmember.h" +#include "internal/pystate.h" +#include "internal/hamt.h" + +/* popcnt support in Visual Studio */ +#ifdef _MSC_VER +#include +#endif + +/* +This file provides an implemention of an immutable mapping using the +Hash Array Mapped Trie (or HAMT) datastructure. + +This design allows to have: + +1. Efficient copy: immutable mappings can be copied by reference, + making it an O(1) operation. + +2. Efficient mutations: due to structural sharing, only a portion of + the trie needs to be copied when the collection is mutated. The + cost of set/delete operations is O(log N). + +3. Efficient lookups: O(log N). + +(where N is number of key/value items in the immutable mapping.) + + +HAMT +==== + +The core idea of HAMT is that the shape of the trie is encoded into the +hashes of keys. + +Say we want to store a K/V pair in our mapping. First, we calculate the +hash of K, let's say it's 19830128, or in binary: + + 0b1001011101001010101110000 = 19830128 + +Now let's partition this bit representation of the hash into blocks of +5 bits each: + + 0b00_00000_10010_11101_00101_01011_10000 = 19830128 + (6) (5) (4) (3) (2) (1) + +Each block of 5 bits represents a number betwen 0 and 31. So if we have +a tree that consists of nodes, each of which is an array of 32 pointers, +those 5-bit blocks will encode a position on a single tree level. + +For example, storing the key K with hash 19830128, results in the following +tree structure: + + (array of 32 pointers) + +---+ -- +----+----+----+ -- +----+ + root node | 0 | .. | 15 | 16 | 17 | .. | 31 | 0b10000 = 16 (1) + (level 1) +---+ -- +----+----+----+ -- +----+ + | + +---+ -- +----+----+----+ -- +----+ + a 2nd level node | 0 | .. | 10 | 11 | 12 | .. | 31 | 0b01011 = 11 (2) + +---+ -- +----+----+----+ -- +----+ + | + +---+ -- +----+----+----+ -- +----+ + a 3rd level node | 0 | .. | 04 | 05 | 06 | .. | 31 | 0b01011 = 5 (3) + +---+ -- +----+----+----+ -- +----+ + | + +---+ -- +----+----+----+----+ + a 4th level node | 0 | .. | 04 | 29 | 30 | 31 | 0b11101 = 29 (4) + +---+ -- +----+----+----+----+ + | + +---+ -- +----+----+----+ -- +----+ + a 5th level node | 0 | .. | 17 | 18 | 19 | .. | 31 | 0b10010 = 18 (5) + +---+ -- +----+----+----+ -- +----+ + | + +--------------+ + | + +---+ -- +----+----+----+ -- +----+ + a 6th level node | 0 | .. | 15 | 16 | 17 | .. | 31 | 0b00000 = 0 (6) + +---+ -- +----+----+----+ -- +----+ + | + V -- our value (or collision) + +To rehash: for a K/V pair, the hash of K encodes where in the tree V will +be stored. + +To optimize memory footprint and handle hash collisions, our implementation +uses three different types of nodes: + + * A Bitmap node; + * An Array node; + * A Collision node. + +Because we implement an immutable dictionary, our nodes are also +immutable. Therefore, when we need to modify a node, we copy it, and +do that modification to the copy. + + +Array Nodes +----------- + +These nodes are very simple. Essentially they are arrays of 32 pointers +we used to illustrate the high-level idea in the previous section. + +We use Array nodes only when we need to store more than 16 pointers +in a single node. + +Array nodes do not store key objects or value objects. They are used +only as an indirection level - their pointers point to other nodes in +the tree. + + +Bitmap Node +----------- + +Allocating a new 32-pointers array for every node of our tree would be +very expensive. Unless we store millions of keys, most of tree nodes would +be very sparse. + +When we have less than 16 elements in a node, we don't want to use the +Array node, that would mean that we waste a lot of memory. Instead, +we can use bitmap compression and can have just as many pointers +as we need! + +Bitmap nodes consist of two fields: + +1. An array of pointers. If a Bitmap node holds N elements, the + array will be of N pointers. + +2. A 32bit integer -- a bitmap field. If an N-th bit is set in the + bitmap, it means that the node has an N-th element. + +For example, say we need to store a 3 elements sparse array: + + +---+ -- +---+ -- +----+ -- +----+ + | 0 | .. | 4 | .. | 11 | .. | 17 | + +---+ -- +---+ -- +----+ -- +----+ + | | | + o1 o2 o3 + +We allocate a three-pointer Bitmap node. Its bitmap field will be +then set to: + + 0b_00100_00010_00000_10000 == (1 << 17) | (1 << 11) | (1 << 4) + +To check if our Bitmap node has an I-th element we can do: + + bitmap & (1 << I) + + +And here's a formula to calculate a position in our pointer array +which would correspond to an I-th element: + + popcount(bitmap & ((1 << I) - 1)) + + +Let's break it down: + + * `popcount` is a function that returns a number of bits set to 1; + + * `((1 << I) - 1)` is a mask to filter the bitmask to contain bits + set to the *right* of our bit. + + +So for our 17, 11, and 4 indexes: + + * bitmap & ((1 << 17) - 1) == 0b100000010000 => 2 bits are set => index is 2. + + * bitmap & ((1 << 11) - 1) == 0b10000 => 1 bit is set => index is 1. + + * bitmap & ((1 << 4) - 1) == 0b0 => 0 bits are set => index is 0. + + +To conclude: Bitmap nodes are just like Array nodes -- they can store +a number of pointers, but use bitmap compression to eliminate unused +pointers. + + +Bitmap nodes have two pointers for each item: + + +----+----+----+----+ -- +----+----+ + | k1 | v1 | k2 | v2 | .. | kN | vN | + +----+----+----+----+ -- +----+----+ + +When kI == NULL, vI points to another tree level. + +When kI != NULL, the actual key object is stored in kI, and its +value is stored in vI. + + +Collision Nodes +--------------- + +Collision nodes are simple arrays of pointers -- two pointers per +key/value. When there's a hash collision, say for k1/v1 and k2/v2 +we have `hash(k1)==hash(k2)`. Then our collision node will be: + + +----+----+----+----+ + | k1 | v1 | k2 | v2 | + +----+----+----+----+ + + +Tree Structure +-------------- + +All nodes are PyObjects. + +The `PyHamtObject` object has a pointer to the root node (h_root), +and has a length field (h_count). + +High-level functions accept a PyHamtObject object and dispatch to +lower-level functions depending on what kind of node h_root points to. + + +Operations +========== + +There are three fundamental operations on an immutable dictionary: + +1. "o.assoc(k, v)" will return a new immutable dictionary, that will be + a copy of "o", but with the "k/v" item set. + + Functions in this file: + + hamt_node_assoc, hamt_node_bitmap_assoc, + hamt_node_array_assoc, hamt_node_collision_assoc + + `hamt_node_assoc` function accepts a node object, and calls + other functions depending on its actual type. + +2. "o.find(k)" will lookup key "k" in "o". + + Functions: + + hamt_node_find, hamt_node_bitmap_find, + hamt_node_array_find, hamt_node_collision_find + +3. "o.without(k)" will return a new immutable dictionary, that will be + a copy of "o", buth without the "k" key. + + Functions: + + hamt_node_without, hamt_node_bitmap_without, + hamt_node_array_without, hamt_node_collision_without + + +Further Reading +=============== + +1. http://blog.higher-order.net/2009/09/08/understanding-clojures-persistenthashmap-deftwice.html + +2. http://blog.higher-order.net/2010/08/16/assoc-and-clojures-persistenthashmap-part-ii.html + +3. Clojure's PersistentHashMap implementation: + https://github.com/clojure/clojure/blob/master/src/jvm/clojure/lang/PersistentHashMap.java + + +Debug +===== + +The HAMT datatype is accessible for testing purposes under the +`_testcapi` module: + + >>> from _testcapi import hamt + >>> h = hamt() + >>> h2 = h.set('a', 2) + >>> h3 = h2.set('b', 3) + >>> list(h3) + ['a', 'b'] + +When CPython is built in debug mode, a '__dump__()' method is available +to introspect the tree: + + >>> print(h3.__dump__()) + HAMT(len=2): + BitmapNode(size=4 count=2 bitmap=0b110 id=0x10eb9d9e8): + 'a': 2 + 'b': 3 +*/ + + +#define IS_ARRAY_NODE(node) (Py_TYPE(node) == &_PyHamt_ArrayNode_Type) +#define IS_BITMAP_NODE(node) (Py_TYPE(node) == &_PyHamt_BitmapNode_Type) +#define IS_COLLISION_NODE(node) (Py_TYPE(node) == &_PyHamt_CollisionNode_Type) + + +/* Return type for 'find' (lookup a key) functions. + + * F_ERROR - an error occurred; + * F_NOT_FOUND - the key was not found; + * F_FOUND - the key was found. +*/ +typedef enum {F_ERROR, F_NOT_FOUND, F_FOUND} hamt_find_t; + + +/* Return type for 'without' (delete a key) functions. + + * W_ERROR - an error occurred; + * W_NOT_FOUND - the key was not found: there's nothing to delete; + * W_EMPTY - the key was found: the node/tree would be empty + if the key is deleted; + * W_NEWNODE - the key was found: a new node/tree is returned + without that key. +*/ +typedef enum {W_ERROR, W_NOT_FOUND, W_EMPTY, W_NEWNODE} hamt_without_t; + + +/* Low-level iterator protocol type. + + * I_ITEM - a new item has been yielded; + * I_END - the whole tree was visited (similar to StopIteration). +*/ +typedef enum {I_ITEM, I_END} hamt_iter_t; + + +#define HAMT_ARRAY_NODE_SIZE 32 + + +typedef struct { + PyObject_HEAD + PyHamtNode *a_array[HAMT_ARRAY_NODE_SIZE]; + Py_ssize_t a_count; +} PyHamtNode_Array; + + +typedef struct { + PyObject_VAR_HEAD + uint32_t b_bitmap; + PyObject *b_array[1]; +} PyHamtNode_Bitmap; + + +typedef struct { + PyObject_VAR_HEAD + int32_t c_hash; + PyObject *c_array[1]; +} PyHamtNode_Collision; + + +static PyHamtNode_Bitmap *_empty_bitmap_node; +static PyHamtObject *_empty_hamt; + + +static PyHamtObject * +hamt_alloc(void); + +static PyHamtNode * +hamt_node_assoc(PyHamtNode *node, + uint32_t shift, int32_t hash, + PyObject *key, PyObject *val, int* added_leaf); + +static hamt_without_t +hamt_node_without(PyHamtNode *node, + uint32_t shift, int32_t hash, + PyObject *key, + PyHamtNode **new_node); + +static hamt_find_t +hamt_node_find(PyHamtNode *node, + uint32_t shift, int32_t hash, + PyObject *key, PyObject **val); + +#ifdef Py_DEBUG +static int +hamt_node_dump(PyHamtNode *node, + _PyUnicodeWriter *writer, int level); +#endif + +static PyHamtNode * +hamt_node_array_new(Py_ssize_t); + +static PyHamtNode * +hamt_node_collision_new(int32_t hash, Py_ssize_t size); + +static inline Py_ssize_t +hamt_node_collision_count(PyHamtNode_Collision *node); + + +#ifdef Py_DEBUG +static void +_hamt_node_array_validate(void *o) +{ + assert(IS_ARRAY_NODE(o)); + PyHamtNode_Array *node = (PyHamtNode_Array*)(o); + Py_ssize_t i = 0, count = 0; + for (; i < HAMT_ARRAY_NODE_SIZE; i++) { + if (node->a_array[i] != NULL) { + count++; + } + } + assert(count == node->a_count); +} + +#define VALIDATE_ARRAY_NODE(NODE) \ + do { _hamt_node_array_validate(NODE); } while (0); +#else +#define VALIDATE_ARRAY_NODE(NODE) +#endif + + +/* Returns -1 on error */ +static inline int32_t +hamt_hash(PyObject *o) +{ + Py_hash_t hash = PyObject_Hash(o); + +#if SIZEOF_PY_HASH_T <= 4 + return hash; +#else + if (hash == -1) { + /* exception */ + return -1; + } + + /* While it's suboptimal to reduce Python's 64 bit hash to + 32 bits via XOR, it seems that the resulting hash function + is good enough (this is also how Long type is hashed in Java.) + Storing 10, 100, 1000 Python strings results in a relatively + shallow and uniform tree structure. + + Please don't change this hashing algorithm, as there are many + tests that test some exact tree shape to cover all code paths. + */ + int32_t xored = (int32_t)(hash & 0xffffffffl) ^ (int32_t)(hash >> 32); + return xored == -1 ? -2 : xored; +#endif +} + +static inline uint32_t +hamt_mask(int32_t hash, uint32_t shift) +{ + return (((uint32_t)hash >> shift) & 0x01f); +} + +static inline uint32_t +hamt_bitpos(int32_t hash, uint32_t shift) +{ + return (uint32_t)1 << hamt_mask(hash, shift); +} + +static inline uint32_t +hamt_bitcount(uint32_t i) +{ +#if defined(__GNUC__) && (__GNUC__ > 4) + return (uint32_t)__builtin_popcountl(i); +#elif defined(__clang__) && (__clang_major__ > 3) + return (uint32_t)__builtin_popcountl(i); +#elif defined(_MSC_VER) + return (uint32_t)__popcnt(i); +#else + /* https://graphics.stanford.edu/~seander/bithacks.html */ + i = i - ((i >> 1) & 0x55555555); + i = (i & 0x33333333) + ((i >> 2) & 0x33333333); + return ((i + (i >> 4) & 0xF0F0F0F) * 0x1010101) >> 24; +#endif +} + +static inline uint32_t +hamt_bitindex(uint32_t bitmap, uint32_t bit) +{ + return hamt_bitcount(bitmap & (bit - 1)); +} + + +/////////////////////////////////// Dump Helpers +#ifdef Py_DEBUG + +static int +_hamt_dump_ident(_PyUnicodeWriter *writer, int level) +{ + /* Write `' ' * level` to the `writer` */ + PyObject *str = NULL; + PyObject *num = NULL; + PyObject *res = NULL; + int ret = -1; + + str = PyUnicode_FromString(" "); + if (str == NULL) { + goto error; + } + + num = PyLong_FromLong((long)level); + if (num == NULL) { + goto error; + } + + res = PyNumber_Multiply(str, num); + if (res == NULL) { + goto error; + } + + ret = _PyUnicodeWriter_WriteStr(writer, res); + +error: + Py_XDECREF(res); + Py_XDECREF(str); + Py_XDECREF(num); + return ret; +} + +static int +_hamt_dump_format(_PyUnicodeWriter *writer, const char *format, ...) +{ + /* A convenient helper combining _PyUnicodeWriter_WriteStr and + PyUnicode_FromFormatV. + */ + PyObject* msg; + int ret; + + va_list vargs; +#ifdef HAVE_STDARG_PROTOTYPES + va_start(vargs, format); +#else + va_start(vargs); +#endif + msg = PyUnicode_FromFormatV(format, vargs); + va_end(vargs); + + if (msg == NULL) { + return -1; + } + + ret = _PyUnicodeWriter_WriteStr(writer, msg); + Py_DECREF(msg); + return ret; +} + +#endif /* Py_DEBUG */ +/////////////////////////////////// Bitmap Node + + +static PyHamtNode * +hamt_node_bitmap_new(Py_ssize_t size) +{ + /* Create a new bitmap node of size 'size' */ + + PyHamtNode_Bitmap *node; + Py_ssize_t i; + + assert(size >= 0); + assert(size % 2 == 0); + + if (size == 0 && _empty_bitmap_node != NULL) { + Py_INCREF(_empty_bitmap_node); + return (PyHamtNode *)_empty_bitmap_node; + } + + /* No freelist; allocate a new bitmap node */ + node = PyObject_GC_NewVar( + PyHamtNode_Bitmap, &_PyHamt_BitmapNode_Type, size); + if (node == NULL) { + return NULL; + } + + Py_SIZE(node) = size; + + for (i = 0; i < size; i++) { + node->b_array[i] = NULL; + } + + node->b_bitmap = 0; + + _PyObject_GC_TRACK(node); + + if (size == 0 && _empty_bitmap_node == NULL) { + /* Since bitmap nodes are immutable, we can cache the instance + for size=0 and reuse it whenever we need an empty bitmap node. + */ + _empty_bitmap_node = node; + Py_INCREF(_empty_bitmap_node); + } + + return (PyHamtNode *)node; +} + +static inline Py_ssize_t +hamt_node_bitmap_count(PyHamtNode_Bitmap *node) +{ + return Py_SIZE(node) / 2; +} + +static PyHamtNode_Bitmap * +hamt_node_bitmap_clone(PyHamtNode_Bitmap *node) +{ + /* Clone a bitmap node; return a new one with the same child notes. */ + + PyHamtNode_Bitmap *clone; + Py_ssize_t i; + + clone = (PyHamtNode_Bitmap *)hamt_node_bitmap_new(Py_SIZE(node)); + if (clone == NULL) { + return NULL; + } + + for (i = 0; i < Py_SIZE(node); i++) { + Py_XINCREF(node->b_array[i]); + clone->b_array[i] = node->b_array[i]; + } + + clone->b_bitmap = node->b_bitmap; + return clone; +} + +static PyHamtNode_Bitmap * +hamt_node_bitmap_clone_without(PyHamtNode_Bitmap *o, uint32_t bit) +{ + assert(bit & o->b_bitmap); + assert(hamt_node_bitmap_count(o) > 1); + + PyHamtNode_Bitmap *new = (PyHamtNode_Bitmap *)hamt_node_bitmap_new( + Py_SIZE(o) - 2); + if (new == NULL) { + return NULL; + } + + uint32_t idx = hamt_bitindex(o->b_bitmap, bit); + uint32_t key_idx = 2 * idx; + uint32_t val_idx = key_idx + 1; + uint32_t i; + + for (i = 0; i < key_idx; i++) { + Py_XINCREF(o->b_array[i]); + new->b_array[i] = o->b_array[i]; + } + + for (i = val_idx + 1; i < Py_SIZE(o); i++) { + Py_XINCREF(o->b_array[i]); + new->b_array[i - 2] = o->b_array[i]; + } + + new->b_bitmap = o->b_bitmap & ~bit; + return new; +} + +static PyHamtNode * +hamt_node_new_bitmap_or_collision(uint32_t shift, + PyObject *key1, PyObject *val1, + int32_t key2_hash, + PyObject *key2, PyObject *val2) +{ + /* Helper method. Creates a new node for key1/val and key2/val2 + pairs. + + If key1 hash is equal to the hash of key2, a Collision node + will be created. If they are not equal, a Bitmap node is + created. + */ + + int32_t key1_hash = hamt_hash(key1); + if (key1_hash == -1) { + return NULL; + } + + if (key1_hash == key2_hash) { + PyHamtNode_Collision *n; + n = (PyHamtNode_Collision *)hamt_node_collision_new(key1_hash, 4); + if (n == NULL) { + return NULL; + } + + Py_INCREF(key1); + n->c_array[0] = key1; + Py_INCREF(val1); + n->c_array[1] = val1; + + Py_INCREF(key2); + n->c_array[2] = key2; + Py_INCREF(val2); + n->c_array[3] = val2; + + return (PyHamtNode *)n; + } + else { + int added_leaf = 0; + PyHamtNode *n = hamt_node_bitmap_new(0); + if (n == NULL) { + return NULL; + } + + PyHamtNode *n2 = hamt_node_assoc( + n, shift, key1_hash, key1, val1, &added_leaf); + Py_DECREF(n); + if (n2 == NULL) { + return NULL; + } + + n = hamt_node_assoc(n2, shift, key2_hash, key2, val2, &added_leaf); + Py_DECREF(n2); + if (n == NULL) { + return NULL; + } + + return n; + } +} + +static PyHamtNode * +hamt_node_bitmap_assoc(PyHamtNode_Bitmap *self, + uint32_t shift, int32_t hash, + PyObject *key, PyObject *val, int* added_leaf) +{ + /* assoc operation for bitmap nodes. + + Return: a new node, or self if key/val already is in the + collection. + + 'added_leaf' is later used in '_PyHamt_Assoc' to determine if + `hamt.set(key, val)` increased the size of the collection. + */ + + uint32_t bit = hamt_bitpos(hash, shift); + uint32_t idx = hamt_bitindex(self->b_bitmap, bit); + + /* Bitmap node layout: + + +------+------+------+------+ --- +------+------+ + | key1 | val1 | key2 | val2 | ... | keyN | valN | + +------+------+------+------+ --- +------+------+ + where `N < Py_SIZE(node)`. + + The `node->b_bitmap` field is a bitmap. For a given + `(shift, hash)` pair we can determine: + + - If this node has the corresponding key/val slots. + - The index of key/val slots. + */ + + if (self->b_bitmap & bit) { + /* The key is set in this node */ + + uint32_t key_idx = 2 * idx; + uint32_t val_idx = key_idx + 1; + + assert(val_idx < Py_SIZE(self)); + + PyObject *key_or_null = self->b_array[key_idx]; + PyObject *val_or_node = self->b_array[val_idx]; + + if (key_or_null == NULL) { + /* key is NULL. This means that we have a few keys + that have the same (hash, shift) pair. */ + + assert(val_or_node != NULL); + + PyHamtNode *sub_node = hamt_node_assoc( + (PyHamtNode *)val_or_node, + shift + 5, hash, key, val, added_leaf); + if (sub_node == NULL) { + return NULL; + } + + if (val_or_node == (PyObject *)sub_node) { + Py_DECREF(sub_node); + Py_INCREF(self); + return (PyHamtNode *)self; + } + + PyHamtNode_Bitmap *ret = hamt_node_bitmap_clone(self); + if (ret == NULL) { + return NULL; + } + Py_SETREF(ret->b_array[val_idx], (PyObject*)sub_node); + return (PyHamtNode *)ret; + } + + assert(key != NULL); + /* key is not NULL. This means that we have only one other + key in this collection that matches our hash for this shift. */ + + int comp_err = PyObject_RichCompareBool(key, key_or_null, Py_EQ); + if (comp_err < 0) { /* exception in __eq__ */ + return NULL; + } + if (comp_err == 1) { /* key == key_or_null */ + if (val == val_or_node) { + /* we already have the same key/val pair; return self. */ + Py_INCREF(self); + return (PyHamtNode *)self; + } + + /* We're setting a new value for the key we had before. + Make a new bitmap node with a replaced value, and return it. */ + PyHamtNode_Bitmap *ret = hamt_node_bitmap_clone(self); + if (ret == NULL) { + return NULL; + } + Py_INCREF(val); + Py_SETREF(ret->b_array[val_idx], val); + return (PyHamtNode *)ret; + } + + /* It's a new key, and it has the same index as *one* another key. + We have a collision. We need to create a new node which will + combine the existing key and the key we're adding. + + `hamt_node_new_bitmap_or_collision` will either create a new + Collision node if the keys have identical hashes, or + a new Bitmap node. + */ + PyHamtNode *sub_node = hamt_node_new_bitmap_or_collision( + shift + 5, + key_or_null, val_or_node, /* existing key/val */ + hash, + key, val /* new key/val */ + ); + if (sub_node == NULL) { + return NULL; + } + + PyHamtNode_Bitmap *ret = hamt_node_bitmap_clone(self); + if (ret == NULL) { + Py_DECREF(sub_node); + return NULL; + } + Py_SETREF(ret->b_array[key_idx], NULL); + Py_SETREF(ret->b_array[val_idx], (PyObject *)sub_node); + + *added_leaf = 1; + return (PyHamtNode *)ret; + } + else { + /* There was no key before with the same (shift,hash). */ + + uint32_t n = hamt_bitcount(self->b_bitmap); + + if (n >= 16) { + /* When we have a situation where we want to store more + than 16 nodes at one level of the tree, we no longer + want to use the Bitmap node with bitmap encoding. + + Instead we start using an Array node, which has + simpler (faster) implementation at the expense of + having prealocated 32 pointers for its keys/values + pairs. + + Small hamt objects (<30 keys) usually don't have any + Array nodes at all. Betwen ~30 and ~400 keys hamt + objects usually have one Array node, and usually it's + a root node. + */ + + uint32_t jdx = hamt_mask(hash, shift); + /* 'jdx' is the index of where the new key should be added + in the new Array node we're about to create. */ + + PyHamtNode *empty = NULL; + PyHamtNode_Array *new_node = NULL; + PyHamtNode *res = NULL; + + /* Create a new Array node. */ + new_node = (PyHamtNode_Array *)hamt_node_array_new(n + 1); + if (new_node == NULL) { + goto fin; + } + + /* Create an empty bitmap node for the next + hamt_node_assoc call. */ + empty = hamt_node_bitmap_new(0); + if (empty == NULL) { + goto fin; + } + + /* Make a new bitmap node for the key/val we're adding. + Set that bitmap node to new-array-node[jdx]. */ + new_node->a_array[jdx] = hamt_node_assoc( + empty, shift + 5, hash, key, val, added_leaf); + if (new_node->a_array[jdx] == NULL) { + goto fin; + } + + /* Copy existing key/value pairs from the current Bitmap + node to the new Array node we've just created. */ + Py_ssize_t i, j; + for (i = 0, j = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + if (((self->b_bitmap >> i) & 1) != 0) { + /* Ensure we don't accidentally override `jdx` element + we set few lines above. + */ + assert(new_node->a_array[i] == NULL); + + if (self->b_array[j] == NULL) { + new_node->a_array[i] = + (PyHamtNode *)self->b_array[j + 1]; + Py_INCREF(new_node->a_array[i]); + } + else { + int32_t rehash = hamt_hash(self->b_array[j]); + if (rehash == -1) { + goto fin; + } + + new_node->a_array[i] = hamt_node_assoc( + empty, shift + 5, + rehash, + self->b_array[j], + self->b_array[j + 1], + added_leaf); + + if (new_node->a_array[i] == NULL) { + goto fin; + } + } + j += 2; + } + } + + VALIDATE_ARRAY_NODE(new_node) + + /* That's it! */ + res = (PyHamtNode *)new_node; + + fin: + Py_XDECREF(empty); + if (res == NULL) { + Py_XDECREF(new_node); + } + return res; + } + else { + /* We have less than 16 keys at this level; let's just + create a new bitmap node out of this node with the + new key/val pair added. */ + + uint32_t key_idx = 2 * idx; + uint32_t val_idx = key_idx + 1; + Py_ssize_t i; + + *added_leaf = 1; + + /* Allocate new Bitmap node which can have one more key/val + pair in addition to what we have already. */ + PyHamtNode_Bitmap *new_node = + (PyHamtNode_Bitmap *)hamt_node_bitmap_new(2 * (n + 1)); + if (new_node == NULL) { + return NULL; + } + + /* Copy all keys/values that will be before the new key/value + we are adding. */ + for (i = 0; i < key_idx; i++) { + Py_XINCREF(self->b_array[i]); + new_node->b_array[i] = self->b_array[i]; + } + + /* Set the new key/value to the new Bitmap node. */ + Py_INCREF(key); + new_node->b_array[key_idx] = key; + Py_INCREF(val); + new_node->b_array[val_idx] = val; + + /* Copy all keys/values that will be after the new key/value + we are adding. */ + for (i = key_idx; i < Py_SIZE(self); i++) { + Py_XINCREF(self->b_array[i]); + new_node->b_array[i + 2] = self->b_array[i]; + } + + new_node->b_bitmap = self->b_bitmap | bit; + return (PyHamtNode *)new_node; + } + } +} + +static hamt_without_t +hamt_node_bitmap_without(PyHamtNode_Bitmap *self, + uint32_t shift, int32_t hash, + PyObject *key, + PyHamtNode **new_node) +{ + uint32_t bit = hamt_bitpos(hash, shift); + if ((self->b_bitmap & bit) == 0) { + return W_NOT_FOUND; + } + + uint32_t idx = hamt_bitindex(self->b_bitmap, bit); + + uint32_t key_idx = 2 * idx; + uint32_t val_idx = key_idx + 1; + + PyObject *key_or_null = self->b_array[key_idx]; + PyObject *val_or_node = self->b_array[val_idx]; + + if (key_or_null == NULL) { + /* key == NULL means that 'value' is another tree node. */ + + PyHamtNode *sub_node = NULL; + + hamt_without_t res = hamt_node_without( + (PyHamtNode *)val_or_node, + shift + 5, hash, key, &sub_node); + + switch (res) { + case W_EMPTY: + /* It's impossible for us to receive a W_EMPTY here: + + - Array nodes are converted to Bitmap nodes when + we delete 16th item from them; + + - Collision nodes are converted to Bitmap when + there is one item in them; + + - Bitmap node's without() inlines single-item + sub-nodes. + + So in no situation we can have a single-item + Bitmap child of another Bitmap node. + */ + Py_UNREACHABLE(); + + case W_NEWNODE: { + assert(sub_node != NULL); + + if (IS_BITMAP_NODE(sub_node)) { + PyHamtNode_Bitmap *sub_tree = (PyHamtNode_Bitmap *)sub_node; + if (hamt_node_bitmap_count(sub_tree) == 1 && + sub_tree->b_array[0] != NULL) + { + /* A bitmap node with one key/value pair. Just + merge it into this node. + + Note that we don't inline Bitmap nodes that + have a NULL key -- those nodes point to another + tree level, and we cannot simply move tree levels + up or down. + */ + + PyHamtNode_Bitmap *clone = hamt_node_bitmap_clone(self); + if (clone == NULL) { + Py_DECREF(sub_node); + return W_ERROR; + } + + PyObject *key = sub_tree->b_array[0]; + PyObject *val = sub_tree->b_array[1]; + + Py_INCREF(key); + Py_XSETREF(clone->b_array[key_idx], key); + Py_INCREF(val); + Py_SETREF(clone->b_array[val_idx], val); + + Py_DECREF(sub_tree); + + *new_node = (PyHamtNode *)clone; + return W_NEWNODE; + } + } + +#ifdef Py_DEBUG + /* Ensure that Collision.without implementation + converts to Bitmap nodes itself. + */ + if (IS_COLLISION_NODE(sub_node)) { + assert(hamt_node_collision_count( + (PyHamtNode_Collision*)sub_node) > 1); + } +#endif + + PyHamtNode_Bitmap *clone = hamt_node_bitmap_clone(self); + Py_SETREF(clone->b_array[val_idx], + (PyObject *)sub_node); /* borrow */ + + *new_node = (PyHamtNode *)clone; + return W_NEWNODE; + } + + case W_ERROR: + case W_NOT_FOUND: + assert(sub_node == NULL); + return res; + + default: + Py_UNREACHABLE(); + } + } + else { + /* We have a regular key/value pair */ + + int cmp = PyObject_RichCompareBool(key_or_null, key, Py_EQ); + if (cmp < 0) { + return W_ERROR; + } + if (cmp == 0) { + return W_NOT_FOUND; + } + + if (hamt_node_bitmap_count(self) == 1) { + return W_EMPTY; + } + + *new_node = (PyHamtNode *) + hamt_node_bitmap_clone_without(self, bit); + if (*new_node == NULL) { + return W_ERROR; + } + + return W_NEWNODE; + } +} + +static hamt_find_t +hamt_node_bitmap_find(PyHamtNode_Bitmap *self, + uint32_t shift, int32_t hash, + PyObject *key, PyObject **val) +{ + /* Lookup a key in a Bitmap node. */ + + uint32_t bit = hamt_bitpos(hash, shift); + uint32_t idx; + uint32_t key_idx; + uint32_t val_idx; + PyObject *key_or_null; + PyObject *val_or_node; + int comp_err; + + if ((self->b_bitmap & bit) == 0) { + return F_NOT_FOUND; + } + + idx = hamt_bitindex(self->b_bitmap, bit); + assert(idx >= 0); + key_idx = idx * 2; + val_idx = key_idx + 1; + + assert(val_idx < Py_SIZE(self)); + + key_or_null = self->b_array[key_idx]; + val_or_node = self->b_array[val_idx]; + + if (key_or_null == NULL) { + /* There are a few keys that have the same hash at the current shift + that match our key. Dispatch the lookup further down the tree. */ + assert(val_or_node != NULL); + return hamt_node_find((PyHamtNode *)val_or_node, + shift + 5, hash, key, val); + } + + /* We have only one key -- a potential match. Let's compare if the + key we are looking at is equal to the key we are looking for. */ + assert(key != NULL); + comp_err = PyObject_RichCompareBool(key, key_or_null, Py_EQ); + if (comp_err < 0) { /* exception in __eq__ */ + return F_ERROR; + } + if (comp_err == 1) { /* key == key_or_null */ + *val = val_or_node; + return F_FOUND; + } + + return F_NOT_FOUND; +} + +static int +hamt_node_bitmap_traverse(PyHamtNode_Bitmap *self, visitproc visit, void *arg) +{ + /* Bitmap's tp_traverse */ + + Py_ssize_t i; + + for (i = Py_SIZE(self); --i >= 0; ) { + Py_VISIT(self->b_array[i]); + } + + return 0; +} + +static void +hamt_node_bitmap_dealloc(PyHamtNode_Bitmap *self) +{ + /* Bitmap's tp_dealloc */ + + Py_ssize_t len = Py_SIZE(self); + Py_ssize_t i; + + PyObject_GC_UnTrack(self); + Py_TRASHCAN_SAFE_BEGIN(self) + + if (len > 0) { + i = len; + while (--i >= 0) { + Py_XDECREF(self->b_array[i]); + } + } + + Py_TYPE(self)->tp_free((PyObject *)self); + Py_TRASHCAN_SAFE_END(self) +} + +#ifdef Py_DEBUG +static int +hamt_node_bitmap_dump(PyHamtNode_Bitmap *node, + _PyUnicodeWriter *writer, int level) +{ + /* Debug build: __dump__() method implementation for Bitmap nodes. */ + + Py_ssize_t i; + PyObject *tmp1; + PyObject *tmp2; + + if (_hamt_dump_ident(writer, level + 1)) { + goto error; + } + + if (_hamt_dump_format(writer, "BitmapNode(size=%zd count=%zd ", + Py_SIZE(node), Py_SIZE(node) / 2)) + { + goto error; + } + + tmp1 = PyLong_FromUnsignedLong(node->b_bitmap); + if (tmp1 == NULL) { + goto error; + } + tmp2 = _PyLong_Format(tmp1, 2); + Py_DECREF(tmp1); + if (tmp2 == NULL) { + goto error; + } + if (_hamt_dump_format(writer, "bitmap=%S id=%p):\n", tmp2, node)) { + Py_DECREF(tmp2); + goto error; + } + Py_DECREF(tmp2); + + for (i = 0; i < Py_SIZE(node); i += 2) { + PyObject *key_or_null = node->b_array[i]; + PyObject *val_or_node = node->b_array[i + 1]; + + if (_hamt_dump_ident(writer, level + 2)) { + goto error; + } + + if (key_or_null == NULL) { + if (_hamt_dump_format(writer, "NULL:\n")) { + goto error; + } + + if (hamt_node_dump((PyHamtNode *)val_or_node, + writer, level + 2)) + { + goto error; + } + } + else { + if (_hamt_dump_format(writer, "%R: %R", key_or_null, + val_or_node)) + { + goto error; + } + } + + if (_hamt_dump_format(writer, "\n")) { + goto error; + } + } + + return 0; +error: + return -1; +} +#endif /* Py_DEBUG */ + + +/////////////////////////////////// Collision Node + + +static PyHamtNode * +hamt_node_collision_new(int32_t hash, Py_ssize_t size) +{ + /* Create a new Collision node. */ + + PyHamtNode_Collision *node; + Py_ssize_t i; + + assert(size >= 4); + assert(size % 2 == 0); + + node = PyObject_GC_NewVar( + PyHamtNode_Collision, &_PyHamt_CollisionNode_Type, size); + if (node == NULL) { + return NULL; + } + + for (i = 0; i < size; i++) { + node->c_array[i] = NULL; + } + + Py_SIZE(node) = size; + node->c_hash = hash; + + _PyObject_GC_TRACK(node); + + return (PyHamtNode *)node; +} + +static hamt_find_t +hamt_node_collision_find_index(PyHamtNode_Collision *self, PyObject *key, + Py_ssize_t *idx) +{ + /* Lookup `key` in the Collision node `self`. Set the index of the + found key to 'idx'. */ + + Py_ssize_t i; + PyObject *el; + + for (i = 0; i < Py_SIZE(self); i += 2) { + el = self->c_array[i]; + + assert(el != NULL); + int cmp = PyObject_RichCompareBool(key, el, Py_EQ); + if (cmp < 0) { + return F_ERROR; + } + if (cmp == 1) { + *idx = i; + return F_FOUND; + } + } + + return F_NOT_FOUND; +} + +static PyHamtNode * +hamt_node_collision_assoc(PyHamtNode_Collision *self, + uint32_t shift, int32_t hash, + PyObject *key, PyObject *val, int* added_leaf) +{ + /* Set a new key to this level (currently a Collision node) + of the tree. */ + + if (hash == self->c_hash) { + /* The hash of the 'key' we are adding matches the hash of + other keys in this Collision node. */ + + Py_ssize_t key_idx = -1; + hamt_find_t found; + PyHamtNode_Collision *new_node; + Py_ssize_t i; + + /* Let's try to lookup the new 'key', maybe we already have it. */ + found = hamt_node_collision_find_index(self, key, &key_idx); + switch (found) { + case F_ERROR: + /* Exception. */ + return NULL; + + case F_NOT_FOUND: + /* This is a totally new key. Clone the current node, + add a new key/value to the cloned node. */ + + new_node = (PyHamtNode_Collision *)hamt_node_collision_new( + self->c_hash, Py_SIZE(self) + 2); + if (new_node == NULL) { + return NULL; + } + + for (i = 0; i < Py_SIZE(self); i++) { + Py_INCREF(self->c_array[i]); + new_node->c_array[i] = self->c_array[i]; + } + + Py_INCREF(key); + new_node->c_array[i] = key; + Py_INCREF(val); + new_node->c_array[i + 1] = val; + + *added_leaf = 1; + return (PyHamtNode *)new_node; + + case F_FOUND: + /* There's a key which is equal to the key we are adding. */ + + assert(key_idx >= 0); + assert(key_idx < Py_SIZE(self)); + Py_ssize_t val_idx = key_idx + 1; + + if (self->c_array[val_idx] == val) { + /* We're setting a key/value pair that's already set. */ + Py_INCREF(self); + return (PyHamtNode *)self; + } + + /* We need to replace old value for the key + with a new value. Create a new Collision node.*/ + new_node = (PyHamtNode_Collision *)hamt_node_collision_new( + self->c_hash, Py_SIZE(self)); + if (new_node == NULL) { + return NULL; + } + + /* Copy all elements of the old node to the new one. */ + for (i = 0; i < Py_SIZE(self); i++) { + Py_INCREF(self->c_array[i]); + new_node->c_array[i] = self->c_array[i]; + } + + /* Replace the old value with the new value for the our key. */ + Py_DECREF(new_node->c_array[val_idx]); + Py_INCREF(val); + new_node->c_array[val_idx] = val; + + return (PyHamtNode *)new_node; + + default: + Py_UNREACHABLE(); + } + } + else { + /* The hash of the new key is different from the hash that + all keys of this Collision node have. + + Create a Bitmap node inplace with two children: + key/value pair that we're adding, and the Collision node + we're replacing on this tree level. + */ + + PyHamtNode_Bitmap *new_node; + PyHamtNode *assoc_res; + + new_node = (PyHamtNode_Bitmap *)hamt_node_bitmap_new(2); + if (new_node == NULL) { + return NULL; + } + new_node->b_bitmap = hamt_bitpos(self->c_hash, shift); + Py_INCREF(self); + new_node->b_array[1] = (PyObject*) self; + + assoc_res = hamt_node_bitmap_assoc( + new_node, shift, hash, key, val, added_leaf); + Py_DECREF(new_node); + return assoc_res; + } +} + +static inline Py_ssize_t +hamt_node_collision_count(PyHamtNode_Collision *node) +{ + return Py_SIZE(node) / 2; +} + +static hamt_without_t +hamt_node_collision_without(PyHamtNode_Collision *self, + uint32_t shift, int32_t hash, + PyObject *key, + PyHamtNode **new_node) +{ + if (hash != self->c_hash) { + return W_NOT_FOUND; + } + + Py_ssize_t key_idx = -1; + hamt_find_t found = hamt_node_collision_find_index(self, key, &key_idx); + + switch (found) { + case F_ERROR: + return W_ERROR; + + case F_NOT_FOUND: + return W_NOT_FOUND; + + case F_FOUND: + assert(key_idx >= 0); + assert(key_idx < Py_SIZE(self)); + + Py_ssize_t new_count = hamt_node_collision_count(self) - 1; + + if (new_count == 0) { + /* The node has only one key/value pair and it's for the + key we're trying to delete. So a new node will be empty + after the removal. + */ + return W_EMPTY; + } + + if (new_count == 1) { + /* The node has two keys, and after deletion the + new Collision node would have one. Collision nodes + with one key shouldn't exist, co convert it to a + Bitmap node. + */ + PyHamtNode_Bitmap *node = (PyHamtNode_Bitmap *) + hamt_node_bitmap_new(2); + if (node == NULL) { + return W_ERROR; + } + + if (key_idx == 0) { + Py_INCREF(self->c_array[2]); + node->b_array[0] = self->c_array[2]; + Py_INCREF(self->c_array[3]); + node->b_array[1] = self->c_array[3]; + } + else { + assert(key_idx == 2); + Py_INCREF(self->c_array[0]); + node->b_array[0] = self->c_array[0]; + Py_INCREF(self->c_array[1]); + node->b_array[1] = self->c_array[1]; + } + + node->b_bitmap = hamt_bitpos(hash, shift); + + *new_node = (PyHamtNode *)node; + return W_NEWNODE; + } + + /* Allocate a new Collision node with capacity for one + less key/value pair */ + PyHamtNode_Collision *new = (PyHamtNode_Collision *) + hamt_node_collision_new( + self->c_hash, Py_SIZE(self) - 2); + + /* Copy all other keys from `self` to `new` */ + Py_ssize_t i; + for (i = 0; i < key_idx; i++) { + Py_INCREF(self->c_array[i]); + new->c_array[i] = self->c_array[i]; + } + for (i = key_idx + 2; i < Py_SIZE(self); i++) { + Py_INCREF(self->c_array[i]); + new->c_array[i - 2] = self->c_array[i]; + } + + *new_node = (PyHamtNode*)new; + return W_NEWNODE; + + default: + Py_UNREACHABLE(); + } +} + +static hamt_find_t +hamt_node_collision_find(PyHamtNode_Collision *self, + uint32_t shift, int32_t hash, + PyObject *key, PyObject **val) +{ + /* Lookup `key` in the Collision node `self`. Set the value + for the found key to 'val'. */ + + Py_ssize_t idx = -1; + hamt_find_t res; + + res = hamt_node_collision_find_index(self, key, &idx); + if (res == F_ERROR || res == F_NOT_FOUND) { + return res; + } + + assert(idx >= 0); + assert(idx + 1 < Py_SIZE(self)); + + *val = self->c_array[idx + 1]; + assert(*val != NULL); + + return F_FOUND; +} + + +static int +hamt_node_collision_traverse(PyHamtNode_Collision *self, + visitproc visit, void *arg) +{ + /* Collision's tp_traverse */ + + Py_ssize_t i; + + for (i = Py_SIZE(self); --i >= 0; ) { + Py_VISIT(self->c_array[i]); + } + + return 0; +} + +static void +hamt_node_collision_dealloc(PyHamtNode_Collision *self) +{ + /* Collision's tp_dealloc */ + + Py_ssize_t len = Py_SIZE(self); + + PyObject_GC_UnTrack(self); + Py_TRASHCAN_SAFE_BEGIN(self) + + if (len > 0) { + + while (--len >= 0) { + Py_XDECREF(self->c_array[len]); + } + } + + Py_TYPE(self)->tp_free((PyObject *)self); + Py_TRASHCAN_SAFE_END(self) +} + +#ifdef Py_DEBUG +static int +hamt_node_collision_dump(PyHamtNode_Collision *node, + _PyUnicodeWriter *writer, int level) +{ + /* Debug build: __dump__() method implementation for Collision nodes. */ + + Py_ssize_t i; + + if (_hamt_dump_ident(writer, level + 1)) { + goto error; + } + + if (_hamt_dump_format(writer, "CollisionNode(size=%zd id=%p):\n", + Py_SIZE(node), node)) + { + goto error; + } + + for (i = 0; i < Py_SIZE(node); i += 2) { + PyObject *key = node->c_array[i]; + PyObject *val = node->c_array[i + 1]; + + if (_hamt_dump_ident(writer, level + 2)) { + goto error; + } + + if (_hamt_dump_format(writer, "%R: %R\n", key, val)) { + goto error; + } + } + + return 0; +error: + return -1; +} +#endif /* Py_DEBUG */ + + +/////////////////////////////////// Array Node + + +static PyHamtNode * +hamt_node_array_new(Py_ssize_t count) +{ + Py_ssize_t i; + + PyHamtNode_Array *node = PyObject_GC_New( + PyHamtNode_Array, &_PyHamt_ArrayNode_Type); + if (node == NULL) { + return NULL; + } + + for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + node->a_array[i] = NULL; + } + + node->a_count = count; + + _PyObject_GC_TRACK(node); + return (PyHamtNode *)node; +} + +static PyHamtNode_Array * +hamt_node_array_clone(PyHamtNode_Array *node) +{ + PyHamtNode_Array *clone; + Py_ssize_t i; + + VALIDATE_ARRAY_NODE(node) + + /* Create a new Array node. */ + clone = (PyHamtNode_Array *)hamt_node_array_new(node->a_count); + if (clone == NULL) { + return NULL; + } + + /* Copy all elements from the current Array node to the new one. */ + for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + Py_XINCREF(node->a_array[i]); + clone->a_array[i] = node->a_array[i]; + } + + VALIDATE_ARRAY_NODE(clone) + return clone; +} + +static PyHamtNode * +hamt_node_array_assoc(PyHamtNode_Array *self, + uint32_t shift, int32_t hash, + PyObject *key, PyObject *val, int* added_leaf) +{ + /* Set a new key to this level (currently a Collision node) + of the tree. + + Array nodes don't store values, they can only point to + other nodes. They are simple arrays of 32 BaseNode pointers/ + */ + + uint32_t idx = hamt_mask(hash, shift); + PyHamtNode *node = self->a_array[idx]; + PyHamtNode *child_node; + PyHamtNode_Array *new_node; + Py_ssize_t i; + + if (node == NULL) { + /* There's no child node for the given hash. Create a new + Bitmap node for this key. */ + + PyHamtNode_Bitmap *empty = NULL; + + /* Get an empty Bitmap node to work with. */ + empty = (PyHamtNode_Bitmap *)hamt_node_bitmap_new(0); + if (empty == NULL) { + return NULL; + } + + /* Set key/val to the newly created empty Bitmap, thus + creating a new Bitmap node with our key/value pair. */ + child_node = hamt_node_bitmap_assoc( + empty, + shift + 5, hash, key, val, added_leaf); + Py_DECREF(empty); + if (child_node == NULL) { + return NULL; + } + + /* Create a new Array node. */ + new_node = (PyHamtNode_Array *)hamt_node_array_new(self->a_count + 1); + if (new_node == NULL) { + Py_DECREF(child_node); + return NULL; + } + + /* Copy all elements from the current Array node to the + new one. */ + for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + Py_XINCREF(self->a_array[i]); + new_node->a_array[i] = self->a_array[i]; + } + + assert(new_node->a_array[idx] == NULL); + new_node->a_array[idx] = child_node; /* borrow */ + VALIDATE_ARRAY_NODE(new_node) + } + else { + /* There's a child node for the given hash. + Set the key to it./ */ + child_node = hamt_node_assoc( + node, shift + 5, hash, key, val, added_leaf); + if (child_node == (PyHamtNode *)self) { + Py_DECREF(child_node); + return (PyHamtNode *)self; + } + + new_node = hamt_node_array_clone(self); + if (new_node == NULL) { + Py_DECREF(child_node); + return NULL; + } + + Py_SETREF(new_node->a_array[idx], child_node); /* borrow */ + VALIDATE_ARRAY_NODE(new_node) + } + + return (PyHamtNode *)new_node; +} + +static hamt_without_t +hamt_node_array_without(PyHamtNode_Array *self, + uint32_t shift, int32_t hash, + PyObject *key, + PyHamtNode **new_node) +{ + uint32_t idx = hamt_mask(hash, shift); + PyHamtNode *node = self->a_array[idx]; + + if (node == NULL) { + return W_NOT_FOUND; + } + + PyHamtNode *sub_node = NULL; + hamt_without_t res = hamt_node_without( + (PyHamtNode *)node, + shift + 5, hash, key, &sub_node); + + switch (res) { + case W_NOT_FOUND: + case W_ERROR: + assert(sub_node == NULL); + return res; + + case W_NEWNODE: { + /* We need to replace a node at the `idx` index. + Clone this node and replace. + */ + assert(sub_node != NULL); + + PyHamtNode_Array *clone = hamt_node_array_clone(self); + if (clone == NULL) { + Py_DECREF(sub_node); + return W_ERROR; + } + + Py_SETREF(clone->a_array[idx], sub_node); /* borrow */ + *new_node = (PyHamtNode*)clone; /* borrow */ + return W_NEWNODE; + } + + case W_EMPTY: { + assert(sub_node == NULL); + /* We need to remove a node at the `idx` index. + Calculate the size of the replacement Array node. + */ + Py_ssize_t new_count = self->a_count - 1; + + if (new_count == 0) { + return W_EMPTY; + } + + if (new_count >= 16) { + /* We convert Bitmap nodes to Array nodes, when a + Bitmap node needs to store more than 15 key/value + pairs. So we will create a new Array node if we + the number of key/values after deletion is still + greater than 15. + */ + + PyHamtNode_Array *new = hamt_node_array_clone(self); + if (new == NULL) { + return W_ERROR; + } + new->a_count = new_count; + Py_CLEAR(new->a_array[idx]); + + *new_node = (PyHamtNode*)new; /* borrow */ + return W_NEWNODE; + } + + /* New Array node would have less than 16 key/value + pairs. We need to create a replacement Bitmap node. */ + + Py_ssize_t bitmap_size = new_count * 2; + uint32_t bitmap = 0; + + PyHamtNode_Bitmap *new = (PyHamtNode_Bitmap *) + hamt_node_bitmap_new(bitmap_size); + if (new == NULL) { + return W_ERROR; + } + + Py_ssize_t new_i = 0; + for (uint32_t i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + if (i == idx) { + /* Skip the node we are deleting. */ + continue; + } + + PyHamtNode *node = self->a_array[i]; + if (node == NULL) { + /* Skip any missing nodes. */ + continue; + } + + bitmap |= 1 << i; + + if (IS_BITMAP_NODE(node)) { + PyHamtNode_Bitmap *child = (PyHamtNode_Bitmap *)node; + + if (hamt_node_bitmap_count(child) == 1 && + child->b_array[0] != NULL) + { + /* node is a Bitmap with one key/value pair, just + merge it into the new Bitmap node we're building. + + Note that we don't inline Bitmap nodes that + have a NULL key -- those nodes point to another + tree level, and we cannot simply move tree levels + up or down. + */ + PyObject *key = child->b_array[0]; + PyObject *val = child->b_array[1]; + + Py_INCREF(key); + new->b_array[new_i] = key; + Py_INCREF(val); + new->b_array[new_i + 1] = val; + } + else { + new->b_array[new_i] = NULL; + Py_INCREF(node); + new->b_array[new_i + 1] = (PyObject*)node; + } + } + else { + +#ifdef Py_DEBUG + if (IS_COLLISION_NODE(node)) { + Py_ssize_t child_count = hamt_node_collision_count( + (PyHamtNode_Collision*)node); + assert(child_count > 1); + } + else if (IS_ARRAY_NODE(node)) { + assert(((PyHamtNode_Array*)node)->a_count >= 16); + } +#endif + + /* Just copy the node into our new Bitmap */ + new->b_array[new_i] = NULL; + Py_INCREF(node); + new->b_array[new_i + 1] = (PyObject*)node; + } + + new_i += 2; + } + + new->b_bitmap = bitmap; + *new_node = (PyHamtNode*)new; /* borrow */ + return W_NEWNODE; + } + + default: + Py_UNREACHABLE(); + } +} + +static hamt_find_t +hamt_node_array_find(PyHamtNode_Array *self, + uint32_t shift, int32_t hash, + PyObject *key, PyObject **val) +{ + /* Lookup `key` in the Array node `self`. Set the value + for the found key to 'val'. */ + + uint32_t idx = hamt_mask(hash, shift); + PyHamtNode *node; + + node = self->a_array[idx]; + if (node == NULL) { + return F_NOT_FOUND; + } + + /* Dispatch to the generic hamt_node_find */ + return hamt_node_find(node, shift + 5, hash, key, val); +} + +static int +hamt_node_array_traverse(PyHamtNode_Array *self, + visitproc visit, void *arg) +{ + /* Array's tp_traverse */ + + Py_ssize_t i; + + for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + Py_VISIT(self->a_array[i]); + } + + return 0; +} + +static void +hamt_node_array_dealloc(PyHamtNode_Array *self) +{ + /* Array's tp_dealloc */ + + Py_ssize_t i; + + PyObject_GC_UnTrack(self); + Py_TRASHCAN_SAFE_BEGIN(self) + + for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + Py_XDECREF(self->a_array[i]); + } + + Py_TYPE(self)->tp_free((PyObject *)self); + Py_TRASHCAN_SAFE_END(self) +} + +#ifdef Py_DEBUG +static int +hamt_node_array_dump(PyHamtNode_Array *node, + _PyUnicodeWriter *writer, int level) +{ + /* Debug build: __dump__() method implementation for Array nodes. */ + + Py_ssize_t i; + + if (_hamt_dump_ident(writer, level + 1)) { + goto error; + } + + if (_hamt_dump_format(writer, "ArrayNode(id=%p):\n", node)) { + goto error; + } + + for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + if (node->a_array[i] == NULL) { + continue; + } + + if (_hamt_dump_ident(writer, level + 2)) { + goto error; + } + + if (_hamt_dump_format(writer, "%d::\n", i)) { + goto error; + } + + if (hamt_node_dump(node->a_array[i], writer, level + 1)) { + goto error; + } + + if (_hamt_dump_format(writer, "\n")) { + goto error; + } + } + + return 0; +error: + return -1; +} +#endif /* Py_DEBUG */ + + +/////////////////////////////////// Node Dispatch + + +static PyHamtNode * +hamt_node_assoc(PyHamtNode *node, + uint32_t shift, int32_t hash, + PyObject *key, PyObject *val, int* added_leaf) +{ + /* Set key/value to the 'node' starting with the given shift/hash. + Return a new node, or the same node if key/value already + set. + + added_leaf will be set to 1 if key/value wasn't in the + tree before. + + This method automatically dispatches to the suitable + hamt_node_{nodetype}_assoc method. + */ + + if (IS_BITMAP_NODE(node)) { + return hamt_node_bitmap_assoc( + (PyHamtNode_Bitmap *)node, + shift, hash, key, val, added_leaf); + } + else if (IS_ARRAY_NODE(node)) { + return hamt_node_array_assoc( + (PyHamtNode_Array *)node, + shift, hash, key, val, added_leaf); + } + else { + assert(IS_COLLISION_NODE(node)); + return hamt_node_collision_assoc( + (PyHamtNode_Collision *)node, + shift, hash, key, val, added_leaf); + } +} + +static hamt_without_t +hamt_node_without(PyHamtNode *node, + uint32_t shift, int32_t hash, + PyObject *key, + PyHamtNode **new_node) +{ + if (IS_BITMAP_NODE(node)) { + return hamt_node_bitmap_without( + (PyHamtNode_Bitmap *)node, + shift, hash, key, + new_node); + } + else if (IS_ARRAY_NODE(node)) { + return hamt_node_array_without( + (PyHamtNode_Array *)node, + shift, hash, key, + new_node); + } + else { + assert(IS_COLLISION_NODE(node)); + return hamt_node_collision_without( + (PyHamtNode_Collision *)node, + shift, hash, key, + new_node); + } +} + +static hamt_find_t +hamt_node_find(PyHamtNode *node, + uint32_t shift, int32_t hash, + PyObject *key, PyObject **val) +{ + /* Find the key in the node starting with the given shift/hash. + + If a value is found, the result will be set to F_FOUND, and + *val will point to the found value object. + + If a value wasn't found, the result will be set to F_NOT_FOUND. + + If an exception occurs during the call, the result will be F_ERROR. + + This method automatically dispatches to the suitable + hamt_node_{nodetype}_find method. + */ + + if (IS_BITMAP_NODE(node)) { + return hamt_node_bitmap_find( + (PyHamtNode_Bitmap *)node, + shift, hash, key, val); + + } + else if (IS_ARRAY_NODE(node)) { + return hamt_node_array_find( + (PyHamtNode_Array *)node, + shift, hash, key, val); + } + else { + assert(IS_COLLISION_NODE(node)); + return hamt_node_collision_find( + (PyHamtNode_Collision *)node, + shift, hash, key, val); + } +} + +#ifdef Py_DEBUG +static int +hamt_node_dump(PyHamtNode *node, + _PyUnicodeWriter *writer, int level) +{ + /* Debug build: __dump__() method implementation for a node. + + This method automatically dispatches to the suitable + hamt_node_{nodetype})_dump method. + */ + + if (IS_BITMAP_NODE(node)) { + return hamt_node_bitmap_dump( + (PyHamtNode_Bitmap *)node, writer, level); + } + else if (IS_ARRAY_NODE(node)) { + return hamt_node_array_dump( + (PyHamtNode_Array *)node, writer, level); + } + else { + assert(IS_COLLISION_NODE(node)); + return hamt_node_collision_dump( + (PyHamtNode_Collision *)node, writer, level); + } +} +#endif /* Py_DEBUG */ + + +/////////////////////////////////// Iterators: Machinery + + +static hamt_iter_t +hamt_iterator_next(PyHamtIteratorState *iter, PyObject **key, PyObject **val); + + +static void +hamt_iterator_init(PyHamtIteratorState *iter, PyHamtNode *root) +{ + for (uint32_t i = 0; i < _Py_HAMT_MAX_TREE_DEPTH; i++) { + iter->i_nodes[i] = NULL; + iter->i_pos[i] = 0; + } + + iter->i_level = 0; + + /* Note: we don't incref/decref nodes in i_nodes. */ + iter->i_nodes[0] = root; +} + +static hamt_iter_t +hamt_iterator_bitmap_next(PyHamtIteratorState *iter, + PyObject **key, PyObject **val) +{ + int8_t level = iter->i_level; + + PyHamtNode_Bitmap *node = (PyHamtNode_Bitmap *)(iter->i_nodes[level]); + Py_ssize_t pos = iter->i_pos[level]; + + if (pos + 1 >= Py_SIZE(node)) { +#ifdef Py_DEBUG + assert(iter->i_level >= 0); + iter->i_nodes[iter->i_level] = NULL; +#endif + iter->i_level--; + return hamt_iterator_next(iter, key, val); + } + + if (node->b_array[pos] == NULL) { + iter->i_pos[level] = pos + 2; + + int8_t next_level = level + 1; + assert(next_level < _Py_HAMT_MAX_TREE_DEPTH); + iter->i_level = next_level; + iter->i_pos[next_level] = 0; + iter->i_nodes[next_level] = (PyHamtNode *) + node->b_array[pos + 1]; + + return hamt_iterator_next(iter, key, val); + } + + *key = node->b_array[pos]; + *val = node->b_array[pos + 1]; + iter->i_pos[level] = pos + 2; + return I_ITEM; +} + +static hamt_iter_t +hamt_iterator_collision_next(PyHamtIteratorState *iter, + PyObject **key, PyObject **val) +{ + int8_t level = iter->i_level; + + PyHamtNode_Collision *node = (PyHamtNode_Collision *)(iter->i_nodes[level]); + Py_ssize_t pos = iter->i_pos[level]; + + if (pos + 1 >= Py_SIZE(node)) { +#ifdef Py_DEBUG + assert(iter->i_level >= 0); + iter->i_nodes[iter->i_level] = NULL; +#endif + iter->i_level--; + return hamt_iterator_next(iter, key, val); + } + + *key = node->c_array[pos]; + *val = node->c_array[pos + 1]; + iter->i_pos[level] = pos + 2; + return I_ITEM; +} + +static hamt_iter_t +hamt_iterator_array_next(PyHamtIteratorState *iter, + PyObject **key, PyObject **val) +{ + int8_t level = iter->i_level; + + PyHamtNode_Array *node = (PyHamtNode_Array *)(iter->i_nodes[level]); + Py_ssize_t pos = iter->i_pos[level]; + + if (pos >= HAMT_ARRAY_NODE_SIZE) { +#ifdef Py_DEBUG + assert(iter->i_level >= 0); + iter->i_nodes[iter->i_level] = NULL; +#endif + iter->i_level--; + return hamt_iterator_next(iter, key, val); + } + + for (Py_ssize_t i = pos; i < HAMT_ARRAY_NODE_SIZE; i++) { + if (node->a_array[i] != NULL) { + iter->i_pos[level] = i + 1; + + int8_t next_level = level + 1; + assert(next_level < _Py_HAMT_MAX_TREE_DEPTH); + iter->i_pos[next_level] = 0; + iter->i_nodes[next_level] = node->a_array[i]; + iter->i_level = next_level; + + return hamt_iterator_next(iter, key, val); + } + } + +#ifdef Py_DEBUG + assert(iter->i_level >= 0); + iter->i_nodes[iter->i_level] = NULL; +#endif + + iter->i_level--; + return hamt_iterator_next(iter, key, val); +} + +static hamt_iter_t +hamt_iterator_next(PyHamtIteratorState *iter, PyObject **key, PyObject **val) +{ + if (iter->i_level < 0) { + return I_END; + } + + assert(iter->i_level < _Py_HAMT_MAX_TREE_DEPTH); + + PyHamtNode *current = iter->i_nodes[iter->i_level]; + + if (IS_BITMAP_NODE(current)) { + return hamt_iterator_bitmap_next(iter, key, val); + } + else if (IS_ARRAY_NODE(current)) { + return hamt_iterator_array_next(iter, key, val); + } + else { + assert(IS_COLLISION_NODE(current)); + return hamt_iterator_collision_next(iter, key, val); + } +} + + +/////////////////////////////////// HAMT high-level functions + + +PyHamtObject * +_PyHamt_Assoc(PyHamtObject *o, PyObject *key, PyObject *val) +{ + int32_t key_hash; + int added_leaf = 0; + PyHamtNode *new_root; + PyHamtObject *new_o; + + key_hash = hamt_hash(key); + if (key_hash == -1) { + return NULL; + } + + new_root = hamt_node_assoc( + (PyHamtNode *)(o->h_root), + 0, key_hash, key, val, &added_leaf); + if (new_root == NULL) { + return NULL; + } + + if (new_root == o->h_root) { + Py_DECREF(new_root); + Py_INCREF(o); + return o; + } + + new_o = hamt_alloc(); + if (new_o == NULL) { + Py_DECREF(new_root); + return NULL; + } + + new_o->h_root = new_root; /* borrow */ + new_o->h_count = added_leaf ? o->h_count + 1 : o->h_count; + + return new_o; +} + +PyHamtObject * +_PyHamt_Without(PyHamtObject *o, PyObject *key) +{ + int32_t key_hash = hamt_hash(key); + if (key_hash == -1) { + return NULL; + } + + PyHamtNode *new_root; + + hamt_without_t res = hamt_node_without( + (PyHamtNode *)(o->h_root), + 0, key_hash, key, + &new_root); + + switch (res) { + case W_ERROR: + return NULL; + case W_EMPTY: + return _PyHamt_New(); + case W_NOT_FOUND: + Py_INCREF(o); + return o; + case W_NEWNODE: { + PyHamtObject *new_o = hamt_alloc(); + if (new_o == NULL) { + Py_DECREF(new_root); + return NULL; + } + + new_o->h_root = new_root; /* borrow */ + new_o->h_count = o->h_count - 1; + assert(new_o->h_count >= 0); + return new_o; + } + default: + Py_UNREACHABLE(); + } +} + +static hamt_find_t +hamt_find(PyHamtObject *o, PyObject *key, PyObject **val) +{ + if (o->h_count == 0) { + return F_NOT_FOUND; + } + + int32_t key_hash = hamt_hash(key); + if (key_hash == -1) { + return F_ERROR; + } + + return hamt_node_find(o->h_root, 0, key_hash, key, val); +} + + +int +_PyHamt_Find(PyHamtObject *o, PyObject *key, PyObject **val) +{ + hamt_find_t res = hamt_find(o, key, val); + switch (res) { + case F_ERROR: + return -1; + case F_NOT_FOUND: + return 0; + case F_FOUND: + return 1; + default: + Py_UNREACHABLE(); + } +} + + +int +_PyHamt_Eq(PyHamtObject *v, PyHamtObject *w) +{ + if (v == w) { + return 1; + } + + if (v->h_count != w->h_count) { + return 0; + } + + PyHamtIteratorState iter; + hamt_iter_t iter_res; + hamt_find_t find_res; + PyObject *v_key; + PyObject *v_val; + PyObject *w_val; + + hamt_iterator_init(&iter, v->h_root); + + do { + iter_res = hamt_iterator_next(&iter, &v_key, &v_val); + if (iter_res == I_ITEM) { + find_res = hamt_find(w, v_key, &w_val); + switch (find_res) { + case F_ERROR: + return -1; + + case F_NOT_FOUND: + return 0; + + case F_FOUND: { + int cmp = PyObject_RichCompareBool(v_val, w_val, Py_EQ); + if (cmp < 0) { + return -1; + } + if (cmp == 0) { + return 0; + } + } + } + } + } while (iter_res != I_END); + + return 1; +} + +Py_ssize_t +_PyHamt_Len(PyHamtObject *o) +{ + return o->h_count; +} + +static PyHamtObject * +hamt_alloc(void) +{ + PyHamtObject *o; + o = PyObject_GC_New(PyHamtObject, &_PyHamt_Type); + if (o == NULL) { + return NULL; + } + o->h_weakreflist = NULL; + PyObject_GC_Track(o); + return o; +} + +PyHamtObject * +_PyHamt_New(void) +{ + if (_empty_hamt != NULL) { + /* HAMT is an immutable object so we can easily cache an + empty instance. */ + Py_INCREF(_empty_hamt); + return _empty_hamt; + } + + PyHamtObject *o = hamt_alloc(); + if (o == NULL) { + return NULL; + } + + o->h_root = hamt_node_bitmap_new(0); + if (o->h_root == NULL) { + Py_DECREF(o); + return NULL; + } + + o->h_count = 0; + + if (_empty_hamt == NULL) { + Py_INCREF(o); + _empty_hamt = o; + } + + return o; +} + +#ifdef Py_DEBUG +static PyObject * +hamt_dump(PyHamtObject *self) +{ + _PyUnicodeWriter writer; + + _PyUnicodeWriter_Init(&writer); + + if (_hamt_dump_format(&writer, "HAMT(len=%zd):\n", self->h_count)) { + goto error; + } + + if (hamt_node_dump(self->h_root, &writer, 0)) { + goto error; + } + + return _PyUnicodeWriter_Finish(&writer); + +error: + _PyUnicodeWriter_Dealloc(&writer); + return NULL; +} +#endif /* Py_DEBUG */ + + +/////////////////////////////////// Iterators: Shared Iterator Implementation + + +static int +hamt_baseiter_tp_clear(PyHamtIterator *it) +{ + Py_CLEAR(it->hi_obj); + return 0; +} + +static void +hamt_baseiter_tp_dealloc(PyHamtIterator *it) +{ + PyObject_GC_UnTrack(it); + (void)hamt_baseiter_tp_clear(it); + PyObject_GC_Del(it); +} + +static int +hamt_baseiter_tp_traverse(PyHamtIterator *it, visitproc visit, void *arg) +{ + Py_VISIT(it->hi_obj); + return 0; +} + +static PyObject * +hamt_baseiter_tp_iternext(PyHamtIterator *it) +{ + PyObject *key; + PyObject *val; + hamt_iter_t res = hamt_iterator_next(&it->hi_iter, &key, &val); + + switch (res) { + case I_END: + PyErr_SetNone(PyExc_StopIteration); + return NULL; + + case I_ITEM: { + return (*(it->hi_yield))(key, val); + } + + default: { + Py_UNREACHABLE(); + } + } +} + +static Py_ssize_t +hamt_baseiter_tp_len(PyHamtIterator *it) +{ + return it->hi_obj->h_count; +} + +static PyMappingMethods PyHamtIterator_as_mapping = { + (lenfunc)hamt_baseiter_tp_len, +}; + +static PyObject * +hamt_baseiter_new(PyTypeObject *type, binaryfunc yield, PyHamtObject *o) +{ + PyHamtIterator *it = PyObject_GC_New(PyHamtIterator, type); + if (it == NULL) { + return NULL; + } + + Py_INCREF(o); + it->hi_obj = o; + it->hi_yield = yield; + + hamt_iterator_init(&it->hi_iter, o->h_root); + + return (PyObject*)it; +} + +#define ITERATOR_TYPE_SHARED_SLOTS \ + .tp_basicsize = sizeof(PyHamtIterator), \ + .tp_itemsize = 0, \ + .tp_as_mapping = &PyHamtIterator_as_mapping, \ + .tp_dealloc = (destructor)hamt_baseiter_tp_dealloc, \ + .tp_getattro = PyObject_GenericGetAttr, \ + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, \ + .tp_traverse = (traverseproc)hamt_baseiter_tp_traverse, \ + .tp_clear = (inquiry)hamt_baseiter_tp_clear, \ + .tp_iter = PyObject_SelfIter, \ + .tp_iternext = (iternextfunc)hamt_baseiter_tp_iternext, + + +/////////////////////////////////// _PyHamtItems_Type + + +PyTypeObject _PyHamtItems_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "items", + ITERATOR_TYPE_SHARED_SLOTS +}; + +static PyObject * +hamt_iter_yield_items(PyObject *key, PyObject *val) +{ + return PyTuple_Pack(2, key, val); +} + +PyObject * +_PyHamt_NewIterItems(PyHamtObject *o) +{ + return hamt_baseiter_new( + &_PyHamtItems_Type, hamt_iter_yield_items, o); +} + + +/////////////////////////////////// _PyHamtKeys_Type + + +PyTypeObject _PyHamtKeys_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "keys", + ITERATOR_TYPE_SHARED_SLOTS +}; + +static PyObject * +hamt_iter_yield_keys(PyObject *key, PyObject *val) +{ + Py_INCREF(key); + return key; +} + +PyObject * +_PyHamt_NewIterKeys(PyHamtObject *o) +{ + return hamt_baseiter_new( + &_PyHamtKeys_Type, hamt_iter_yield_keys, o); +} + + +/////////////////////////////////// _PyHamtValues_Type + + +PyTypeObject _PyHamtValues_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "values", + ITERATOR_TYPE_SHARED_SLOTS +}; + +static PyObject * +hamt_iter_yield_values(PyObject *key, PyObject *val) +{ + Py_INCREF(val); + return val; +} + +PyObject * +_PyHamt_NewIterValues(PyHamtObject *o) +{ + return hamt_baseiter_new( + &_PyHamtValues_Type, hamt_iter_yield_values, o); +} + + +/////////////////////////////////// _PyHamt_Type + + +#ifdef Py_DEBUG +static PyObject * +hamt_dump(PyHamtObject *self); +#endif + + +static PyObject * +hamt_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + return (PyObject*)_PyHamt_New(); +} + +static int +hamt_tp_clear(PyHamtObject *self) +{ + Py_CLEAR(self->h_root); + return 0; +} + + +static int +hamt_tp_traverse(PyHamtObject *self, visitproc visit, void *arg) +{ + Py_VISIT(self->h_root); + return 0; +} + +static void +hamt_tp_dealloc(PyHamtObject *self) +{ + PyObject_GC_UnTrack(self); + if (self->h_weakreflist != NULL) { + PyObject_ClearWeakRefs((PyObject*)self); + } + (void)hamt_tp_clear(self); + Py_TYPE(self)->tp_free(self); +} + + +static PyObject * +hamt_tp_richcompare(PyObject *v, PyObject *w, int op) +{ + if (!PyHamt_Check(v) || !PyHamt_Check(w) || (op != Py_EQ && op != Py_NE)) { + Py_RETURN_NOTIMPLEMENTED; + } + + int res = _PyHamt_Eq((PyHamtObject *)v, (PyHamtObject *)w); + if (res < 0) { + return NULL; + } + + if (op == Py_NE) { + res = !res; + } + + if (res) { + Py_RETURN_TRUE; + } + else { + Py_RETURN_FALSE; + } +} + +static int +hamt_tp_contains(PyHamtObject *self, PyObject *key) +{ + PyObject *val; + return _PyHamt_Find(self, key, &val); +} + +static PyObject * +hamt_tp_subscript(PyHamtObject *self, PyObject *key) +{ + PyObject *val; + hamt_find_t res = hamt_find(self, key, &val); + switch (res) { + case F_ERROR: + return NULL; + case F_FOUND: + Py_INCREF(val); + return val; + case F_NOT_FOUND: + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + default: + Py_UNREACHABLE(); + } +} + +static Py_ssize_t +hamt_tp_len(PyHamtObject *self) +{ + return _PyHamt_Len(self); +} + +static PyObject * +hamt_tp_iter(PyHamtObject *self) +{ + return _PyHamt_NewIterKeys(self); +} + +static PyObject * +hamt_py_set(PyHamtObject *self, PyObject *args) +{ + PyObject *key; + PyObject *val; + + if (!PyArg_UnpackTuple(args, "set", 2, 2, &key, &val)) { + return NULL; + } + + return (PyObject *)_PyHamt_Assoc(self, key, val); +} + +static PyObject * +hamt_py_get(PyHamtObject *self, PyObject *args) +{ + PyObject *key; + PyObject *def = NULL; + + if (!PyArg_UnpackTuple(args, "get", 1, 2, &key, &def)) { + return NULL; + } + + PyObject *val = NULL; + hamt_find_t res = hamt_find(self, key, &val); + switch (res) { + case F_ERROR: + return NULL; + case F_FOUND: + Py_INCREF(val); + return val; + case F_NOT_FOUND: + if (def == NULL) { + Py_RETURN_NONE; + } + Py_INCREF(def); + return def; + default: + Py_UNREACHABLE(); + } +} + +static PyObject * +hamt_py_delete(PyHamtObject *self, PyObject *key) +{ + return (PyObject *)_PyHamt_Without(self, key); +} + +static PyObject * +hamt_py_items(PyHamtObject *self, PyObject *args) +{ + return _PyHamt_NewIterItems(self); +} + +static PyObject * +hamt_py_values(PyHamtObject *self, PyObject *args) +{ + return _PyHamt_NewIterValues(self); +} + +static PyObject * +hamt_py_keys(PyHamtObject *self, PyObject *args) +{ + return _PyHamt_NewIterKeys(self); +} + +#ifdef Py_DEBUG +static PyObject * +hamt_py_dump(PyHamtObject *self, PyObject *args) +{ + return hamt_dump(self); +} +#endif + + +static PyMethodDef PyHamt_methods[] = { + {"set", (PyCFunction)hamt_py_set, METH_VARARGS, NULL}, + {"get", (PyCFunction)hamt_py_get, METH_VARARGS, NULL}, + {"delete", (PyCFunction)hamt_py_delete, METH_O, NULL}, + {"items", (PyCFunction)hamt_py_items, METH_NOARGS, NULL}, + {"keys", (PyCFunction)hamt_py_keys, METH_NOARGS, NULL}, + {"values", (PyCFunction)hamt_py_values, METH_NOARGS, NULL}, +#ifdef Py_DEBUG + {"__dump__", (PyCFunction)hamt_py_dump, METH_NOARGS, NULL}, +#endif + {NULL, NULL} +}; + +static PySequenceMethods PyHamt_as_sequence = { + 0, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + 0, /* sq_item */ + 0, /* sq_slice */ + 0, /* sq_ass_item */ + 0, /* sq_ass_slice */ + (objobjproc)hamt_tp_contains, /* sq_contains */ + 0, /* sq_inplace_concat */ + 0, /* sq_inplace_repeat */ +}; + +static PyMappingMethods PyHamt_as_mapping = { + (lenfunc)hamt_tp_len, /* mp_length */ + (binaryfunc)hamt_tp_subscript, /* mp_subscript */ +}; + +PyTypeObject _PyHamt_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "hamt", + sizeof(PyHamtObject), + .tp_methods = PyHamt_methods, + .tp_as_mapping = &PyHamt_as_mapping, + .tp_as_sequence = &PyHamt_as_sequence, + .tp_iter = (getiterfunc)hamt_tp_iter, + .tp_dealloc = (destructor)hamt_tp_dealloc, + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_richcompare = hamt_tp_richcompare, + .tp_traverse = (traverseproc)hamt_tp_traverse, + .tp_clear = (inquiry)hamt_tp_clear, + .tp_new = hamt_tp_new, + .tp_weaklistoffset = offsetof(PyHamtObject, h_weakreflist), + .tp_hash = PyObject_HashNotImplemented, +}; + + +/////////////////////////////////// Tree Node Types + + +PyTypeObject _PyHamt_ArrayNode_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "hamt_array_node", + sizeof(PyHamtNode_Array), + 0, + .tp_dealloc = (destructor)hamt_node_array_dealloc, + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_traverse = (traverseproc)hamt_node_array_traverse, + .tp_free = PyObject_GC_Del, + .tp_hash = PyObject_HashNotImplemented, +}; + +PyTypeObject _PyHamt_BitmapNode_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "hamt_bitmap_node", + sizeof(PyHamtNode_Bitmap) - sizeof(PyObject *), + sizeof(PyObject *), + .tp_dealloc = (destructor)hamt_node_bitmap_dealloc, + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_traverse = (traverseproc)hamt_node_bitmap_traverse, + .tp_free = PyObject_GC_Del, + .tp_hash = PyObject_HashNotImplemented, +}; + +PyTypeObject _PyHamt_CollisionNode_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "hamt_collision_node", + sizeof(PyHamtNode_Collision) - sizeof(PyObject *), + sizeof(PyObject *), + .tp_dealloc = (destructor)hamt_node_collision_dealloc, + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_traverse = (traverseproc)hamt_node_collision_traverse, + .tp_free = PyObject_GC_Del, + .tp_hash = PyObject_HashNotImplemented, +}; + + +int +_PyHamt_Init(void) +{ + if ((PyType_Ready(&_PyHamt_Type) < 0) || + (PyType_Ready(&_PyHamt_ArrayNode_Type) < 0) || + (PyType_Ready(&_PyHamt_BitmapNode_Type) < 0) || + (PyType_Ready(&_PyHamt_CollisionNode_Type) < 0) || + (PyType_Ready(&_PyHamtKeys_Type) < 0) || + (PyType_Ready(&_PyHamtValues_Type) < 0) || + (PyType_Ready(&_PyHamtItems_Type) < 0)) + { + return 0; + } + + return 1; +} + +void +_PyHamt_Fini(void) +{ + Py_CLEAR(_empty_hamt); + Py_CLEAR(_empty_bitmap_node); +} diff --git a/Python/pylifecycle.c b/Python/pylifecycle.c index 2f61db0d826..d46784a2f6b 100644 --- a/Python/pylifecycle.c +++ b/Python/pylifecycle.c @@ -4,6 +4,8 @@ #include "Python-ast.h" #undef Yield /* undefine macro conflicting with winbase.h */ +#include "internal/context.h" +#include "internal/hamt.h" #include "internal/pystate.h" #include "grammar.h" #include "node.h" @@ -758,6 +760,9 @@ _Py_InitializeCore(const _PyCoreConfig *core_config) return _Py_INIT_ERR("can't initialize warnings"); } + if (!_PyContext_Init()) + return _Py_INIT_ERR("can't init context"); + /* This call sets up builtin and frozen import support */ if (!interp->core_config._disable_importlib) { err = initimport(interp, sysmod); @@ -1176,6 +1181,7 @@ Py_FinalizeEx(void) _Py_HashRandomization_Fini(); _PyArg_Fini(); PyAsyncGen_Fini(); + _PyContext_Fini(); /* Cleanup Unicode implementation */ _PyUnicode_Fini(); diff --git a/Python/pystate.c b/Python/pystate.c index 9c25a26460e..909d831465d 100644 --- a/Python/pystate.c +++ b/Python/pystate.c @@ -173,6 +173,8 @@ PyInterpreterState_New(void) } HEAD_UNLOCK(); + interp->tstate_next_unique_id = 0; + return interp; } @@ -313,6 +315,11 @@ new_threadstate(PyInterpreterState *interp, int init) tstate->async_gen_firstiter = NULL; tstate->async_gen_finalizer = NULL; + tstate->context = NULL; + tstate->context_ver = 1; + + tstate->id = ++interp->tstate_next_unique_id; + if (init) _PyThreadState_Init(tstate); @@ -499,6 +506,8 @@ PyThreadState_Clear(PyThreadState *tstate) Py_CLEAR(tstate->coroutine_wrapper); Py_CLEAR(tstate->async_gen_firstiter); Py_CLEAR(tstate->async_gen_finalizer); + + Py_CLEAR(tstate->context); } diff --git a/Tools/msi/lib/lib_files.wxs b/Tools/msi/lib/lib_files.wxs index 5a72612c6a5..46ddcb41e9a 100644 --- a/Tools/msi/lib/lib_files.wxs +++ b/Tools/msi/lib/lib_files.wxs @@ -1,6 +1,6 @@  - + diff --git a/setup.py b/setup.py index 1da40a426a2..258094e3ada 100644 --- a/setup.py +++ b/setup.py @@ -644,6 +644,9 @@ class PyBuildExt(build_ext): # array objects exts.append( Extension('array', ['arraymodule.c']) ) + # Context Variables + exts.append( Extension('_contextvars', ['_contextvarsmodule.c']) ) + shared_math = 'Modules/_math.o' # complex math library functions exts.append( Extension('cmath', ['cmathmodule.c'],