diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index 68d1da9faab..a33c9793b5a 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -123,6 +123,7 @@ get_module_state(PyObject *mod) static struct PyModuleDef _decimal_module; static PyType_Spec dec_spec; +static PyType_Spec context_spec; static inline decimal_state * get_module_state_by_def(PyTypeObject *tp) @@ -190,6 +191,7 @@ typedef struct PyDecContextObject { PyObject *flags; int capitals; PyThreadState *tstate; + decimal_state *modstate; } PyDecContextObject; typedef struct { @@ -210,6 +212,15 @@ typedef struct { #define CTX(v) (&((PyDecContextObject *)v)->ctx) #define CtxCaps(v) (((PyDecContextObject *)v)->capitals) +static inline decimal_state * +get_module_state_from_ctx(PyObject *v) +{ + assert(PyType_GetBaseByToken(Py_TYPE(v), &context_spec, NULL) == 1); + decimal_state *state = ((PyDecContextObject *)v)->modstate; + assert(state != NULL); + return state; +} + Py_LOCAL_INLINE(PyObject *) incr_true(void) @@ -564,7 +575,7 @@ static int dec_addstatus(PyObject *context, uint32_t status) { mpd_context_t *ctx = CTX(context); - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); ctx->status |= status; if (status & (ctx->traps|MPD_Malloc_error)) { @@ -859,7 +870,7 @@ static PyObject * context_getround(PyObject *self, void *Py_UNUSED(closure)) { int i = mpd_getround(CTX(self)); - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_from_ctx(self); return Py_NewRef(state->round_map[i]); } @@ -1018,7 +1029,7 @@ context_setround(PyObject *self, PyObject *value, void *Py_UNUSED(closure)) mpd_context_t *ctx; int x; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_from_ctx(self); x = getround(state, value); if (x == -1) { return -1; @@ -1077,7 +1088,7 @@ context_settraps_list(PyObject *self, PyObject *value) { mpd_context_t *ctx; uint32_t flags; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_from_ctx(self); flags = list_as_flags(state, value); if (flags & DEC_ERRORS) { return -1; @@ -1097,7 +1108,7 @@ context_settraps_dict(PyObject *self, PyObject *value) mpd_context_t *ctx; uint32_t flags; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_from_ctx(self); if (PyDecSignalDict_Check(state, value)) { flags = SdFlags(value); } @@ -1142,7 +1153,7 @@ context_setstatus_list(PyObject *self, PyObject *value) { mpd_context_t *ctx; uint32_t flags; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_from_ctx(self); flags = list_as_flags(state, value); if (flags & DEC_ERRORS) { @@ -1163,7 +1174,7 @@ context_setstatus_dict(PyObject *self, PyObject *value) mpd_context_t *ctx; uint32_t flags; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_from_ctx(self); if (PyDecSignalDict_Check(state, value)) { flags = SdFlags(value); } @@ -1393,6 +1404,7 @@ context_new(PyTypeObject *type, CtxCaps(self) = 1; self->tstate = NULL; + self->modstate = state; if (type == state->PyDecContext_Type) { PyObject_GC_Track(self); @@ -1470,7 +1482,7 @@ context_repr(PyDecContextObject *self) int n, mem; #ifdef Py_DEBUG - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_from_ctx((PyObject *)self); assert(PyDecContext_Check(state, self)); #endif ctx = CTX(self); @@ -1561,7 +1573,7 @@ context_copy(PyObject *self, PyObject *Py_UNUSED(dummy)) { PyObject *copy; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_from_ctx(self); copy = PyObject_CallObject((PyObject *)state->PyDecContext_Type, NULL); if (copy == NULL) { return NULL; @@ -1581,7 +1593,7 @@ context_reduce(PyObject *self, PyObject *Py_UNUSED(dummy)) PyObject *traps; PyObject *ret; mpd_context_t *ctx; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_from_ctx(self); ctx = CTX(self); @@ -2022,11 +2034,10 @@ static PyType_Spec ctxmanager_spec = { /******************************************************************************/ static PyObject * -PyDecType_New(PyTypeObject *type) +PyDecType_New(decimal_state *state, PyTypeObject *type) { PyDecObject *dec; - decimal_state *state = get_module_state_by_def(type); if (type == state->PyDec_Type) { dec = PyObject_GC_New(PyDecObject, state->PyDec_Type); } @@ -2052,7 +2063,7 @@ PyDecType_New(PyTypeObject *type) assert(PyObject_GC_IsTracked((PyObject *)dec)); return (PyObject *)dec; } -#define dec_alloc(st) PyDecType_New((st)->PyDec_Type) +#define dec_alloc(st) PyDecType_New(st, (st)->PyDec_Type) static int dec_traverse(PyObject *dec, visitproc visit, void *arg) @@ -2155,7 +2166,8 @@ PyDecType_FromCString(PyTypeObject *type, const char *s, PyObject *dec; uint32_t status = 0; - dec = PyDecType_New(type); + decimal_state *state = get_module_state_from_ctx(context); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2179,7 +2191,8 @@ PyDecType_FromCStringExact(PyTypeObject *type, const char *s, uint32_t status = 0; mpd_context_t maxctx; - dec = PyDecType_New(type); + decimal_state *state = get_module_state_from_ctx(context); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2266,7 +2279,8 @@ PyDecType_FromSsize(PyTypeObject *type, mpd_ssize_t v, PyObject *context) PyObject *dec; uint32_t status = 0; - dec = PyDecType_New(type); + decimal_state *state = get_module_state_from_ctx(context); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2287,7 +2301,8 @@ PyDecType_FromSsizeExact(PyTypeObject *type, mpd_ssize_t v, PyObject *context) uint32_t status = 0; mpd_context_t maxctx; - dec = PyDecType_New(type); + decimal_state *state = get_module_state_from_ctx(context); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2305,13 +2320,13 @@ PyDecType_FromSsizeExact(PyTypeObject *type, mpd_ssize_t v, PyObject *context) /* Convert from a PyLongObject. The context is not modified; flags set during conversion are accumulated in the status parameter. */ static PyObject * -dec_from_long(PyTypeObject *type, PyObject *v, +dec_from_long(decimal_state *state, PyTypeObject *type, PyObject *v, const mpd_context_t *ctx, uint32_t *status) { PyObject *dec; PyLongObject *l = (PyLongObject *)v; - dec = PyDecType_New(type); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2356,7 +2371,8 @@ PyDecType_FromLong(PyTypeObject *type, PyObject *v, PyObject *context) return NULL; } - dec = dec_from_long(type, v, CTX(context), &status); + decimal_state *state = get_module_state_from_ctx(context); + dec = dec_from_long(state, type, v, CTX(context), &status); if (dec == NULL) { return NULL; } @@ -2385,7 +2401,8 @@ PyDecType_FromLongExact(PyTypeObject *type, PyObject *v, } mpd_maxcontext(&maxctx); - dec = dec_from_long(type, v, &maxctx, &status); + decimal_state *state = get_module_state_from_ctx(context); + dec = dec_from_long(state, type, v, &maxctx, &status); if (dec == NULL) { return NULL; } @@ -2417,7 +2434,7 @@ PyDecType_FromFloatExact(PyTypeObject *type, PyObject *v, mpd_t *d1, *d2; uint32_t status = 0; mpd_context_t maxctx; - decimal_state *state = get_module_state_by_def(type); + decimal_state *state = get_module_state_from_ctx(context); #ifdef Py_DEBUG assert(PyType_IsSubtype(type, state->PyDec_Type)); @@ -2438,7 +2455,7 @@ PyDecType_FromFloatExact(PyTypeObject *type, PyObject *v, sign = (copysign(1.0, x) == 1.0) ? 0 : 1; if (isnan(x) || isinf(x)) { - dec = PyDecType_New(type); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2555,12 +2572,12 @@ PyDecType_FromDecimalExact(PyTypeObject *type, PyObject *v, PyObject *context) PyObject *dec; uint32_t status = 0; - decimal_state *state = get_module_state_by_def(type); + decimal_state *state = get_module_state_from_ctx(context); if (type == state->PyDec_Type && PyDec_CheckExact(state, v)) { return Py_NewRef(v); } - dec = PyDecType_New(type); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2844,7 +2861,7 @@ dec_from_float(PyObject *type, PyObject *pyfloat) static PyObject * ctx_from_float(PyObject *context, PyObject *v) { - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); return PyDec_FromFloat(state, v, context); } @@ -2855,7 +2872,7 @@ dec_apply(PyObject *v, PyObject *context) PyObject *result; uint32_t status = 0; - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); result = dec_alloc(state); if (result == NULL) { return NULL; @@ -2882,7 +2899,7 @@ dec_apply(PyObject *v, PyObject *context) static PyObject * PyDecType_FromObjectExact(PyTypeObject *type, PyObject *v, PyObject *context) { - decimal_state *state = get_module_state_by_def(type); + decimal_state *state = get_module_state_from_ctx(context); if (v == NULL) { return PyDecType_FromSsizeExact(type, 0, context); } @@ -2917,7 +2934,7 @@ PyDecType_FromObjectExact(PyTypeObject *type, PyObject *v, PyObject *context) static PyObject * PyDec_FromObject(PyObject *v, PyObject *context) { - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); if (v == NULL) { return PyDec_FromSsize(state, 0, context); } @@ -3004,7 +3021,7 @@ ctx_create_decimal(PyObject *context, PyObject *args) Py_LOCAL_INLINE(int) convert_op(int type_err, PyObject **conv, PyObject *v, PyObject *context) { - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); if (PyDec_Check(state, v)) { *conv = Py_NewRef(v); return 1; @@ -3107,7 +3124,7 @@ multiply_by_denominator(PyObject *v, PyObject *r, PyObject *context) if (tmp == NULL) { return NULL; } - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); denom = PyDec_FromLongExact(state, tmp, context); Py_DECREF(tmp); if (denom == NULL) { @@ -3162,7 +3179,7 @@ numerator_as_decimal(PyObject *r, PyObject *context) return NULL; } - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); num = PyDec_FromLongExact(state, tmp, context); Py_DECREF(tmp); return num; @@ -3181,7 +3198,7 @@ convert_op_cmp(PyObject **vcmp, PyObject **wcmp, PyObject *v, PyObject *w, *vcmp = v; - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); if (PyDec_Check(state, w)) { *wcmp = Py_NewRef(w); } @@ -4421,12 +4438,11 @@ dec_conjugate(PyObject *self, PyObject *Py_UNUSED(dummy)) return Py_NewRef(self); } -static PyObject * -dec_mpd_radix(PyObject *self, PyObject *Py_UNUSED(dummy)) +static inline PyObject * +_dec_mpd_radix(decimal_state *state) { PyObject *result; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); result = dec_alloc(state); if (result == NULL) { return NULL; @@ -4436,6 +4452,13 @@ dec_mpd_radix(PyObject *self, PyObject *Py_UNUSED(dummy)) return result; } +static PyObject * +dec_mpd_radix(PyObject *self, PyObject *Py_UNUSED(dummy)) +{ + decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + return _dec_mpd_radix(state); +} + static PyObject * dec_mpd_qcopy_abs(PyObject *self, PyObject *Py_UNUSED(dummy)) { @@ -5138,7 +5161,7 @@ ctx_##MPDFUNC(PyObject *context, PyObject *v) \ \ CONVERT_OP_RAISE(&a, v, context); \ decimal_state *state = \ - get_module_state_by_def(Py_TYPE(context)); \ + get_module_state_from_ctx(context); \ if ((result = dec_alloc(state)) == NULL) { \ Py_DECREF(a); \ return NULL; \ @@ -5170,7 +5193,7 @@ ctx_##MPDFUNC(PyObject *context, PyObject *args) \ \ CONVERT_BINOP_RAISE(&a, &b, v, w, context); \ decimal_state *state = \ - get_module_state_by_def(Py_TYPE(context)); \ + get_module_state_from_ctx(context); \ if ((result = dec_alloc(state)) == NULL) { \ Py_DECREF(a); \ Py_DECREF(b); \ @@ -5206,7 +5229,7 @@ ctx_##MPDFUNC(PyObject *context, PyObject *args) \ \ CONVERT_BINOP_RAISE(&a, &b, v, w, context); \ decimal_state *state = \ - get_module_state_by_def(Py_TYPE(context)); \ + get_module_state_from_ctx(context); \ if ((result = dec_alloc(state)) == NULL) { \ Py_DECREF(a); \ Py_DECREF(b); \ @@ -5235,7 +5258,7 @@ ctx_##MPDFUNC(PyObject *context, PyObject *args) \ } \ \ CONVERT_TERNOP_RAISE(&a, &b, &c, v, w, x, context); \ - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); \ + decimal_state *state = get_module_state_from_ctx(context); \ if ((result = dec_alloc(state)) == NULL) { \ Py_DECREF(a); \ Py_DECREF(b); \ @@ -5301,7 +5324,7 @@ ctx_mpd_qdivmod(PyObject *context, PyObject *args) } CONVERT_BINOP_RAISE(&a, &b, v, w, context); - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); q = dec_alloc(state); if (q == NULL) { Py_DECREF(a); @@ -5356,7 +5379,7 @@ ctx_mpd_qpow(PyObject *context, PyObject *args, PyObject *kwds) } } - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); result = dec_alloc(state); if (result == NULL) { Py_DECREF(a); @@ -5391,7 +5414,8 @@ DecCtx_TernaryFunc(mpd_qfma) static PyObject * ctx_mpd_radix(PyObject *context, PyObject *dummy) { - return dec_mpd_radix(context, dummy); + decimal_state *state = get_module_state_from_ctx(context); + return _dec_mpd_radix(state); } /* Boolean functions: single decimal argument */ @@ -5408,7 +5432,7 @@ DecCtx_BoolFunc_NO_CTX(mpd_iszero) static PyObject * ctx_iscanonical(PyObject *context, PyObject *v) { - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); if (!PyDec_Check(state, v)) { PyErr_SetString(PyExc_TypeError, "argument must be a Decimal"); @@ -5434,7 +5458,7 @@ PyDecContext_Apply(PyObject *context, PyObject *v) static PyObject * ctx_canonical(PyObject *context, PyObject *v) { - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); if (!PyDec_Check(state, v)) { PyErr_SetString(PyExc_TypeError, "argument must be a Decimal"); @@ -5451,7 +5475,7 @@ ctx_mpd_qcopy_abs(PyObject *context, PyObject *v) uint32_t status = 0; CONVERT_OP_RAISE(&a, v, context); - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); result = dec_alloc(state); if (result == NULL) { Py_DECREF(a); @@ -5484,7 +5508,7 @@ ctx_mpd_qcopy_negate(PyObject *context, PyObject *v) uint32_t status = 0; CONVERT_OP_RAISE(&a, v, context); - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); result = dec_alloc(state); if (result == NULL) { Py_DECREF(a); @@ -5581,7 +5605,7 @@ ctx_mpd_qcopy_sign(PyObject *context, PyObject *args) } CONVERT_BINOP_RAISE(&a, &b, v, w, context); - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = get_module_state_from_ctx(context); result = dec_alloc(state); if (result == NULL) { Py_DECREF(a); @@ -5736,6 +5760,7 @@ static PyMethodDef context_methods [] = }; static PyType_Slot context_slots[] = { + {Py_tp_token, Py_TP_USE_SPEC}, {Py_tp_dealloc, context_dealloc}, {Py_tp_traverse, context_traverse}, {Py_tp_clear, context_clear},