From 81a5fc38e81b424869f4710f48e9371dfa2d3b77 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Fri, 8 May 2020 07:53:15 -0700 Subject: [PATCH] bpo-40541: Add optional *counts* parameter to random.sample() (GH-19970) --- Doc/library/random.rst | 21 ++++-- Lib/random.py | 34 +++++++-- Lib/test/test_random.py | 73 ++++++++++++++++++- .../2020-05-06-15-36-47.bpo-40541.LlYghL.rst | 1 + 4 files changed, 116 insertions(+), 13 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2020-05-06-15-36-47.bpo-40541.LlYghL.rst diff --git a/Doc/library/random.rst b/Doc/library/random.rst index f37bc2a111d..90366f499ca 100644 --- a/Doc/library/random.rst +++ b/Doc/library/random.rst @@ -217,7 +217,7 @@ Functions for sequences The optional parameter *random*. -.. function:: sample(population, k) +.. function:: sample(population, k, *, counts=None) Return a *k* length list of unique elements chosen from the population sequence or set. Used for random sampling without replacement. @@ -231,6 +231,11 @@ Functions for sequences Members of the population need not be :term:`hashable` or unique. If the population contains repeats, then each occurrence is a possible selection in the sample. + Repeated elements can be specified one at a time or with the optional + keyword-only *counts* parameter. For example, ``sample(['red', 'blue'], + counts=[4, 2], k=5)`` is equivalent to ``sample(['red', 'red', 'red', 'red', + 'blue', 'blue'], k=5)``. + To choose a sample from a range of integers, use a :func:`range` object as an argument. This is especially fast and space efficient for sampling from a large population: ``sample(range(10000000), k=60)``. @@ -238,6 +243,9 @@ Functions for sequences If the sample size is larger than the population size, a :exc:`ValueError` is raised. + .. versionchanged:: 3.9 + Added the *counts* parameter. + .. deprecated:: 3.9 In the future, the *population* must be a sequence. Instances of :class:`set` are no longer supported. The set must first be converted @@ -420,12 +428,11 @@ Simulations:: >>> choices(['red', 'black', 'green'], [18, 18, 2], k=6) ['red', 'green', 'black', 'black', 'red', 'black'] - >>> # Deal 20 cards without replacement from a deck of 52 playing cards - >>> # and determine the proportion of cards with a ten-value - >>> # (a ten, jack, queen, or king). - >>> deck = collections.Counter(tens=16, low_cards=36) - >>> seen = sample(list(deck.elements()), k=20) - >>> seen.count('tens') / 20 + >>> # Deal 20 cards without replacement from a deck + >>> # of 52 playing cards, and determine the proportion of cards + >>> # with a ten-value: ten, jack, queen, or king. + >>> dealt = sample(['tens', 'low cards'], counts=[16, 36], k=20) + >>> dealt.count('tens') / 20 0.15 >>> # Estimate the probability of getting 5 or more heads from 7 spins diff --git a/Lib/random.py b/Lib/random.py index f2c4f39fb60..75f70d5d699 100644 --- a/Lib/random.py +++ b/Lib/random.py @@ -331,7 +331,7 @@ class Random(_random.Random): j = _int(random() * (i+1)) x[i], x[j] = x[j], x[i] - def sample(self, population, k): + def sample(self, population, k, *, counts=None): """Chooses k unique random elements from a population sequence or set. Returns a new list containing elements from the population while @@ -344,9 +344,21 @@ class Random(_random.Random): population contains repeats, then each occurrence is a possible selection in the sample. - To choose a sample in a range of integers, use range as an argument. - This is especially fast and space efficient for sampling from a - large population: sample(range(10000000), 60) + Repeated elements can be specified one at a time or with the optional + counts parameter. For example: + + sample(['red', 'blue'], counts=[4, 2], k=5) + + is equivalent to: + + sample(['red', 'red', 'red', 'red', 'blue', 'blue'], k=5) + + To choose a sample from a range of integers, use range() for the + population argument. This is especially fast and space efficient + for sampling from a large population: + + sample(range(10000000), 60) + """ # Sampling without replacement entails tracking either potential @@ -379,8 +391,20 @@ class Random(_random.Random): population = tuple(population) if not isinstance(population, _Sequence): raise TypeError("Population must be a sequence. For dicts or sets, use sorted(d).") - randbelow = self._randbelow n = len(population) + if counts is not None: + cum_counts = list(_accumulate(counts)) + if len(cum_counts) != n: + raise ValueError('The number of counts does not match the population') + total = cum_counts.pop() + if not isinstance(total, int): + raise TypeError('Counts must be integers') + if total <= 0: + raise ValueError('Total of counts must be greater than zero') + selections = sample(range(total), k=k) + bisect = _bisect + return [population[bisect(cum_counts, s)] for s in selections] + randbelow = self._randbelow if not 0 <= k <= n: raise ValueError("Sample larger than population or is negative") result = [None] * k diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py index bb95ca0884a..a3710f4aa48 100644 --- a/Lib/test/test_random.py +++ b/Lib/test/test_random.py @@ -9,7 +9,7 @@ from functools import partial from math import log, exp, pi, fsum, sin, factorial from test import support from fractions import Fraction - +from collections import Counter class TestBasicOps: # Superclass with tests common to all generators. @@ -161,6 +161,77 @@ class TestBasicOps: population = {10, 20, 30, 40, 50, 60, 70} self.gen.sample(population, k=5) + def test_sample_with_counts(self): + sample = self.gen.sample + + # General case + colors = ['red', 'green', 'blue', 'orange', 'black', 'brown', 'amber'] + counts = [500, 200, 20, 10, 5, 0, 1 ] + k = 700 + summary = Counter(sample(colors, counts=counts, k=k)) + self.assertEqual(sum(summary.values()), k) + for color, weight in zip(colors, counts): + self.assertLessEqual(summary[color], weight) + self.assertNotIn('brown', summary) + + # Case that exhausts the population + k = sum(counts) + summary = Counter(sample(colors, counts=counts, k=k)) + self.assertEqual(sum(summary.values()), k) + for color, weight in zip(colors, counts): + self.assertLessEqual(summary[color], weight) + self.assertNotIn('brown', summary) + + # Case with population size of 1 + summary = Counter(sample(['x'], counts=[10], k=8)) + self.assertEqual(summary, Counter(x=8)) + + # Case with all counts equal. + nc = len(colors) + summary = Counter(sample(colors, counts=[10]*nc, k=10*nc)) + self.assertEqual(summary, Counter(10*colors)) + + # Test error handling + with self.assertRaises(TypeError): + sample(['red', 'green', 'blue'], counts=10, k=10) # counts not iterable + with self.assertRaises(ValueError): + sample(['red', 'green', 'blue'], counts=[-3, -7, -8], k=2) # counts are negative + with self.assertRaises(ValueError): + sample(['red', 'green', 'blue'], counts=[0, 0, 0], k=2) # counts are zero + with self.assertRaises(ValueError): + sample(['red', 'green'], counts=[10, 10], k=21) # population too small + with self.assertRaises(ValueError): + sample(['red', 'green', 'blue'], counts=[1, 2], k=2) # too few counts + with self.assertRaises(ValueError): + sample(['red', 'green', 'blue'], counts=[1, 2, 3, 4], k=2) # too many counts + + def test_sample_counts_equivalence(self): + # Test the documented strong equivalence to a sample with repeated elements. + # We run this test on random.Random() which makes deterministic selections + # for a given seed value. + sample = random.sample + seed = random.seed + + colors = ['red', 'green', 'blue', 'orange', 'black', 'amber'] + counts = [500, 200, 20, 10, 5, 1 ] + k = 700 + seed(8675309) + s1 = sample(colors, counts=counts, k=k) + seed(8675309) + expanded = [color for (color, count) in zip(colors, counts) for i in range(count)] + self.assertEqual(len(expanded), sum(counts)) + s2 = sample(expanded, k=k) + self.assertEqual(s1, s2) + + pop = 'abcdefghi' + counts = [10, 9, 8, 7, 6, 5, 4, 3, 2] + seed(8675309) + s1 = ''.join(sample(pop, counts=counts, k=30)) + expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)]) + seed(8675309) + s2 = ''.join(sample(expanded, k=30)) + self.assertEqual(s1, s2) + def test_choices(self): choices = self.gen.choices data = ['red', 'green', 'blue', 'yellow'] diff --git a/Misc/NEWS.d/next/Library/2020-05-06-15-36-47.bpo-40541.LlYghL.rst b/Misc/NEWS.d/next/Library/2020-05-06-15-36-47.bpo-40541.LlYghL.rst new file mode 100644 index 00000000000..a2e694ac1ad --- /dev/null +++ b/Misc/NEWS.d/next/Library/2020-05-06-15-36-47.bpo-40541.LlYghL.rst @@ -0,0 +1 @@ +Added an optional *counts* parameter to random.sample().