Issue27181 add geometric mean.

This commit is contained in:
Steven D'Aprano 2016-08-09 13:58:10 +10:00
parent e7fef52f98
commit 9a2be91c6b
2 changed files with 552 additions and 0 deletions

View File

@ -303,6 +303,230 @@ def _fail_neg(values, errmsg='negative value'):
yield x
class _nroot_NS:
"""Hands off! Don't touch!
Everything inside this namespace (class) is an even-more-private
implementation detail of the private _nth_root function.
"""
# This class exists only to be used as a namespace, for convenience
# of being able to keep the related functions together, and to
# collapse the group in an editor. If this were C# or C++, I would
# use a Namespace, but the closest Python has is a class.
#
# FIXME possibly move this out into a separate module?
# That feels like overkill, and may encourage people to treat it as
# a public feature.
def __init__(self):
raise TypeError('namespace only, do not instantiate')
def nth_root(x, n):
"""Return the positive nth root of numeric x.
This may be more accurate than ** or pow():
>>> math.pow(1000, 1.0/3) #doctest:+SKIP
9.999999999999998
>>> _nth_root(1000, 3)
10.0
>>> _nth_root(11**5, 5)
11.0
>>> _nth_root(2, 12)
1.0594630943592953
"""
if not isinstance(n, int):
raise TypeError('degree n must be an int')
if n < 2:
raise ValueError('degree n must be 2 or more')
if isinstance(x, decimal.Decimal):
return _nroot_NS.decimal_nroot(x, n)
elif isinstance(x, numbers.Real):
return _nroot_NS.float_nroot(x, n)
else:
raise TypeError('expected a number, got %s') % type(x).__name__
def float_nroot(x, n):
"""Handle nth root of Reals, treated as a float."""
assert isinstance(n, int) and n > 1
if x < 0:
if n%2 == 0:
raise ValueError('domain error: even root of negative number')
else:
return -_nroot_NS.nroot(-x, n)
elif x == 0:
return math.copysign(0.0, x)
elif x > 0:
try:
isinfinity = math.isinf(x)
except OverflowError:
return _nroot_NS.bignum_nroot(x, n)
else:
if isinfinity:
return float('inf')
else:
return _nroot_NS.nroot(x, n)
else:
assert math.isnan(x)
return float('nan')
def nroot(x, n):
"""Calculate x**(1/n), then improve the answer."""
# This uses math.pow() to calculate an initial guess for the root,
# then uses the iterated nroot algorithm to improve it.
#
# By my testing, about 8% of the time the iterated algorithm ends
# up converging to a result which is less accurate than the initial
# guess. [FIXME: is this still true?] In that case, we use the
# guess instead of the "improved" value. This way, we're never
# less accurate than math.pow().
r1 = math.pow(x, 1.0/n)
eps1 = abs(r1**n - x)
if eps1 == 0.0:
# r1 is the exact root, so we're done. By my testing, this
# occurs about 80% of the time for x < 1 and 30% of the
# time for x > 1.
return r1
else:
try:
r2 = _nroot_NS.iterated_nroot(x, n, r1)
except RuntimeError:
return r1
else:
eps2 = abs(r2**n - x)
if eps1 < eps2:
return r1
return r2
def iterated_nroot(a, n, g):
"""Return the nth root of a, starting with guess g.
This is a special case of Newton's Method.
https://en.wikipedia.org/wiki/Nth_root_algorithm
"""
np = n - 1
def iterate(r):
try:
return (np*r + a/math.pow(r, np))/n
except OverflowError:
# If r is large enough, r**np may overflow. If that
# happens, r**-np will be small, but not necessarily zero.
return (np*r + a*math.pow(r, -np))/n
# With a good guess, such as g = a**(1/n), this will converge in
# only a few iterations. However a poor guess can take thousands
# of iterations to converge, if at all. We guard against poor
# guesses by setting an upper limit to the number of iterations.
r1 = g
r2 = iterate(g)
for i in range(1000):
if r1 == r2:
break
# Use Floyd's cycle-finding algorithm to avoid being trapped
# in a cycle.
# https://en.wikipedia.org/wiki/Cycle_detection#Tortoise_and_hare
r1 = iterate(r1)
r2 = iterate(iterate(r2))
else:
# If the guess is particularly bad, the above may fail to
# converge in any reasonable time.
raise RuntimeError('nth-root failed to converge')
return r2
def decimal_nroot(x, n):
"""Handle nth root of Decimals."""
assert isinstance(x, decimal.Decimal)
assert isinstance(n, int)
if x.is_snan():
# Signalling NANs always raise.
raise decimal.InvalidOperation('nth-root of snan')
if x.is_qnan():
# Quiet NANs only raise if the context is set to raise,
# otherwise return a NAN.
ctx = decimal.getcontext()
if ctx.traps[decimal.InvalidOperation]:
raise decimal.InvalidOperation('nth-root of nan')
else:
# Preserve the input NAN.
return x
if x.is_infinite():
return x
# FIXME this hasn't had the extensive testing of the float
# version _iterated_nroot so there's possibly some buggy
# corner cases buried in here. Can it overflow? Fail to
# converge or get trapped in a cycle? Converge to a less
# accurate root?
np = n - 1
def iterate(r):
return (np*r + x/r**np)/n
r0 = x**(decimal.Decimal(1)/n)
assert isinstance(r0, decimal.Decimal)
r1 = iterate(r0)
while True:
if r1 == r0:
return r1
r0, r1 = r1, iterate(r1)
def bignum_nroot(x, n):
"""Return the nth root of a positive huge number."""
assert x > 0
# I state without proof that ⁿ√x ≈ ⁿ√2·ⁿ√(x//2)
# and that for sufficiently big x the error is acceptible.
# We now halve x until it is small enough to get the root.
m = 0
while True:
x //= 2
m += 1
try:
y = float(x)
except OverflowError:
continue
break
a = _nroot_NS.nroot(y, n)
# At this point, we want the nth-root of 2**m, or 2**(m/n).
# We can write that as 2**(q + r/n) = 2**q * ⁿ√2**r where q = m//n.
q, r = divmod(m, n)
b = 2**q * _nroot_NS.nroot(2**r, n)
return a * b
# This is the (private) function for calculating nth roots:
_nth_root = _nroot_NS.nth_root
assert type(_nth_root) is type(lambda: None)
def _product(values):
"""Return product of values as (exponent, mantissa)."""
errmsg = 'mixed Decimal and float is not supported'
prod = 1
for x in values:
if isinstance(x, float):
break
prod *= x
else:
return (0, prod)
if isinstance(prod, Decimal):
raise TypeError(errmsg)
# Since floats can overflow easily, we calculate the product as a
# sort of poor-man's BigFloat. Given that:
#
# x = 2**p * m # p == power or exponent (scale), m = mantissa
#
# we can calculate the product of two (or more) x values as:
#
# x1*x2 = 2**p1*m1 * 2**p2*m2 = 2**(p1+p2)*(m1*m2)
#
mant, scale = 1, 0 #math.frexp(prod) # FIXME
for y in chain([x], values):
if isinstance(y, Decimal):
raise TypeError(errmsg)
m1, e1 = math.frexp(y)
m2, e2 = math.frexp(mant)
scale += (e1 + e2)
mant = m1*m2
return (scale, mant)
# === Measures of central tendency (averages) ===
def mean(data):
@ -331,6 +555,49 @@ def mean(data):
return _convert(total/n, T)
def geometric_mean(data):
"""Return the geometric mean of data.
The geometric mean is appropriate when averaging quantities which
are multiplied together rather than added, for example growth rates.
Suppose an investment grows by 10% in the first year, falls by 5% in
the second, then grows by 12% in the third, what is the average rate
of growth over the three years?
>>> geometric_mean([1.10, 0.95, 1.12])
1.0538483123382172
giving an average growth of 5.385%. Using the arithmetic mean will
give approximately 5.667%, which is too high.
``StatisticsError`` will be raised if ``data`` is empty, or any
element is less than zero.
"""
if iter(data) is data:
data = list(data)
errmsg = 'geometric mean does not support negative values'
n = len(data)
if n < 1:
raise StatisticsError('geometric_mean requires at least one data point')
elif n == 1:
x = data[0]
if isinstance(g, (numbers.Real, Decimal)):
if x < 0:
raise StatisticsError(errmsg)
return x
else:
raise TypeError('unsupported type')
else:
scale, prod = _product(_fail_neg(data, errmsg))
r = _nth_root(prod, n)
if scale:
p, q = divmod(scale, n)
s = 2**p * _nth_root(2**q, n)
else:
s = 1
return s*r
def harmonic_mean(data):
"""Return the harmonic mean of data.

View File

@ -1010,6 +1010,291 @@ class FailNegTest(unittest.TestCase):
self.assertEqual(errmsg, msg)
class Test_Product(NumericTestCase):
"""Test the private _product function."""
def test_ints(self):
data = [1, 2, 5, 7, 9]
self.assertEqual(statistics._product(data), (0, 630))
self.assertEqual(statistics._product(data*100), (0, 630**100))
def test_floats(self):
data = [1.0, 2.0, 4.0, 8.0]
self.assertEqual(statistics._product(data), (8, 0.25))
def test_overflow(self):
# Test with floats that overflow.
data = [1e300]*5
self.assertEqual(statistics._product(data), (5980, 0.6928287951283193))
def test_fractions(self):
F = Fraction
data = [F(14, 23), F(69, 1), F(665, 529), F(299, 105), F(1683, 39)]
exp, mant = statistics._product(data)
self.assertEqual(exp, 0)
self.assertEqual(mant, F(2*3*7*11*17*19, 23))
self.assertTrue(isinstance(mant, F))
# Mixed Fraction and int.
data = [3, 25, F(2, 15)]
exp, mant = statistics._product(data)
self.assertEqual(exp, 0)
self.assertEqual(mant, F(10))
self.assertTrue(isinstance(mant, F))
@unittest.expectedFailure
def test_decimal(self):
D = Decimal
data = [D('24.5'), D('17.6'), D('0.025'), D('1.3')]
assert False
def test_mixed_decimal_float(self):
# Test that mixed Decimal and float raises.
self.assertRaises(TypeError, statistics._product, [1.0, Decimal(1)])
self.assertRaises(TypeError, statistics._product, [Decimal(1), 1.0])
class Test_Nth_Root(NumericTestCase):
"""Test the functionality of the private _nth_root function."""
def setUp(self):
self.nroot = statistics._nth_root
# --- Special values (infinities, NANs, zeroes) ---
def test_float_NAN(self):
# Test that the root of a float NAN is a float NAN.
NAN = float('nan')
for n in range(2, 9):
with self.subTest(n=n):
result = self.nroot(NAN, n)
self.assertTrue(math.isnan(result))
def test_decimal_QNAN(self):
# Test the behaviour when taking the root of a Decimal quiet NAN.
NAN = decimal.Decimal('nan')
with decimal.localcontext() as ctx:
ctx.traps[decimal.InvalidOperation] = 1
self.assertRaises(decimal.InvalidOperation, self.nroot, NAN, 5)
ctx.traps[decimal.InvalidOperation] = 0
self.assertTrue(self.nroot(NAN, 5).is_qnan())
def test_decimal_SNAN(self):
# Test that taking the root of a Decimal sNAN always raises.
sNAN = decimal.Decimal('snan')
with decimal.localcontext() as ctx:
ctx.traps[decimal.InvalidOperation] = 1
self.assertRaises(decimal.InvalidOperation, self.nroot, sNAN, 5)
ctx.traps[decimal.InvalidOperation] = 0
self.assertRaises(decimal.InvalidOperation, self.nroot, sNAN, 5)
def test_inf(self):
# Test that the root of infinity is infinity.
for INF in (float('inf'), decimal.Decimal('inf')):
for n in range(2, 9):
with self.subTest(n=n, inf=INF):
self.assertEqual(self.nroot(INF, n), INF)
def testNInf(self):
# Test that the root of -inf is -inf for odd n.
for NINF in (float('-inf'), decimal.Decimal('-inf')):
for n in range(3, 11, 2):
with self.subTest(n=n, inf=NINF):
self.assertEqual(self.nroot(NINF, n), NINF)
# FIXME: need to check Decimal zeroes too.
def test_zero(self):
# Test that the root of +0.0 is +0.0.
for n in range(2, 11):
with self.subTest(n=n):
result = self.nroot(+0.0, n)
self.assertEqual(result, 0.0)
self.assertEqual(sign(result), +1)
# FIXME: need to check Decimal zeroes too.
def test_neg_zero(self):
# Test that the root of -0.0 is -0.0.
for n in range(2, 11):
with self.subTest(n=n):
result = self.nroot(-0.0, n)
self.assertEqual(result, 0.0)
self.assertEqual(sign(result), -1)
# --- Test return types ---
def check_result_type(self, x, n, outtype):
self.assertIsInstance(self.nroot(x, n), outtype)
class MySubclass(type(x)):
pass
self.assertIsInstance(self.nroot(MySubclass(x), n), outtype)
def testDecimal(self):
# Test that Decimal arguments return Decimal results.
self.check_result_type(decimal.Decimal('33.3'), 3, decimal.Decimal)
def testFloat(self):
# Test that other arguments return float results.
for x in (0.2, Fraction(11, 7), 91):
self.check_result_type(x, 6, float)
# --- Test bad input ---
def testBadOrderTypes(self):
# Test that nroot raises correctly when n has the wrong type.
for n in (5.0, 2j, None, 'x', b'x', [], {}, set(), sign):
with self.subTest(n=n):
self.assertRaises(TypeError, self.nroot, 2.5, n)
def testBadOrderValues(self):
# Test that nroot raises correctly when n has a wrong value.
for n in (1, 0, -1, -2, -87):
with self.subTest(n=n):
self.assertRaises(ValueError, self.nroot, 2.5, n)
def testBadTypes(self):
# Test that nroot raises correctly when x has the wrong type.
for x in (None, 'x', b'x', [], {}, set(), sign):
with self.subTest(x=x):
self.assertRaises(TypeError, self.nroot, x, 3)
def testNegativeEvenPower(self):
# Test negative x with even n raises correctly.
x = random.uniform(-20.0, -0.1)
assert x < 0
for n in range(2, 9, 2):
with self.subTest(x=x, n=n):
self.assertRaises(ValueError, self.nroot, x, n)
# --- Test that nroot is never worse than calling math.pow() ---
def check_error_is_no_worse(self, x, n):
y = math.pow(x, n)
with self.subTest(x=x, n=n, y=y):
err1 = abs(self.nroot(y, n) - x)
err2 = abs(math.pow(y, 1.0/n) - x)
self.assertLessEqual(err1, err2)
def testCompareWithPowSmall(self):
# Compare nroot with pow for small values of x.
for i in range(200):
x = random.uniform(1e-9, 1.0-1e-9)
n = random.choice(range(2, 16))
self.check_error_is_no_worse(x, n)
def testCompareWithPowMedium(self):
# Compare nroot with pow for medium-sized values of x.
for i in range(200):
x = random.uniform(1.0, 100.0)
n = random.choice(range(2, 16))
self.check_error_is_no_worse(x, n)
def testCompareWithPowLarge(self):
# Compare nroot with pow for largish values of x.
for i in range(200):
x = random.uniform(100.0, 10000.0)
n = random.choice(range(2, 16))
self.check_error_is_no_worse(x, n)
def testCompareWithPowHuge(self):
# Compare nroot with pow for huge values of x.
for i in range(200):
x = random.uniform(1e20, 1e50)
# We restrict the order here to avoid an Overflow error.
n = random.choice(range(2, 7))
self.check_error_is_no_worse(x, n)
# --- Test for numerically correct answers ---
def testExactPowers(self):
# Test that small integer powers are calculated exactly.
for i in range(1, 51):
for n in range(2, 16):
if (i, n) == (35, 13):
# See testExpectedFailure35p13
continue
with self.subTest(i=i, n=n):
x = i**n
self.assertEqual(self.nroot(x, n), i)
def testExactPowersNegatives(self):
# Test that small negative integer powers are calculated exactly.
for i in range(-1, -51, -1):
for n in range(3, 16, 2):
if (i, n) == (-35, 13):
# See testExpectedFailure35p13
continue
with self.subTest(i=i, n=n):
x = i**n
assert sign(x) == -1
self.assertEqual(self.nroot(x, n), i)
def testExpectedFailure35p13(self):
# Test the expected failure 35**13 is almost exact.
x = 35**13
err = abs(self.nroot(x, 13) - 35)
self.assertLessEqual(err, 0.000000001)
err = abs(self.nroot(-x, 13) + 35)
self.assertLessEqual(err, 0.000000001)
def testOne(self):
# Test that the root of 1.0 is 1.0.
for n in range(2, 11):
with self.subTest(n=n):
self.assertEqual(self.nroot(1.0, n), 1.0)
def testFraction(self):
# Test Fraction results.
x = Fraction(89, 75)
self.assertEqual(self.nroot(x**12, 12), float(x))
def testInt(self):
# Test int results.
x = 276
self.assertEqual(self.nroot(x**24, 24), x)
def testBigInt(self):
# Test that ints too big to convert to floats work.
bignum = 10**20 # That's not that big...
self.assertEqual(self.nroot(bignum**280, 280), bignum)
# Can we make it bigger?
hugenum = bignum**50
# Make sure that it is too big to convert to a float.
try:
y = float(hugenum)
except OverflowError:
pass
else:
raise AssertionError('hugenum is not big enough')
self.assertEqual(self.nroot(hugenum, 50), float(bignum))
def testDecimal(self):
# Test Decimal results.
for s in '3.759 64.027 5234.338'.split():
x = decimal.Decimal(s)
with self.subTest(x=x):
a = self.nroot(x**5, 5)
self.assertEqual(a, x)
a = self.nroot(x**17, 17)
self.assertEqual(a, x)
def testFloat(self):
# Test float results.
for x in (3.04e-16, 18.25, 461.3, 1.9e17):
with self.subTest(x=x):
self.assertEqual(self.nroot(x**3, 3), x)
self.assertEqual(self.nroot(x**8, 8), x)
self.assertEqual(self.nroot(x**11, 11), x)
class Test_NthRoot_NS(unittest.TestCase):
"""Test internals of the nth_root function, hidden in _nroot_NS."""
def test_class_cannot_be_instantiated(self):
# Test that _nroot_NS cannot be instantiated.
# It should be a namespace, like in C++ or C#, but Python
# lacks that feature and so we have to make do with a class.
self.assertRaises(TypeError, statistics._nroot_NS)
# === Tests for public functions ===
class UnivariateCommonMixin: