diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index ea18c6393e8..dd4c73cdf14 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -2029,7 +2029,7 @@ class UsabilityTest(unittest.TestCase): Decimal = self.decimal.Decimal class MyDecimal(Decimal): - pass + y = None d1 = MyDecimal(1) d2 = MyDecimal(2) @@ -2047,14 +2047,29 @@ class UsabilityTest(unittest.TestCase): self.assertIs(type(d), MyDecimal) self.assertEqual(d, d1) - a = Decimal('1.0') - b = MyDecimal(a) - self.assertIs(type(b), MyDecimal) - self.assertEqual(a, b) + # Decimal(Decimal) + d = Decimal('1.0') + x = Decimal(d) + self.assertIs(type(x), Decimal) + self.assertEqual(x, d) - c = Decimal(b) - self.assertIs(type(c), Decimal) - self.assertEqual(a, c) + # MyDecimal(Decimal) + m = MyDecimal(d) + self.assertIs(type(m), MyDecimal) + self.assertEqual(m, d) + self.assertIs(m.y, None) + + # Decimal(MyDecimal) + x = Decimal(m) + self.assertIs(type(x), Decimal) + self.assertEqual(x, d) + + # MyDecimal(MyDecimal) + m.y = 9 + x = MyDecimal(m) + self.assertIs(type(x), MyDecimal) + self.assertEqual(x, d) + self.assertIs(x.y, None) def test_implicit_context(self): Decimal = self.decimal.Decimal diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index 0e1d3044cad..e951ded5fff 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -2338,14 +2338,14 @@ PyDecType_FromFloat(PyTypeObject *type, PyObject *v, return dec; } -/* Return a new PyDecObject (subtype) from a Decimal. */ +/* Return a new PyDecObject or a subtype from a Decimal. */ static PyObject * PyDecType_FromDecimalExact(PyTypeObject *type, PyObject *v, PyObject *context) { PyObject *dec; uint32_t status = 0; - if (type == Py_TYPE(v)) { + if (type == &PyDec_Type && PyDec_CheckExact(v)) { Py_INCREF(v); return v; }