Clean-up and simplify median_grouped(). Vastly improve its docstring. (#92324)

This commit is contained in:
Raymond Hettinger 2022-05-05 03:01:07 -05:00 committed by GitHub
parent b885b8f4be
commit 5212cbc261
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 91 deletions

View File

@ -348,22 +348,6 @@ def _convert(value, T):
raise
def _find_lteq(a, x):
'Locate the leftmost value exactly equal to x'
i = bisect_left(a, x)
if i != len(a) and a[i] == x:
return i
raise ValueError
def _find_rteq(a, l, x):
'Locate the rightmost value exactly equal to x'
i = bisect_right(a, x, lo=l)
if i != (len(a) + 1) and a[i - 1] == x:
return i - 1
raise ValueError
def _fail_neg(values, errmsg='negative value'):
"""Iterate over values, failing if any are less than zero."""
for x in values:
@ -628,30 +612,44 @@ def median_high(data):
def median_grouped(data, interval=1):
"""Return the 50th percentile (median) of grouped continuous data.
"""Estimates the median for numeric data binned around the midpoints
of consecutive, fixed-width intervals.
>>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5])
3.7
>>> median_grouped([52, 52, 53, 54])
52.5
The *data* can be any iterable of numeric data with each value being
exactly the midpoint of a bin. At least one value must be present.
This calculates the median as the 50th percentile, and should be
used when your data is continuous and grouped. In the above example,
the values 1, 2, 3, etc. actually represent the midpoint of classes
0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in
class 3.5-4.5, and interpolation is used to estimate it.
The *interval* is width of each bin.
Optional argument ``interval`` represents the class interval, and
defaults to 1. Changing the class interval naturally will change the
interpolated 50th percentile value:
For example, demographic information may have been summarized into
consecutive ten-year age groups with each group being represented
by the 5-year midpoints of the intervals:
>>> median_grouped([1, 3, 3, 5, 7], interval=1)
3.25
>>> median_grouped([1, 3, 3, 5, 7], interval=2)
3.5
>>> demographics = Counter({
... 25: 172, # 20 to 30 years old
... 35: 484, # 30 to 40 years old
... 45: 387, # 40 to 50 years old
... 55: 22, # 50 to 60 years old
... 65: 6, # 60 to 70 years old
... })
The 50th percentile (median) is the 536th person out of the 1071
member cohort. That person is in the 30 to 40 year old age group.
The regular median() function would assume that everyone in the
tricenarian age group was exactly 35 years old. A more tenable
assumption is that the 484 members of that age group are evenly
distributed between 30 and 40. For that, we use median_grouped().
>>> data = list(demographics.elements())
>>> median(data)
35
>>> round(median_grouped(data, interval=10), 1)
37.5
The caller is responsible for making sure the data points are separated
by exact multiples of *interval*. This is essential for getting a
correct result. The function does not check this precondition.
This function does not check whether the data points are at least
``interval`` apart.
"""
data = sorted(data)
n = len(data)
@ -659,26 +657,30 @@ def median_grouped(data, interval=1):
raise StatisticsError("no median for empty data")
elif n == 1:
return data[0]
# Find the value at the midpoint. Remember this corresponds to the
# centre of the class interval.
# midpoint of the class interval.
x = data[n // 2]
# Generate a clear error message for non-numeric data
for obj in (x, interval):
if isinstance(obj, (str, bytes)):
raise TypeError('expected number but got %r' % obj)
raise TypeError(f'expected a number but got {obj!r}')
# Using O(log n) bisection, find where all the x values occur in the data.
# All x will lie within data[i:j].
i = bisect_left(data, x)
j = bisect_right(data, x, lo=i)
# Interpolate the median using the formula found at:
# https://www.cuemath.com/data/median-of-grouped-data/
try:
L = x - interval / 2 # The lower limit of the median interval.
except TypeError:
# Mixed type. For now we just coerce to float.
# Coerce mixed types to float.
L = float(x) - float(interval) / 2
# Uses bisection search to search for x in data with log(n) time complexity
# Find the position of leftmost occurrence of x in data
l1 = _find_lteq(data, x)
# Find the position of rightmost occurrence of x in data[l1...len(data)]
# Assuming always l1 <= l2
l2 = _find_rteq(data, l1, x)
cf = l1
f = l2 - l1 + 1
cf = i # Cumulative frequency of the preceding interval
f = j - i # Number of elements in the median internal
return L + interval * (n / 2 - cf) / f

View File

@ -1040,50 +1040,6 @@ class FailNegTest(unittest.TestCase):
self.assertEqual(errmsg, msg)
class FindLteqTest(unittest.TestCase):
# Test _find_lteq private function.
def test_invalid_input_values(self):
for a, x in [
([], 1),
([1, 2], 3),
([1, 3], 2)
]:
with self.subTest(a=a, x=x):
with self.assertRaises(ValueError):
statistics._find_lteq(a, x)
def test_locate_successfully(self):
for a, x, expected_i in [
([1, 1, 1, 2, 3], 1, 0),
([0, 1, 1, 1, 2, 3], 1, 1),
([1, 2, 3, 3, 3], 3, 2)
]:
with self.subTest(a=a, x=x):
self.assertEqual(expected_i, statistics._find_lteq(a, x))
class FindRteqTest(unittest.TestCase):
# Test _find_rteq private function.
def test_invalid_input_values(self):
for a, l, x in [
([1], 2, 1),
([1, 3], 0, 2)
]:
with self.assertRaises(ValueError):
statistics._find_rteq(a, l, x)
def test_locate_successfully(self):
for a, l, x, expected_i in [
([1, 1, 1, 2, 3], 0, 1, 2),
([0, 1, 1, 1, 2, 3], 0, 1, 3),
([1, 2, 3, 3, 3], 0, 3, 4)
]:
with self.subTest(a=a, l=l, x=x):
self.assertEqual(expected_i, statistics._find_rteq(a, l, x))
# === Tests for public functions ===
class UnivariateCommonMixin: