diff --git a/Doc/library/random.rst b/Doc/library/random.rst index 6dc54d28776..330cce15b8c 100644 --- a/Doc/library/random.rst +++ b/Doc/library/random.rst @@ -124,6 +124,27 @@ Functions for sequences: Return a random element from the non-empty sequence *seq*. If *seq* is empty, raises :exc:`IndexError`. +.. function:: weighted_choices(k, population, weights=None, *, cum_weights=None) + + Return a *k* sized list of elements chosen from the *population* with replacement. + If the *population* is empty, raises :exc:`IndexError`. + + If a *weights* sequence is specified, selections are made according to the + relative weights. Alternatively, if a *cum_weights* sequence is given, the + selections are made according to the cumulative weights. For example, the + relative weights ``[10, 5, 30, 5]`` are equivalent to the cumulative + weights ``[10, 15, 45, 50]``. Internally, the relative weights are + converted to cumulative weights before making selections, so supplying the + cumulative weights saves work. + + If neither *weights* nor *cum_weights* are specified, selections are made + with equal probability. If a weights sequence is supplied, it must be + the same length as the *population* sequence. It is a :exc:`TypeError` + to specify both *weights* and *cum_weights*. + + The *weights* or *cum_weights* can use any numeric type that interoperates + with the :class:`float` values returned by :func:`random` (that includes + integers, floats, and fractions but excludes decimals). .. function:: shuffle(x[, random]) diff --git a/Lib/random.py b/Lib/random.py index 82f6013b1fe..136395e938f 100644 --- a/Lib/random.py +++ b/Lib/random.py @@ -8,6 +8,7 @@ --------- pick random element pick random sample + pick weighted random sample generate random permutation distributions on the real line: @@ -43,12 +44,14 @@ from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin from os import urandom as _urandom from _collections_abc import Set as _Set, Sequence as _Sequence from hashlib import sha512 as _sha512 +import itertools as _itertools +import bisect as _bisect __all__ = ["Random","seed","random","uniform","randint","choice","sample", "randrange","shuffle","normalvariate","lognormvariate", "expovariate","vonmisesvariate","gammavariate","triangular", "gauss","betavariate","paretovariate","weibullvariate", - "getstate","setstate", "getrandbits", + "getstate","setstate", "getrandbits", "weighted_choices", "SystemRandom"] NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0) @@ -334,6 +337,28 @@ class Random(_random.Random): result[i] = population[j] return result + def weighted_choices(self, k, population, weights=None, *, cum_weights=None): + """Return a k sized list of population elements chosen with replacement. + + If the relative weights or cumulative weights are not specified, + the selections are made with equal probability. + + """ + if cum_weights is None: + if weights is None: + choice = self.choice + return [choice(population) for i in range(k)] + else: + cum_weights = list(_itertools.accumulate(weights)) + elif weights is not None: + raise TypeError('Cannot specify both weights and cumulative_weights') + if len(cum_weights) != len(population): + raise ValueError('The number of weights does not match the population') + bisect = _bisect.bisect + random = self.random + total = cum_weights[-1] + return [population[bisect(cum_weights, random() * total)] for i in range(k)] + ## -------------------- real-valued distributions ------------------- ## -------------------- uniform distribution ------------------- @@ -724,6 +749,7 @@ choice = _inst.choice randrange = _inst.randrange sample = _inst.sample shuffle = _inst.shuffle +weighted_choices = _inst.weighted_choices normalvariate = _inst.normalvariate lognormvariate = _inst.lognormvariate expovariate = _inst.expovariate diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py index e80ed17a8cb..b3741a8845a 100644 --- a/Lib/test/test_random.py +++ b/Lib/test/test_random.py @@ -7,6 +7,7 @@ import warnings from functools import partial from math import log, exp, pi, fsum, sin from test import support +from fractions import Fraction class TestBasicOps: # Superclass with tests common to all generators. @@ -141,6 +142,73 @@ class TestBasicOps: def test_sample_on_dicts(self): self.assertRaises(TypeError, self.gen.sample, dict.fromkeys('abcdef'), 2) + def test_weighted_choices(self): + weighted_choices = self.gen.weighted_choices + data = ['red', 'green', 'blue', 'yellow'] + str_data = 'abcd' + range_data = range(4) + set_data = set(range(4)) + + # basic functionality + for sample in [ + weighted_choices(5, data), + weighted_choices(5, data, range(4)), + weighted_choices(k=5, population=data, weights=range(4)), + weighted_choices(k=5, population=data, cum_weights=range(4)), + ]: + self.assertEqual(len(sample), 5) + self.assertEqual(type(sample), list) + self.assertTrue(set(sample) <= set(data)) + + # test argument handling + with self.assertRaises(TypeError): # missing arguments + weighted_choices(2) + + self.assertEqual(weighted_choices(0, data), []) # k == 0 + self.assertEqual(weighted_choices(-1, data), []) # negative k behaves like ``[0] * -1`` + with self.assertRaises(TypeError): + weighted_choices(2.5, data) # k is a float + + self.assertTrue(set(weighted_choices(5, str_data)) <= set(str_data)) # population is a string sequence + self.assertTrue(set(weighted_choices(5, range_data)) <= set(range_data)) # population is a range + with self.assertRaises(TypeError): + weighted_choices(2.5, set_data) # population is not a sequence + + self.assertTrue(set(weighted_choices(5, data, None)) <= set(data)) # weights is None + self.assertTrue(set(weighted_choices(5, data, weights=None)) <= set(data)) + with self.assertRaises(ValueError): + weighted_choices(5, data, [1,2]) # len(weights) != len(population) + with self.assertRaises(IndexError): + weighted_choices(5, data, [0]*4) # weights sum to zero + with self.assertRaises(TypeError): + weighted_choices(5, data, 10) # non-iterable weights + with self.assertRaises(TypeError): + weighted_choices(5, data, [None]*4) # non-numeric weights + for weights in [ + [15, 10, 25, 30], # integer weights + [15.1, 10.2, 25.2, 30.3], # float weights + [Fraction(1, 3), Fraction(2, 6), Fraction(3, 6), Fraction(4, 6)], # fractional weights + [True, False, True, False] # booleans (include / exclude) + ]: + self.assertTrue(set(weighted_choices(5, data, weights)) <= set(data)) + + with self.assertRaises(ValueError): + weighted_choices(5, data, cum_weights=[1,2]) # len(weights) != len(population) + with self.assertRaises(IndexError): + weighted_choices(5, data, cum_weights=[0]*4) # cum_weights sum to zero + with self.assertRaises(TypeError): + weighted_choices(5, data, cum_weights=10) # non-iterable cum_weights + with self.assertRaises(TypeError): + weighted_choices(5, data, cum_weights=[None]*4) # non-numeric cum_weights + with self.assertRaises(TypeError): + weighted_choices(5, data, range(4), cum_weights=range(4)) # both weights and cum_weights + for weights in [ + [15, 10, 25, 30], # integer cum_weights + [15.1, 10.2, 25.2, 30.3], # float cum_weights + [Fraction(1, 3), Fraction(2, 6), Fraction(3, 6), Fraction(4, 6)], # fractional cum_weights + ]: + self.assertTrue(set(weighted_choices(5, data, cum_weights=weights)) <= set(data)) + def test_gauss(self): # Ensure that the seed() method initializes all the hidden state. In # particular, through 2.2.1 it failed to reset a piece of state used diff --git a/Misc/NEWS b/Misc/NEWS index e913ef81d28..fbf7b2b9764 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -101,6 +101,8 @@ Library - Issue #27691: Fix ssl module's parsing of GEN_RID subject alternative name fields in X.509 certs. +- Issue #18844: Add random.weighted_choices(). + - Issue #25761: Improved error reporting about truncated pickle data in C implementation of unpickler. UnpicklingError is now raised instead of AttributeError and ValueError in some cases.