diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index ec5db9fa423..99faf784b1c 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -2047,6 +2047,11 @@ class UsabilityTest(unittest.TestCase): self.assertIs(type(d), MyDecimal) self.assertEqual(d, d1) + a = Decimal('1.0') + b = MyDecimal(a) + self.assertIs(type(b), MyDecimal) + self.assertEqual(a, b) + def test_implicit_context(self): Decimal = self.decimal.Decimal getcontext = self.decimal.getcontext diff --git a/Misc/NEWS b/Misc/NEWS index 2f2ed8967c1..e09e67d8ae4 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -80,6 +80,9 @@ Core and Builtins Library ------- +- Issue #16431: Use the type information when constructing a Decimal subtype + from a Decimal argument. + - Issue #16350: zlib.Decompress.decompress() now accumulates data from successive calls after EOF in unused_data, instead of only saving the argument to the last call. Patch by Serhiy Storchaka. diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index 996f9da17be..0bc484f4ab1 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -2338,6 +2338,32 @@ PyDecType_FromFloat(PyTypeObject *type, PyObject *v, return dec; } +/* Return a new PyDecObject (subtype) from a Decimal. */ +static PyObject * +PyDecType_FromDecimalExact(PyTypeObject *type, PyObject *v, PyObject *context) +{ + PyObject *dec; + uint32_t status = 0; + + if (type == &PyDec_Type) { + Py_INCREF(v); + return v; + } + + dec = PyDecType_New(type); + if (dec == NULL) { + return NULL; + } + + mpd_qcopy(MPD(dec), MPD(v), &status); + if (dec_addstatus(context, status)) { + Py_DECREF(dec); + return NULL; + } + + return dec; +} + static PyObject * sequence_as_tuple(PyObject *v, PyObject *ex, const char *mesg) { @@ -2642,8 +2668,7 @@ PyDecType_FromObjectExact(PyTypeObject *type, PyObject *v, PyObject *context) return PyDecType_FromSsizeExact(type, 0, context); } else if (PyDec_Check(v)) { - Py_INCREF(v); - return v; + return PyDecType_FromDecimalExact(type, v, context); } else if (PyUnicode_Check(v)) { return PyDecType_FromUnicodeExactWS(type, v, context);