Rework integer overflow path in math.prod and add more tests (GH-11809)

The overflow check was relying on undefined behaviour as it was using the result of the multiplication to do the check, and once the overflow has already happened, any operation on the result is undefined behaviour.

Some extra checks that exercise code paths related to this are also added.
This commit is contained in:
Pablo Galindo 2019-03-09 19:18:08 +00:00 committed by GitHub
parent 62fa51f121
commit 0411411c6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 137 additions and 40 deletions

View File

@ -1595,6 +1595,92 @@ class MathTests(unittest.TestCase):
self.fail('Failures in test_mtestfile:\n ' +
'\n '.join(failures))
def test_prod(self):
prod = math.prod
self.assertEqual(prod([]), 1)
self.assertEqual(prod([], start=5), 5)
self.assertEqual(prod(list(range(2,8))), 5040)
self.assertEqual(prod(iter(list(range(2,8)))), 5040)
self.assertEqual(prod(range(1, 10), start=10), 3628800)
self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
# Test overflow in fast-path for integers
self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
# Test overflow in fast-path for floats
self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
self.assertRaises(TypeError, prod)
self.assertRaises(TypeError, prod, 42)
self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
values = [bytearray(b'a'), bytearray(b'b')]
self.assertRaises(TypeError, prod, values, bytearray(b''))
self.assertRaises(TypeError, prod, [[1], [2], [3]])
self.assertRaises(TypeError, prod, [{2:3}])
self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
with self.assertRaises(TypeError):
prod([10, 20], [30, 40]) # start is a keyword-only argument
self.assertEqual(prod([0, 1, 2, 3]), 0)
self.assertEqual(prod([1, 0, 2, 3]), 0)
self.assertEqual(prod([1, 2, 3, 0]), 0)
def _naive_prod(iterable, start=1):
for elem in iterable:
start *= elem
return start
# Big integers
iterable = range(1, 10000)
self.assertEqual(prod(iterable), _naive_prod(iterable))
iterable = range(-10000, -1)
self.assertEqual(prod(iterable), _naive_prod(iterable))
iterable = range(-1000, 1000)
self.assertEqual(prod(iterable), 0)
# Big floats
iterable = [float(x) for x in range(1, 1000)]
self.assertEqual(prod(iterable), _naive_prod(iterable))
iterable = [float(x) for x in range(-1000, -1)]
self.assertEqual(prod(iterable), _naive_prod(iterable))
iterable = [float(x) for x in range(-1000, 1000)]
self.assertIsNaN(prod(iterable))
# Float tests
self.assertIsNaN(prod([1, 2, 3, float("nan"), 2, 3]))
self.assertIsNaN(prod([1, 0, float("nan"), 2, 3]))
self.assertIsNaN(prod([1, float("nan"), 0, 3]))
self.assertIsNaN(prod([1, float("inf"), float("nan"),3]))
self.assertIsNaN(prod([1, float("-inf"), float("nan"),3]))
self.assertIsNaN(prod([1, float("nan"), float("inf"),3]))
self.assertIsNaN(prod([1, float("nan"), float("-inf"),3]))
self.assertEqual(prod([1, 2, 3, float('inf'),-3,4]), float('-inf'))
self.assertEqual(prod([1, 2, 3, float('-inf'),-3,4]), float('inf'))
self.assertIsNaN(prod([1,2,0,float('inf'), -3, 4]))
self.assertIsNaN(prod([1,2,0,float('-inf'), -3, 4]))
self.assertIsNaN(prod([1, 2, 3, float('inf'), -3, 0, 3]))
self.assertIsNaN(prod([1, 2, 3, float('-inf'), -3, 0, 2]))
# Type preservation
self.assertEqual(type(prod([1, 2, 3, 4, 5, 6])), int)
self.assertEqual(type(prod([1, 2.0, 3, 4, 5, 6])), float)
self.assertEqual(type(prod(range(1, 10000))), int)
self.assertEqual(type(prod(range(1, 10000), start=1.0)), float)
self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])),
decimal.Decimal)
# Custom assertions.
def assertIsNaN(self, value):
@ -1724,41 +1810,6 @@ class IsCloseTests(unittest.TestCase):
self.assertAllClose(fraction_examples, rel_tol=1e-8)
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
def test_prod(self):
prod = math.prod
self.assertEqual(prod([]), 1)
self.assertEqual(prod([], start=5), 5)
self.assertEqual(prod(list(range(2,8))), 5040)
self.assertEqual(prod(iter(list(range(2,8)))), 5040)
self.assertEqual(prod(range(1, 10), start=10), 3628800)
self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
# Test overflow in fast-path for integers
self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
# Test overflow in fast-path for floats
self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
self.assertRaises(TypeError, prod)
self.assertRaises(TypeError, prod, 42)
self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
values = [bytearray(b'a'), bytearray(b'b')]
self.assertRaises(TypeError, prod, values, bytearray(b''))
self.assertRaises(TypeError, prod, [[1], [2], [3]])
self.assertRaises(TypeError, prod, [{2:3}])
self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
with self.assertRaises(TypeError):
prod([10, 20], [30, 40]) # start is a keyword-only argument
self.assertEqual(prod([0, 1, 2, 3]), 0)
self.assertEqual(prod([1, 0, 2, 3]), 0)
self.assertEqual(prod(range(10)), 0)
def test_main():
from doctest import DocFileSuite

View File

@ -2493,6 +2493,55 @@ math_isclose_impl(PyObject *module, double a, double b, double rel_tol,
(diff <= abs_tol));
}
static inline int
_check_long_mult_overflow(long a, long b) {
/* From Python2's int_mul code:
Integer overflow checking for * is painful: Python tried a couple ways, but
they didn't work on all platforms, or failed in endcases (a product of
-sys.maxint-1 has been a particular pain).
Here's another way:
The native long product x*y is either exactly right or *way* off, being
just the last n bits of the true product, where n is the number of bits
in a long (the delivered product is the true product plus i*2**n for
some integer i).
The native double product (double)x * (double)y is subject to three
rounding errors: on a sizeof(long)==8 box, each cast to double can lose
info, and even on a sizeof(long)==4 box, the multiplication can lose info.
But, unlike the native long product, it's not in *range* trouble: even
if sizeof(long)==32 (256-bit longs), the product easily fits in the
dynamic range of a double. So the leading 50 (or so) bits of the double
product are correct.
We check these two ways against each other, and declare victory if they're
approximately the same. Else, because the native long product is the only
one that can lose catastrophic amounts of information, it's the native long
product that must have overflowed.
*/
long longprod = (long)((unsigned long)a * b);
double doubleprod = (double)a * (double)b;
double doubled_longprod = (double)longprod;
if (doubled_longprod == doubleprod) {
return 0;
}
const double diff = doubled_longprod - doubleprod;
const double absdiff = diff >= 0.0 ? diff : -diff;
const double absprod = doubleprod >= 0.0 ? doubleprod : -doubleprod;
if (32.0 * absdiff <= absprod) {
return 0;
}
return 1;
}
/*[clinic input]
math.prod
@ -2558,11 +2607,8 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
}
if (PyLong_CheckExact(item)) {
long b = PyLong_AsLongAndOverflow(item, &overflow);
long x = i_result * b;
/* Continue if there is no overflow */
if (overflow == 0
&& x < LONG_MAX && x > LONG_MIN
&& !(b != 0 && x / b != i_result)) {
if (overflow == 0 && !_check_long_mult_overflow(i_result, b)) {
long x = i_result * b;
i_result = x;
Py_DECREF(item);
continue;