Finish the work on __round__ and __trunc__.

With Alex Martelli and Keir Mierle.
This commit is contained in:
Guido van Rossum 2007-08-23 22:07:24 +00:00
parent 0f3cff58b2
commit 2fa33db12b
4 changed files with 173 additions and 60 deletions

View File

@ -1440,6 +1440,7 @@ class BuiltinTest(unittest.TestCase):
def test_round(self): def test_round(self):
self.assertEqual(round(0.0), 0.0) self.assertEqual(round(0.0), 0.0)
self.assertEqual(type(round(0.0)), int)
self.assertEqual(round(1.0), 1.0) self.assertEqual(round(1.0), 1.0)
self.assertEqual(round(10.0), 10.0) self.assertEqual(round(10.0), 10.0)
self.assertEqual(round(1000000000.0), 1000000000.0) self.assertEqual(round(1000000000.0), 1000000000.0)
@ -1468,6 +1469,25 @@ class BuiltinTest(unittest.TestCase):
self.assertEqual(round(-999999999.9), -1000000000.0) self.assertEqual(round(-999999999.9), -1000000000.0)
self.assertEqual(round(-8.0, -1), -10.0) self.assertEqual(round(-8.0, -1), -10.0)
self.assertEqual(type(round(-8.0, -1)), float)
self.assertEqual(type(round(-8.0, 0)), float)
self.assertEqual(type(round(-8.0, 1)), float)
# Check even / odd rounding behaviour
self.assertEqual(round(5.5), 6)
self.assertEqual(round(6.5), 6)
self.assertEqual(round(-5.5), -6)
self.assertEqual(round(-6.5), -6)
# Check behavior on ints
self.assertEqual(round(0), 0)
self.assertEqual(round(8), 8)
self.assertEqual(round(-8), -8)
self.assertEqual(type(round(0)), int)
self.assertEqual(type(round(-8, -1)), float)
self.assertEqual(type(round(-8, 0)), float)
self.assertEqual(type(round(-8, 1)), float)
# test new kwargs # test new kwargs
self.assertEqual(round(number=-8.0, ndigits=-1), -10.0) self.assertEqual(round(number=-8.0, ndigits=-1), -10.0)
@ -1487,6 +1507,11 @@ class BuiltinTest(unittest.TestCase):
self.assertRaises(TypeError, round, 1, 2, 3) self.assertRaises(TypeError, round, 1, 2, 3)
self.assertRaises(TypeError, round, TestNoRound()) self.assertRaises(TypeError, round, TestNoRound())
t = TestNoRound()
t.__round__ = lambda *args: args
self.assertRaises(TypeError, round, t)
self.assertRaises(TypeError, round, t, 0)
def test_setattr(self): def test_setattr(self):
setattr(sys, 'spam', 1) setattr(sys, 'spam', 1)
self.assertEqual(sys.spam, 1) self.assertEqual(sys.spam, 1)
@ -1529,6 +1554,18 @@ class BuiltinTest(unittest.TestCase):
self.assertRaises(ValueError, sum, BadSeq()) self.assertRaises(ValueError, sum, BadSeq())
def test_trunc(self): def test_trunc(self):
self.assertEqual(trunc(1), 1)
self.assertEqual(trunc(-1), -1)
self.assertEqual(type(trunc(1)), int)
self.assertEqual(type(trunc(1.5)), int)
self.assertEqual(trunc(1.5), 1)
self.assertEqual(trunc(-1.5), -1)
self.assertEqual(trunc(1.999999), 1)
self.assertEqual(trunc(-1.999999), -1)
self.assertEqual(trunc(-0.999999), -0)
self.assertEqual(trunc(-100.999), -100)
class TestTrunc: class TestTrunc:
def __trunc__(self): def __trunc__(self):
return 23 return 23
@ -1542,6 +1579,11 @@ class BuiltinTest(unittest.TestCase):
self.assertRaises(TypeError, trunc, 1, 2) self.assertRaises(TypeError, trunc, 1, 2)
self.assertRaises(TypeError, trunc, TestNoTrunc()) self.assertRaises(TypeError, trunc, TestNoTrunc())
t = TestNoTrunc()
t.__trunc__ = lambda *args: args
self.assertRaises(TypeError, trunc, t)
self.assertRaises(TypeError, trunc, t, 0)
def test_tuple(self): def test_tuple(self):
self.assertEqual(tuple(()), ()) self.assertEqual(tuple(()), ())
t0_3 = (0, 1, 2, 3) t0_3 = (0, 1, 2, 3)

View File

@ -743,14 +743,7 @@ float_bool(PyFloatObject *v)
} }
static PyObject * static PyObject *
float_long(PyObject *v) float_trunc(PyObject *v)
{
double x = PyFloat_AsDouble(v);
return PyLong_FromDouble(x);
}
static PyObject *
float_int(PyObject *v)
{ {
double x = PyFloat_AsDouble(v); double x = PyFloat_AsDouble(v);
double wholepart; /* integral portion of x, rounded toward 0 */ double wholepart; /* integral portion of x, rounded toward 0 */
@ -775,6 +768,55 @@ float_int(PyObject *v)
return PyLong_FromDouble(wholepart); return PyLong_FromDouble(wholepart);
} }
static PyObject *
float_round(PyObject *v, PyObject *args)
{
#define UNDEF_NDIGITS (-0x7fffffff) /* Unlikely ndigits value */
double x;
double f;
double flr, cil;
double rounded;
int i;
int ndigits = UNDEF_NDIGITS;
if (!PyArg_ParseTuple(args, "|i", &ndigits))
return NULL;
x = PyFloat_AsDouble(v);
if (ndigits != UNDEF_NDIGITS) {
f = 1.0;
i = abs(ndigits);
while (--i >= 0)
f = f*10.0;
if (ndigits < 0)
x /= f;
else
x *= f;
}
flr = floor(x);
cil = ceil(x);
if (x-flr > 0.5)
rounded = cil;
else if (x-flr == 0.5)
rounded = fmod(flr, 2) == 0 ? flr : cil;
else
rounded = flr;
if (ndigits != UNDEF_NDIGITS) {
if (ndigits < 0)
rounded *= f;
else
rounded /= f;
return PyFloat_FromDouble(rounded);
}
return PyLong_FromDouble(rounded);
#undef UNDEF_NDIGITS
}
static PyObject * static PyObject *
float_float(PyObject *v) float_float(PyObject *v)
{ {
@ -976,6 +1018,11 @@ float_getzero(PyObject *v, void *closure)
static PyMethodDef float_methods[] = { static PyMethodDef float_methods[] = {
{"conjugate", (PyCFunction)float_float, METH_NOARGS, {"conjugate", (PyCFunction)float_float, METH_NOARGS,
"Returns self, the complex conjugate of any float."}, "Returns self, the complex conjugate of any float."},
{"__trunc__", (PyCFunction)float_trunc, METH_NOARGS,
"Returns the Integral closest to x between 0 and x."},
{"__round__", (PyCFunction)float_round, METH_VARARGS,
"Returns the Integral closest to x, rounding half toward even.\n"
"When an argument is passed, works like built-in round(x, ndigits)."},
{"__getnewargs__", (PyCFunction)float_getnewargs, METH_NOARGS}, {"__getnewargs__", (PyCFunction)float_getnewargs, METH_NOARGS},
{"__getformat__", (PyCFunction)float_getformat, {"__getformat__", (PyCFunction)float_getformat,
METH_O|METH_CLASS, float_getformat_doc}, METH_O|METH_CLASS, float_getformat_doc},
@ -1020,8 +1067,8 @@ static PyNumberMethods float_as_number = {
0, /*nb_xor*/ 0, /*nb_xor*/
0, /*nb_or*/ 0, /*nb_or*/
(coercion)0, /*nb_coerce*/ (coercion)0, /*nb_coerce*/
float_int, /*nb_int*/ float_trunc, /*nb_int*/
float_long, /*nb_long*/ float_trunc, /*nb_long*/
float_float, /*nb_float*/ float_float, /*nb_float*/
0, /* nb_oct */ 0, /* nb_oct */
0, /* nb_hex */ 0, /* nb_hex */

View File

@ -3592,9 +3592,45 @@ long_getN(PyLongObject *v, void *context) {
return PyLong_FromLong((intptr_t)context); return PyLong_FromLong((intptr_t)context);
} }
static PyObject *
long_round(PyObject *self, PyObject *args)
{
#define UNDEF_NDIGITS (-0x7fffffff) /* Unlikely ndigits value */
int ndigits = UNDEF_NDIGITS;
double x;
PyObject *res;
if (!PyArg_ParseTuple(args, "|i", &ndigits))
return NULL;
if (ndigits == UNDEF_NDIGITS)
return long_long(self);
/* If called with two args, defer to float.__round__(). */
x = PyLong_AsDouble(self);
if (x == -1.0 && PyErr_Occurred())
return NULL;
self = PyFloat_FromDouble(x);
if (self == NULL)
return NULL;
res = PyObject_CallMethod(self, "__round__", "i", ndigits);
Py_DECREF(self);
return res;
#undef UNDEF_NDIGITS
}
static PyMethodDef long_methods[] = { static PyMethodDef long_methods[] = {
{"conjugate", (PyCFunction)long_long, METH_NOARGS, {"conjugate", (PyCFunction)long_long, METH_NOARGS,
"Returns self, the complex conjugate of any int."}, "Returns self, the complex conjugate of any int."},
{"__trunc__", (PyCFunction)long_long, METH_NOARGS,
"Truncating an Integral returns itself."},
{"__floor__", (PyCFunction)long_long, METH_NOARGS,
"Flooring an Integral returns itself."},
{"__ceil__", (PyCFunction)long_long, METH_NOARGS,
"Ceiling of an Integral returns itself."},
{"__round__", (PyCFunction)long_round, METH_VARARGS,
"Rounding an Integral returns itself.\n"
"Rounding with an ndigits arguments defers to float.__round__."},
{"__getnewargs__", (PyCFunction)long_getnewargs, METH_NOARGS}, {"__getnewargs__", (PyCFunction)long_getnewargs, METH_NOARGS},
{NULL, NULL} /* sentinel */ {NULL, NULL} /* sentinel */
}; };

View File

@ -1373,63 +1373,44 @@ For most object types, eval(repr(object)) == object.");
static PyObject * static PyObject *
builtin_round(PyObject *self, PyObject *args, PyObject *kwds) builtin_round(PyObject *self, PyObject *args, PyObject *kwds)
{ {
double number; #define UNDEF_NDIGITS (-0x7fffffff) /* Unlikely ndigits value */
double f; static PyObject *round_str = NULL;
int ndigits = 0; int ndigits = UNDEF_NDIGITS;
int i;
static char *kwlist[] = {"number", "ndigits", 0}; static char *kwlist[] = {"number", "ndigits", 0};
PyObject* real; PyObject *number, *round;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|i:round", if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|i:round",
kwlist, &real, &ndigits)) kwlist, &number, &ndigits))
return NULL; return NULL;
if (ndigits == 0) { if (round_str == NULL) {
PyObject *res; round_str = PyUnicode_FromString("__round__");
PyObject *d = PyObject_GetAttrString(real, "__round__"); if (round_str == NULL)
if (d == NULL && !PyFloat_Check(real)) {
PyErr_SetString(PyExc_TypeError,
"round() argument must have __round__ attribute or be a float");
return NULL; return NULL;
} }
if (d == NULL) {
PyErr_Clear(); round = _PyType_Lookup(Py_Type(number), round_str);
} else { if (round == NULL) {
res = PyObject_CallFunction(d, ""); PyErr_Format(PyExc_TypeError,
Py_DECREF(d); "type %.100s doesn't define __round__ method",
return res; Py_Type(number)->tp_name);
}
} else if (!PyFloat_Check(real)) {
PyErr_SetString(PyExc_TypeError,
"round() argument must have __round__ attribute or be a float");
return NULL; return NULL;
} }
number = PyFloat_AsDouble(real); if (ndigits == UNDEF_NDIGITS)
f = 1.0; return PyObject_CallFunction(round, "O", number);
i = abs(ndigits);
while (--i >= 0)
f = f*10.0;
if (ndigits < 0)
number /= f;
else else
number *= f; return PyObject_CallFunction(round, "Oi", number, ndigits);
if (number >= 0.0) #undef UNDEF_NDIGITS
number = floor(number + 0.5);
else
number = ceil(number - 0.5);
if (ndigits < 0)
number *= f;
else
number /= f;
return PyFloat_FromDouble(number);
} }
PyDoc_STRVAR(round_doc, PyDoc_STRVAR(round_doc,
"round(number[, ndigits]) -> floating point number\n\ "round(number[, ndigits]) -> floating point number\n\
\n\ \n\
Round a number to a given precision in decimal digits (default 0 digits).\n\ Round a number to a given precision in decimal digits (default 0 digits).\n\
This always returns a floating point number. Precision may be negative."); This returns an int when called with one argument, otherwise a float.\n\
Precision may be negative.");
static PyObject * static PyObject *
builtin_sorted(PyObject *self, PyObject *args, PyObject *kwds) builtin_sorted(PyObject *self, PyObject *args, PyObject *kwds)
@ -1511,18 +1492,25 @@ Without arguments, equivalent to locals().\n\
With an argument, equivalent to object.__dict__."); With an argument, equivalent to object.__dict__.");
static PyObject * static PyObject *
builtin_trunc(PyObject *self, PyObject *v) builtin_trunc(PyObject *self, PyObject *number)
{ {
PyObject *res; static PyObject *trunc_str = NULL;
PyObject *d = PyObject_GetAttrString(v, "__trunc__"); PyObject *trunc;
if (d == NULL) {
PyErr_SetString(PyExc_TypeError, if (trunc_str == NULL) {
"trunc() argument must have __trunc__ attribute"); trunc_str = PyUnicode_FromString("__trunc__");
if (trunc_str == NULL)
return NULL;
}
trunc = _PyType_Lookup(Py_Type(number), trunc_str);
if (trunc == NULL) {
PyErr_Format(PyExc_TypeError,
"type %.100s doesn't define __trunc__ method",
Py_Type(number)->tp_name);
return NULL; return NULL;
} }
res = PyObject_CallFunction(d, ""); return PyObject_CallFunction(trunc, "O", number);
Py_DECREF(d);
return res;
} }
PyDoc_STRVAR(trunc_doc, PyDoc_STRVAR(trunc_doc,