mirror of https://github.com/python/cpython
bpo-32630: Use contextvars in decimal (GH-5278)
This commit is contained in:
parent
bc4123b0b3
commit
f13f12d8da
|
@ -433,13 +433,11 @@ _rounding_modes = (ROUND_DOWN, ROUND_HALF_UP, ROUND_HALF_EVEN, ROUND_CEILING,
|
||||||
# The getcontext() and setcontext() function manage access to a thread-local
|
# The getcontext() and setcontext() function manage access to a thread-local
|
||||||
# current context.
|
# current context.
|
||||||
|
|
||||||
import threading
|
import contextvars
|
||||||
|
|
||||||
local = threading.local()
|
_current_context_var = contextvars.ContextVar('decimal_context')
|
||||||
if hasattr(local, '__decimal_context__'):
|
|
||||||
del local.__decimal_context__
|
|
||||||
|
|
||||||
def getcontext(_local=local):
|
def getcontext():
|
||||||
"""Returns this thread's context.
|
"""Returns this thread's context.
|
||||||
|
|
||||||
If this thread does not yet have a context, returns
|
If this thread does not yet have a context, returns
|
||||||
|
@ -447,20 +445,20 @@ def getcontext(_local=local):
|
||||||
New contexts are copies of DefaultContext.
|
New contexts are copies of DefaultContext.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return _local.__decimal_context__
|
return _current_context_var.get()
|
||||||
except AttributeError:
|
except LookupError:
|
||||||
context = Context()
|
context = Context()
|
||||||
_local.__decimal_context__ = context
|
_current_context_var.set(context)
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def setcontext(context, _local=local):
|
def setcontext(context):
|
||||||
"""Set this thread's context to context."""
|
"""Set this thread's context to context."""
|
||||||
if context in (DefaultContext, BasicContext, ExtendedContext):
|
if context in (DefaultContext, BasicContext, ExtendedContext):
|
||||||
context = context.copy()
|
context = context.copy()
|
||||||
context.clear_flags()
|
context.clear_flags()
|
||||||
_local.__decimal_context__ = context
|
_current_context_var.set(context)
|
||||||
|
|
||||||
del threading, local # Don't contaminate the namespace
|
del contextvars # Don't contaminate the namespace
|
||||||
|
|
||||||
def localcontext(ctx=None):
|
def localcontext(ctx=None):
|
||||||
"""Return a context manager for a copy of the supplied context
|
"""Return a context manager for a copy of the supplied context
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
import asyncio
|
||||||
|
import decimal
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class DecimalContextTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_asyncio_task_decimal_context(self):
|
||||||
|
async def fractions(t, precision, x, y):
|
||||||
|
with decimal.localcontext() as ctx:
|
||||||
|
ctx.prec = precision
|
||||||
|
a = decimal.Decimal(x) / decimal.Decimal(y)
|
||||||
|
await asyncio.sleep(t)
|
||||||
|
b = decimal.Decimal(x) / decimal.Decimal(y ** 2)
|
||||||
|
return a, b
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
r1, r2 = await asyncio.gather(
|
||||||
|
fractions(0.1, 3, 1, 3), fractions(0.2, 6, 1, 3))
|
||||||
|
|
||||||
|
return r1, r2
|
||||||
|
|
||||||
|
r1, r2 = asyncio.run(main())
|
||||||
|
|
||||||
|
self.assertEqual(str(r1[0]), '0.333')
|
||||||
|
self.assertEqual(str(r1[1]), '0.111')
|
||||||
|
|
||||||
|
self.assertEqual(str(r2[0]), '0.333333')
|
||||||
|
self.assertEqual(str(r2[1]), '0.111111')
|
|
@ -0,0 +1 @@
|
||||||
|
Refactor decimal module to use contextvars to store decimal context.
|
|
@ -122,10 +122,7 @@ incr_false(void)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* Key for thread state dictionary */
|
static PyContextVar *current_context_var;
|
||||||
static PyObject *tls_context_key = NULL;
|
|
||||||
/* Invariant: NULL or the most recently accessed thread local context */
|
|
||||||
static PyDecContextObject *cached_context = NULL;
|
|
||||||
|
|
||||||
/* Template for creating new thread contexts, calling Context() without
|
/* Template for creating new thread contexts, calling Context() without
|
||||||
* arguments and initializing the module_context on first access. */
|
* arguments and initializing the module_context on first access. */
|
||||||
|
@ -1220,10 +1217,6 @@ context_new(PyTypeObject *type, PyObject *args UNUSED, PyObject *kwds UNUSED)
|
||||||
static void
|
static void
|
||||||
context_dealloc(PyDecContextObject *self)
|
context_dealloc(PyDecContextObject *self)
|
||||||
{
|
{
|
||||||
if (self == cached_context) {
|
|
||||||
cached_context = NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
Py_XDECREF(self->traps);
|
Py_XDECREF(self->traps);
|
||||||
Py_XDECREF(self->flags);
|
Py_XDECREF(self->flags);
|
||||||
Py_TYPE(self)->tp_free(self);
|
Py_TYPE(self)->tp_free(self);
|
||||||
|
@ -1498,69 +1491,38 @@ static PyGetSetDef context_getsets [] =
|
||||||
* operation.
|
* operation.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/* Get the context from the thread state dictionary. */
|
|
||||||
static PyObject *
|
static PyObject *
|
||||||
current_context_from_dict(void)
|
init_current_context(void)
|
||||||
{
|
{
|
||||||
PyObject *dict;
|
PyObject *tl_context = context_copy(default_context_template, NULL);
|
||||||
PyObject *tl_context;
|
if (tl_context == NULL) {
|
||||||
PyThreadState *tstate;
|
|
||||||
|
|
||||||
dict = PyThreadState_GetDict();
|
|
||||||
if (dict == NULL) {
|
|
||||||
PyErr_SetString(PyExc_RuntimeError,
|
|
||||||
"cannot get thread state");
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
CTX(tl_context)->status = 0;
|
||||||
|
|
||||||
tl_context = PyDict_GetItemWithError(dict, tls_context_key);
|
PyContextToken *tok = PyContextVar_Set(current_context_var, tl_context);
|
||||||
if (tl_context != NULL) {
|
if (tok == NULL) {
|
||||||
/* We already have a thread local context. */
|
|
||||||
CONTEXT_CHECK(tl_context);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (PyErr_Occurred()) {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Set up a new thread local context. */
|
|
||||||
tl_context = context_copy(default_context_template, NULL);
|
|
||||||
if (tl_context == NULL) {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
CTX(tl_context)->status = 0;
|
|
||||||
|
|
||||||
if (PyDict_SetItem(dict, tls_context_key, tl_context) < 0) {
|
|
||||||
Py_DECREF(tl_context);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
Py_DECREF(tl_context);
|
Py_DECREF(tl_context);
|
||||||
|
return NULL;
|
||||||
}
|
}
|
||||||
|
Py_DECREF(tok);
|
||||||
|
|
||||||
/* Cache the context of the current thread, assuming that it
|
|
||||||
* will be accessed several times before a thread switch. */
|
|
||||||
tstate = PyThreadState_GET();
|
|
||||||
if (tstate) {
|
|
||||||
cached_context = (PyDecContextObject *)tl_context;
|
|
||||||
cached_context->tstate = tstate;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Borrowed reference with refcount==1 */
|
|
||||||
return tl_context;
|
return tl_context;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Return borrowed reference to thread local context. */
|
static inline PyObject *
|
||||||
static PyObject *
|
|
||||||
current_context(void)
|
current_context(void)
|
||||||
{
|
{
|
||||||
PyThreadState *tstate;
|
PyObject *tl_context;
|
||||||
|
if (PyContextVar_Get(current_context_var, NULL, &tl_context) < 0) {
|
||||||
tstate = PyThreadState_GET();
|
return NULL;
|
||||||
if (cached_context && cached_context->tstate == tstate) {
|
|
||||||
return (PyObject *)cached_context;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return current_context_from_dict();
|
if (tl_context != NULL) {
|
||||||
|
return tl_context;
|
||||||
|
}
|
||||||
|
|
||||||
|
return init_current_context();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ctxobj := borrowed reference to the current context */
|
/* ctxobj := borrowed reference to the current context */
|
||||||
|
@ -1568,47 +1530,22 @@ current_context(void)
|
||||||
ctxobj = current_context(); \
|
ctxobj = current_context(); \
|
||||||
if (ctxobj == NULL) { \
|
if (ctxobj == NULL) { \
|
||||||
return NULL; \
|
return NULL; \
|
||||||
}
|
} \
|
||||||
|
Py_DECREF(ctxobj);
|
||||||
/* ctx := pointer to the mpd_context_t struct of the current context */
|
|
||||||
#define CURRENT_CONTEXT_ADDR(ctx) { \
|
|
||||||
PyObject *_c_t_x_o_b_j = current_context(); \
|
|
||||||
if (_c_t_x_o_b_j == NULL) { \
|
|
||||||
return NULL; \
|
|
||||||
} \
|
|
||||||
ctx = CTX(_c_t_x_o_b_j); \
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Return a new reference to the current context */
|
/* Return a new reference to the current context */
|
||||||
static PyObject *
|
static PyObject *
|
||||||
PyDec_GetCurrentContext(PyObject *self UNUSED, PyObject *args UNUSED)
|
PyDec_GetCurrentContext(PyObject *self UNUSED, PyObject *args UNUSED)
|
||||||
{
|
{
|
||||||
PyObject *context;
|
return current_context();
|
||||||
|
|
||||||
context = current_context();
|
|
||||||
if (context == NULL) {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
Py_INCREF(context);
|
|
||||||
return context;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Set the thread local context to a new context, decrement old reference */
|
/* Set the thread local context to a new context, decrement old reference */
|
||||||
static PyObject *
|
static PyObject *
|
||||||
PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
|
PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
|
||||||
{
|
{
|
||||||
PyObject *dict;
|
|
||||||
|
|
||||||
CONTEXT_CHECK(v);
|
CONTEXT_CHECK(v);
|
||||||
|
|
||||||
dict = PyThreadState_GetDict();
|
|
||||||
if (dict == NULL) {
|
|
||||||
PyErr_SetString(PyExc_RuntimeError,
|
|
||||||
"cannot get thread state");
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* If the new context is one of the templates, make a copy.
|
/* If the new context is one of the templates, make a copy.
|
||||||
* This is the current behavior of decimal.py. */
|
* This is the current behavior of decimal.py. */
|
||||||
if (v == default_context_template ||
|
if (v == default_context_template ||
|
||||||
|
@ -1624,13 +1561,13 @@ PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
|
||||||
Py_INCREF(v);
|
Py_INCREF(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
cached_context = NULL;
|
PyContextToken *tok = PyContextVar_Set(current_context_var, v);
|
||||||
if (PyDict_SetItem(dict, tls_context_key, v) < 0) {
|
Py_DECREF(v);
|
||||||
Py_DECREF(v);
|
if (tok == NULL) {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
Py_DECREF(tok);
|
||||||
|
|
||||||
Py_DECREF(v);
|
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4458,6 +4395,7 @@ _dec_hash(PyDecObject *v)
|
||||||
if (context == NULL) {
|
if (context == NULL) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
Py_DECREF(context);
|
||||||
|
|
||||||
if (mpd_isspecial(MPD(v))) {
|
if (mpd_isspecial(MPD(v))) {
|
||||||
if (mpd_issnan(MPD(v))) {
|
if (mpd_issnan(MPD(v))) {
|
||||||
|
@ -5599,6 +5537,11 @@ PyInit__decimal(void)
|
||||||
mpd_free = PyMem_Free;
|
mpd_free = PyMem_Free;
|
||||||
mpd_setminalloc(_Py_DEC_MINALLOC);
|
mpd_setminalloc(_Py_DEC_MINALLOC);
|
||||||
|
|
||||||
|
/* Init context variable */
|
||||||
|
current_context_var = PyContextVar_New("decimal_context", NULL);
|
||||||
|
if (current_context_var == NULL) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
|
||||||
/* Init external C-API functions */
|
/* Init external C-API functions */
|
||||||
_py_long_multiply = PyLong_Type.tp_as_number->nb_multiply;
|
_py_long_multiply = PyLong_Type.tp_as_number->nb_multiply;
|
||||||
|
@ -5768,7 +5711,6 @@ PyInit__decimal(void)
|
||||||
CHECK_INT(PyModule_AddObject(m, "DefaultContext",
|
CHECK_INT(PyModule_AddObject(m, "DefaultContext",
|
||||||
default_context_template));
|
default_context_template));
|
||||||
|
|
||||||
ASSIGN_PTR(tls_context_key, PyUnicode_FromString("___DECIMAL_CTX__"));
|
|
||||||
Py_INCREF(Py_True);
|
Py_INCREF(Py_True);
|
||||||
CHECK_INT(PyModule_AddObject(m, "HAVE_THREADS", Py_True));
|
CHECK_INT(PyModule_AddObject(m, "HAVE_THREADS", Py_True));
|
||||||
|
|
||||||
|
@ -5827,9 +5769,9 @@ error:
|
||||||
Py_CLEAR(SignalTuple); /* GCOV_NOT_REACHED */
|
Py_CLEAR(SignalTuple); /* GCOV_NOT_REACHED */
|
||||||
Py_CLEAR(DecimalTuple); /* GCOV_NOT_REACHED */
|
Py_CLEAR(DecimalTuple); /* GCOV_NOT_REACHED */
|
||||||
Py_CLEAR(default_context_template); /* GCOV_NOT_REACHED */
|
Py_CLEAR(default_context_template); /* GCOV_NOT_REACHED */
|
||||||
Py_CLEAR(tls_context_key); /* GCOV_NOT_REACHED */
|
|
||||||
Py_CLEAR(basic_context_template); /* GCOV_NOT_REACHED */
|
Py_CLEAR(basic_context_template); /* GCOV_NOT_REACHED */
|
||||||
Py_CLEAR(extended_context_template); /* GCOV_NOT_REACHED */
|
Py_CLEAR(extended_context_template); /* GCOV_NOT_REACHED */
|
||||||
|
Py_CLEAR(current_context_var); /* GCOV_NOT_REACHED */
|
||||||
Py_CLEAR(m); /* GCOV_NOT_REACHED */
|
Py_CLEAR(m); /* GCOV_NOT_REACHED */
|
||||||
|
|
||||||
return NULL; /* GCOV_NOT_REACHED */
|
return NULL; /* GCOV_NOT_REACHED */
|
||||||
|
|
Loading…
Reference in New Issue