mirror of https://github.com/python/cpython
bpo-45876: Correctly rounded stdev() and pstdev() for the Decimal case (GH-29828)
This commit is contained in:
parent
8a45ca542a
commit
a39f46afde
|
@ -137,7 +137,7 @@ from decimal import Decimal
|
|||
from itertools import groupby, repeat
|
||||
from bisect import bisect_left, bisect_right
|
||||
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
|
||||
from operator import itemgetter, mul
|
||||
from operator import mul
|
||||
from collections import Counter, namedtuple
|
||||
|
||||
_SQRT2 = sqrt(2.0)
|
||||
|
@ -248,6 +248,28 @@ def _exact_ratio(x):
|
|||
|
||||
x is expected to be an int, Fraction, Decimal or float.
|
||||
"""
|
||||
|
||||
# XXX We should revisit whether using fractions to accumulate exact
|
||||
# ratios is the right way to go.
|
||||
|
||||
# The integer ratios for binary floats can have numerators or
|
||||
# denominators with over 300 decimal digits. The problem is more
|
||||
# acute with decimal floats where the the default decimal context
|
||||
# supports a huge range of exponents from Emin=-999999 to
|
||||
# Emax=999999. When expanded with as_integer_ratio(), numbers like
|
||||
# Decimal('3.14E+5000') and Decimal('3.14E-5000') have large
|
||||
# numerators or denominators that will slow computation.
|
||||
|
||||
# When the integer ratios are accumulated as fractions, the size
|
||||
# grows to cover the full range from the smallest magnitude to the
|
||||
# largest. For example, Fraction(3.14E+300) + Fraction(3.14E-300),
|
||||
# has a 616 digit numerator. Likewise,
|
||||
# Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000'))
|
||||
# has 10,003 digit numerator.
|
||||
|
||||
# This doesn't seem to have been problem in practice, but it is a
|
||||
# potential pitfall.
|
||||
|
||||
try:
|
||||
return x.as_integer_ratio()
|
||||
except AttributeError:
|
||||
|
@ -305,28 +327,60 @@ def _fail_neg(values, errmsg='negative value'):
|
|||
raise StatisticsError(errmsg)
|
||||
yield x
|
||||
|
||||
def _isqrt_frac_rto(n: int, m: int) -> float:
|
||||
|
||||
def _integer_sqrt_of_frac_rto(n: int, m: int) -> int:
|
||||
"""Square root of n/m, rounded to the nearest integer using round-to-odd."""
|
||||
# Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
|
||||
a = math.isqrt(n // m)
|
||||
return a | (a*a*m != n)
|
||||
|
||||
# For 53 bit precision floats, the _sqrt_frac() shift is 109.
|
||||
_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3
|
||||
|
||||
def _sqrt_frac(n: int, m: int) -> float:
|
||||
# For 53 bit precision floats, the bit width used in
|
||||
# _float_sqrt_of_frac() is 109.
|
||||
_sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3
|
||||
|
||||
|
||||
def _float_sqrt_of_frac(n: int, m: int) -> float:
|
||||
"""Square root of n/m as a float, correctly rounded."""
|
||||
# See principle and proof sketch at: https://bugs.python.org/msg407078
|
||||
q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2
|
||||
q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2
|
||||
if q >= 0:
|
||||
numerator = _isqrt_frac_rto(n, m << 2 * q) << q
|
||||
numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q
|
||||
denominator = 1
|
||||
else:
|
||||
numerator = _isqrt_frac_rto(n << -2 * q, m)
|
||||
numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m)
|
||||
denominator = 1 << -q
|
||||
return numerator / denominator # Convert to float
|
||||
|
||||
|
||||
def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal:
|
||||
"""Square root of n/m as a Decimal, correctly rounded."""
|
||||
# Premise: For decimal, computing (n/m).sqrt() can be off
|
||||
# by 1 ulp from the correctly rounded result.
|
||||
# Method: Check the result, moving up or down a step if needed.
|
||||
if n <= 0:
|
||||
if not n:
|
||||
return Decimal('0.0')
|
||||
n, m = -n, -m
|
||||
|
||||
root = (Decimal(n) / Decimal(m)).sqrt()
|
||||
nr, dr = root.as_integer_ratio()
|
||||
|
||||
plus = root.next_plus()
|
||||
np, dp = plus.as_integer_ratio()
|
||||
# test: n / m > ((root + plus) / 2) ** 2
|
||||
if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2:
|
||||
return plus
|
||||
|
||||
minus = root.next_minus()
|
||||
nm, dm = minus.as_integer_ratio()
|
||||
# test: n / m < ((root + minus) / 2) ** 2
|
||||
if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2:
|
||||
return minus
|
||||
|
||||
return root
|
||||
|
||||
|
||||
# === Measures of central tendency (averages) ===
|
||||
|
||||
def mean(data):
|
||||
|
@ -869,7 +923,7 @@ def stdev(data, xbar=None):
|
|||
if hasattr(T, 'sqrt'):
|
||||
var = _convert(mss, T)
|
||||
return var.sqrt()
|
||||
return _sqrt_frac(mss.numerator, mss.denominator)
|
||||
return _float_sqrt_of_frac(mss.numerator, mss.denominator)
|
||||
|
||||
|
||||
def pstdev(data, mu=None):
|
||||
|
@ -888,10 +942,9 @@ def pstdev(data, mu=None):
|
|||
raise StatisticsError('pstdev requires at least one data point')
|
||||
T, ss = _ss(data, mu)
|
||||
mss = ss / n
|
||||
if hasattr(T, 'sqrt'):
|
||||
var = _convert(mss, T)
|
||||
return var.sqrt()
|
||||
return _sqrt_frac(mss.numerator, mss.denominator)
|
||||
if issubclass(T, Decimal):
|
||||
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
|
||||
return _float_sqrt_of_frac(mss.numerator, mss.denominator)
|
||||
|
||||
|
||||
# === Statistics for relations between two inputs ===
|
||||
|
|
|
@ -2164,9 +2164,9 @@ class TestPStdev(VarianceStdevMixin, NumericTestCase):
|
|||
|
||||
class TestSqrtHelpers(unittest.TestCase):
|
||||
|
||||
def test_isqrt_frac_rto(self):
|
||||
def test_integer_sqrt_of_frac_rto(self):
|
||||
for n, m in itertools.product(range(100), range(1, 1000)):
|
||||
r = statistics._isqrt_frac_rto(n, m)
|
||||
r = statistics._integer_sqrt_of_frac_rto(n, m)
|
||||
self.assertIsInstance(r, int)
|
||||
if r*r*m == n:
|
||||
# Root is exact
|
||||
|
@ -2177,7 +2177,7 @@ class TestSqrtHelpers(unittest.TestCase):
|
|||
self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2)
|
||||
|
||||
@requires_IEEE_754
|
||||
def test_sqrt_frac(self):
|
||||
def test_float_sqrt_of_frac(self):
|
||||
|
||||
def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
|
||||
if not x:
|
||||
|
@ -2204,22 +2204,59 @@ class TestSqrtHelpers(unittest.TestCase):
|
|||
denonimator: int = randrange(10 ** randrange(50)) + 1
|
||||
with self.subTest(numerator=numerator, denonimator=denonimator):
|
||||
x: Fraction = Fraction(numerator, denonimator)
|
||||
root: float = statistics._sqrt_frac(numerator, denonimator)
|
||||
root: float = statistics._float_sqrt_of_frac(numerator, denonimator)
|
||||
self.assertTrue(is_root_correctly_rounded(x, root))
|
||||
|
||||
# Verify that corner cases and error handling match math.sqrt()
|
||||
self.assertEqual(statistics._sqrt_frac(0, 1), 0.0)
|
||||
self.assertEqual(statistics._float_sqrt_of_frac(0, 1), 0.0)
|
||||
with self.assertRaises(ValueError):
|
||||
statistics._sqrt_frac(-1, 1)
|
||||
statistics._float_sqrt_of_frac(-1, 1)
|
||||
with self.assertRaises(ValueError):
|
||||
statistics._sqrt_frac(1, -1)
|
||||
statistics._float_sqrt_of_frac(1, -1)
|
||||
|
||||
# Error handling for zero denominator matches that for Fraction(1, 0)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
statistics._sqrt_frac(1, 0)
|
||||
statistics._float_sqrt_of_frac(1, 0)
|
||||
|
||||
# The result is well defined if both inputs are negative
|
||||
self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0))
|
||||
self.assertEqual(statistics._float_sqrt_of_frac(-2, -1), statistics._float_sqrt_of_frac(2, 1))
|
||||
|
||||
def test_decimal_sqrt_of_frac(self):
|
||||
root: Decimal
|
||||
numerator: int
|
||||
denominator: int
|
||||
|
||||
for root, numerator, denominator in [
|
||||
(Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000), # No adj
|
||||
(Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000), # Adj up
|
||||
(Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000), # Adj down
|
||||
]:
|
||||
with decimal.localcontext(decimal.DefaultContext):
|
||||
self.assertEqual(statistics._decimal_sqrt_of_frac(numerator, denominator), root)
|
||||
|
||||
# Confirm expected root with a quad precision decimal computation
|
||||
with decimal.localcontext(decimal.DefaultContext) as ctx:
|
||||
ctx.prec *= 4
|
||||
high_prec_ratio = Decimal(numerator) / Decimal(denominator)
|
||||
ctx.rounding = decimal.ROUND_05UP
|
||||
high_prec_root = high_prec_ratio.sqrt()
|
||||
with decimal.localcontext(decimal.DefaultContext):
|
||||
target_root = +high_prec_root
|
||||
self.assertEqual(root, target_root)
|
||||
|
||||
# Verify that corner cases and error handling match Decimal.sqrt()
|
||||
self.assertEqual(statistics._decimal_sqrt_of_frac(0, 1), 0.0)
|
||||
with self.assertRaises(decimal.InvalidOperation):
|
||||
statistics._decimal_sqrt_of_frac(-1, 1)
|
||||
with self.assertRaises(decimal.InvalidOperation):
|
||||
statistics._decimal_sqrt_of_frac(1, -1)
|
||||
|
||||
# Error handling for zero denominator matches that for Fraction(1, 0)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
statistics._decimal_sqrt_of_frac(1, 0)
|
||||
|
||||
# The result is well defined if both inputs are negative
|
||||
self.assertEqual(statistics._decimal_sqrt_of_frac(-2, -1), statistics._decimal_sqrt_of_frac(2, 1))
|
||||
|
||||
|
||||
class TestStdev(VarianceStdevMixin, NumericTestCase):
|
||||
|
|
Loading…
Reference in New Issue