From a26cf9b7609fc1c08fd1a69ddf5e44dc98a70dce Mon Sep 17 00:00:00 2001 From: Jeffrey Yasskin Date: Mon, 4 Feb 2008 01:04:35 +0000 Subject: [PATCH] Make int() and long() fall back to __trunc__(). See issue 2002. --- Include/abstract.h | 13 +++++ Lib/rational.py | 2 - Lib/test/test_builtin.py | 102 +++++++++++++++++++++++++++++++++++++++ Objects/abstract.c | 94 ++++++++++++++++++++++++++++++++++++ Objects/classobject.c | 24 ++++++++- 5 files changed, 232 insertions(+), 3 deletions(-) diff --git a/Include/abstract.h b/Include/abstract.h index b7fde09d528..e6cbb7b5be4 100644 --- a/Include/abstract.h +++ b/Include/abstract.h @@ -760,6 +760,19 @@ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx*/ PyAPI_FUNC(Py_ssize_t) PyNumber_AsSsize_t(PyObject *o, PyObject *exc); + /* + Returns the Integral instance converted to an int. The + instance is expected to be int or long or have an __int__ + method. Steals integral's reference. error_format will be + used to create the TypeError if integral isn't actually an + Integral instance. error_format should be a format string + that can accept a char* naming integral's type. + */ + + PyAPI_FUNC(PyObject *) _PyNumber_ConvertIntegralToInt( + PyObject *integral, + const char* error_format); + /* Returns the object converted to Py_ssize_t by going through PyNumber_Index first. If an overflow error occurs while diff --git a/Lib/rational.py b/Lib/rational.py index c76cba3d074..dcdaad494b7 100755 --- a/Lib/rational.py +++ b/Lib/rational.py @@ -424,8 +424,6 @@ class Rational(RationalAbc): else: return a.numerator // a.denominator - __int__ = __trunc__ - def __hash__(self): """hash(self) diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index cfc900335d9..9612a4b6872 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -934,6 +934,14 @@ class BuiltinTest(unittest.TestCase): def test_intconversion(self): # Test __int__() + class ClassicMissingMethods: + pass + self.assertRaises(AttributeError, int, ClassicMissingMethods()) + + class MissingMethods(object): + pass + self.assertRaises(TypeError, int, MissingMethods()) + class Foo0: def __int__(self): return 42 @@ -965,6 +973,49 @@ class BuiltinTest(unittest.TestCase): self.assertEqual(int(Foo4()), 42L) self.assertRaises(TypeError, int, Foo5()) + class Classic: + pass + for base in (object, Classic): + class IntOverridesTrunc(base): + def __int__(self): + return 42 + def __trunc__(self): + return -12 + self.assertEqual(int(IntOverridesTrunc()), 42) + + class JustTrunc(base): + def __trunc__(self): + return 42 + self.assertEqual(int(JustTrunc()), 42) + + for trunc_result_base in (object, Classic): + class Integral(trunc_result_base): + def __int__(self): + return 42 + + class TruncReturnsNonInt(base): + def __trunc__(self): + return Integral() + self.assertEqual(int(TruncReturnsNonInt()), 42) + + class NonIntegral(trunc_result_base): + def __trunc__(self): + # Check that we avoid infinite recursion. + return NonIntegral() + + class TruncReturnsNonIntegral(base): + def __trunc__(self): + return NonIntegral() + try: + int(TruncReturnsNonIntegral()) + except TypeError as e: + self.assertEquals(str(e), + "__trunc__ returned non-Integral" + " (type NonIntegral)") + else: + self.fail("Failed to raise TypeError with %s" % + ((base, trunc_result_base),)) + def test_intern(self): self.assertRaises(TypeError, intern) s = "never interned before" @@ -1207,6 +1258,14 @@ class BuiltinTest(unittest.TestCase): def test_longconversion(self): # Test __long__() + class ClassicMissingMethods: + pass + self.assertRaises(AttributeError, long, ClassicMissingMethods()) + + class MissingMethods(object): + pass + self.assertRaises(TypeError, long, MissingMethods()) + class Foo0: def __long__(self): return 42L @@ -1238,6 +1297,49 @@ class BuiltinTest(unittest.TestCase): self.assertEqual(long(Foo4()), 42) self.assertRaises(TypeError, long, Foo5()) + class Classic: + pass + for base in (object, Classic): + class LongOverridesTrunc(base): + def __long__(self): + return 42 + def __trunc__(self): + return -12 + self.assertEqual(long(LongOverridesTrunc()), 42) + + class JustTrunc(base): + def __trunc__(self): + return 42 + self.assertEqual(long(JustTrunc()), 42) + + for trunc_result_base in (object, Classic): + class Integral(trunc_result_base): + def __int__(self): + return 42 + + class TruncReturnsNonLong(base): + def __trunc__(self): + return Integral() + self.assertEqual(long(TruncReturnsNonLong()), 42) + + class NonIntegral(trunc_result_base): + def __trunc__(self): + # Check that we avoid infinite recursion. + return NonIntegral() + + class TruncReturnsNonIntegral(base): + def __trunc__(self): + return NonIntegral() + try: + long(TruncReturnsNonIntegral()) + except TypeError as e: + self.assertEquals(str(e), + "__trunc__ returned non-Integral" + " (type NonIntegral)") + else: + self.fail("Failed to raise TypeError with %s" % + ((base, trunc_result_base),)) + def test_map(self): self.assertEqual( map(None, 'hello world'), diff --git a/Objects/abstract.c b/Objects/abstract.c index 830fe821795..a3e159a105e 100644 --- a/Objects/abstract.c +++ b/Objects/abstract.c @@ -1034,13 +1034,65 @@ PyNumber_AsSsize_t(PyObject *item, PyObject *err) } +PyObject * +_PyNumber_ConvertIntegralToInt(PyObject *integral, const char* error_format) +{ + const char *type_name; + static PyObject *int_name = NULL; + if (int_name == NULL) { + int_name = PyString_InternFromString("__int__"); + if (int_name == NULL) + return NULL; + } + + if (integral && (!PyInt_Check(integral) && + !PyLong_Check(integral))) { + /* Don't go through tp_as_number->nb_int to avoid + hitting the classic class fallback to __trunc__. */ + PyObject *int_func = PyObject_GetAttr(integral, int_name); + if (int_func == NULL) { + PyErr_Clear(); /* Raise a different error. */ + goto non_integral_error; + } + Py_DECREF(integral); + integral = PyEval_CallObject(int_func, NULL); + Py_DECREF(int_func); + if (integral && (!PyInt_Check(integral) && + !PyLong_Check(integral))) { + goto non_integral_error; + } + } + return integral; + +non_integral_error: + if (PyInstance_Check(integral)) { + type_name = PyString_AS_STRING(((PyInstanceObject *)integral) + ->in_class->cl_name); + } + else { + type_name = integral->ob_type->tp_name; + } + PyErr_Format(PyExc_TypeError, error_format, type_name); + Py_DECREF(integral); + return NULL; +} + + PyObject * PyNumber_Int(PyObject *o) { PyNumberMethods *m; + static PyObject *trunc_name = NULL; + PyObject *trunc_func; const char *buffer; Py_ssize_t buffer_len; + if (trunc_name == NULL) { + trunc_name = PyString_InternFromString("__trunc__"); + if (trunc_name == NULL) + return NULL; + } + if (o == NULL) return null_error(); if (PyInt_CheckExact(o)) { @@ -1049,6 +1101,7 @@ PyNumber_Int(PyObject *o) } m = o->ob_type->tp_as_number; if (m && m->nb_int) { /* This should include subclasses of int */ + /* Classic classes always take this branch. */ PyObject *res = m->nb_int(o); if (res && (!PyInt_Check(res) && !PyLong_Check(res))) { PyErr_Format(PyExc_TypeError, @@ -1063,6 +1116,18 @@ PyNumber_Int(PyObject *o) PyIntObject *io = (PyIntObject*)o; return PyInt_FromLong(io->ob_ival); } + trunc_func = PyObject_GetAttr(o, trunc_name); + if (trunc_func) { + PyObject *truncated = PyEval_CallObject(trunc_func, NULL); + Py_DECREF(trunc_func); + /* __trunc__ is specified to return an Integral type, but + int() needs to return an int. */ + return _PyNumber_ConvertIntegralToInt( + truncated, + "__trunc__ returned non-Integral (type %.200s)"); + } + PyErr_Clear(); /* It's not an error if o.__trunc__ doesn't exist. */ + if (PyString_Check(o)) return int_from_string(PyString_AS_STRING(o), PyString_GET_SIZE(o)); @@ -1102,13 +1167,22 @@ PyObject * PyNumber_Long(PyObject *o) { PyNumberMethods *m; + static PyObject *trunc_name = NULL; + PyObject *trunc_func; const char *buffer; Py_ssize_t buffer_len; + if (trunc_name == NULL) { + trunc_name = PyString_InternFromString("__trunc__"); + if (trunc_name == NULL) + return NULL; + } + if (o == NULL) return null_error(); m = o->ob_type->tp_as_number; if (m && m->nb_long) { /* This should include subclasses of long */ + /* Classic classes always take this branch. */ PyObject *res = m->nb_long(o); if (res && (!PyInt_Check(res) && !PyLong_Check(res))) { PyErr_Format(PyExc_TypeError, @@ -1121,6 +1195,26 @@ PyNumber_Long(PyObject *o) } if (PyLong_Check(o)) /* A long subclass without nb_long */ return _PyLong_Copy((PyLongObject *)o); + trunc_func = PyObject_GetAttr(o, trunc_name); + if (trunc_func) { + PyObject *truncated = PyEval_CallObject(trunc_func, NULL); + PyObject *int_instance; + Py_DECREF(trunc_func); + /* __trunc__ is specified to return an Integral type, + but long() needs to return a long. */ + int_instance = _PyNumber_ConvertIntegralToInt( + truncated, + "__trunc__ returned non-Integral (type %.200s)"); + if (int_instance && PyInt_Check(int_instance)) { + /* Make sure that long() returns a long instance. */ + long value = PyInt_AS_LONG(int_instance); + Py_DECREF(int_instance); + return PyLong_FromLong(value); + } + return int_instance; + } + PyErr_Clear(); /* It's not an error if o.__trunc__ doesn't exist. */ + if (PyString_Check(o)) /* need to do extra error checking that PyLong_FromString() * doesn't do. In particular long('9.5') must raise an diff --git a/Objects/classobject.c b/Objects/classobject.c index b4b17f90777..9f364e2e88a 100644 --- a/Objects/classobject.c +++ b/Objects/classobject.c @@ -1798,7 +1798,29 @@ instance_index(PyInstanceObject *self) UNARY(instance_invert, "__invert__") -UNARY(instance_int, "__int__") +UNARY(_instance_trunc, "__trunc__") + +static PyObject * +instance_int(PyInstanceObject *self) +{ + PyObject *truncated; + static PyObject *int_name; + if (int_name == NULL) { + int_name = PyString_InternFromString("__int__"); + if (int_name == NULL) + return NULL; + } + if (PyObject_HasAttr((PyObject*)self, int_name)) + return generic_unary_op(self, int_name); + + truncated = _instance_trunc(self); + /* __trunc__ is specified to return an Integral type, but + int() needs to return an int. */ + return _PyNumber_ConvertIntegralToInt( + truncated, + "__trunc__ returned non-Integral (type %.200s)"); +} + UNARY_FB(instance_long, "__long__", instance_int) UNARY(instance_float, "__float__") UNARY(instance_oct, "__oct__")