Take Tim's advice and have random.sample() support only sequences and sets.

This commit is contained in:
Raymond Hettinger 2008-01-14 01:00:53 +00:00
parent 28de64fd0f
commit 1acde190b2
4 changed files with 22 additions and 49 deletions

View File

@ -111,8 +111,8 @@ Functions for sequences:
.. function:: sample(population, k) .. function:: sample(population, k)
Return a *k* length list of unique elements chosen from the population sequence. Return a *k* length list of unique elements chosen from the population sequence
Used for random sampling without replacement. or set. Used for random sampling without replacement.
Returns a new list containing elements from the population while leaving the Returns a new list containing elements from the population while leaving the
original population unchanged. The resulting list is in selection order so that original population unchanged. The resulting list is in selection order so that

View File

@ -267,7 +267,7 @@ class Random(_random.Random):
x[i], x[j] = x[j], x[i] x[i], x[j] = x[j], x[i]
def sample(self, population, k): def sample(self, population, k):
"""Chooses k unique random elements from a population sequence. """Chooses k unique random elements from a population sequence or set.
Returns a new list containing elements from the population while Returns a new list containing elements from the population while
leaving the original population unchanged. The resulting list is leaving the original population unchanged. The resulting list is
@ -284,15 +284,6 @@ class Random(_random.Random):
large population: sample(range(10000000), 60) large population: sample(range(10000000), 60)
""" """
# XXX Although the documentation says `population` is "a sequence",
# XXX attempts are made to cater to any iterable with a __len__
# XXX method. This has had mixed success. Examples from both
# XXX sides: sets work fine, and should become officially supported;
# XXX dicts are much harder, and have failed in various subtle
# XXX ways across attempts. Support for mapping types should probably
# XXX be dropped (and users should pass mapping.keys() or .values()
# XXX explicitly).
# Sampling without replacement entails tracking either potential # Sampling without replacement entails tracking either potential
# selections (the pool) in a list or previous selections in a set. # selections (the pool) in a list or previous selections in a set.
@ -303,37 +294,35 @@ class Random(_random.Random):
# preferred since the list takes less space than the # preferred since the list takes less space than the
# set and it doesn't suffer from frequent reselections. # set and it doesn't suffer from frequent reselections.
if isinstance(population, (set, frozenset)):
population = tuple(population)
if not hasattr(population, '__getitem__') or hasattr(population, 'keys'):
raise TypeError("Population must be a sequence or set. For dicts, use dict.keys().")
random = self.random
n = len(population) n = len(population)
if not 0 <= k <= n: if not 0 <= k <= n:
raise ValueError("sample larger than population") raise ValueError("Sample larger than population")
random = self.random
_int = int _int = int
result = [None] * k result = [None] * k
setsize = 21 # size of a small set minus size of an empty list setsize = 21 # size of a small set minus size of an empty list
if k > 5: if k > 5:
setsize += 4 ** _ceil(_log(k * 3, 4)) # table size for big sets setsize += 4 ** _ceil(_log(k * 3, 4)) # table size for big sets
if n <= setsize or hasattr(population, "keys"): if n <= setsize:
# An n-length list is smaller than a k-length set, or this is a # An n-length list is smaller than a k-length set
# mapping type so the other algorithm wouldn't work.
pool = list(population) pool = list(population)
for i in range(k): # invariant: non-selected at [0,n-i) for i in range(k): # invariant: non-selected at [0,n-i)
j = _int(random() * (n-i)) j = _int(random() * (n-i))
result[i] = pool[j] result[i] = pool[j]
pool[j] = pool[n-i-1] # move non-selected item into vacancy pool[j] = pool[n-i-1] # move non-selected item into vacancy
else: else:
try: selected = set()
selected = set() selected_add = selected.add
selected_add = selected.add for i in range(k):
for i in range(k): j = _int(random() * n)
while j in selected:
j = _int(random() * n) j = _int(random() * n)
while j in selected: selected_add(j)
j = _int(random() * n) result[i] = population[j]
selected_add(j)
result[i] = population[j]
except (TypeError, KeyError): # handle (at least) sets
if isinstance(population, list):
raise
return self.sample(tuple(population), k)
return result return result
## -------------------- real-valued distributions ------------------- ## -------------------- real-valued distributions -------------------

View File

@ -84,26 +84,7 @@ class TestBasicOps(unittest.TestCase):
self.gen.sample(tuple('abcdefghijklmnopqrst'), 2) self.gen.sample(tuple('abcdefghijklmnopqrst'), 2)
def test_sample_on_dicts(self): def test_sample_on_dicts(self):
self.gen.sample(dict.fromkeys('abcdefghijklmnopqrst'), 2) self.assertRaises(TypeError, self.gen.sample, dict.fromkeys('abcdef'), 2)
# SF bug #1460340 -- random.sample can raise KeyError
a = dict.fromkeys(list(range(10)) +
list(range(10,100,2)) +
list(range(100,110)))
self.gen.sample(a, 3)
# A followup to bug #1460340: sampling from a dict could return
# a subset of its keys or of its values, depending on the size of
# the subset requested.
N = 30
d = dict((i, complex(i, i)) for i in range(N))
for k in range(N+1):
samp = self.gen.sample(d, k)
# Verify that we got ints back (keys); the values are complex.
for x in samp:
self.assert_(type(x) is int)
samp.sort()
self.assertEqual(samp, list(range(N)))
def test_gauss(self): def test_gauss(self):
# Ensure that the seed() method initializes all the hidden state. In # Ensure that the seed() method initializes all the hidden state. In

View File

@ -355,6 +355,9 @@ Library
- Removed defunct parts of the random module (the Wichmann-Hill generator - Removed defunct parts of the random module (the Wichmann-Hill generator
and the jumpahead() method). and the jumpahead() method).
- random.sample() now explicitly supports all sequences and sets while
explicitly excluding mappings.
- Patch #467924: add ZipFile.extract() and ZipFile.extractall() in the - Patch #467924: add ZipFile.extract() and ZipFile.extractall() in the
zipfile module. zipfile module.