Improved clarity and thoroughness of docstring.

Added design notes in comments.
Used better variable names.
Eliminated the unsavory "pool[-k:]" which was an aspiring bug (for k==0).
Used if/else to show the two algorithms in parallel style.
Added one more test assertion.
This commit is contained in:
Raymond Hettinger 2002-11-13 15:26:37 +00:00
parent 674dae245a
commit c0b4034b81
1 changed files with 41 additions and 20 deletions

View File

@ -377,39 +377,59 @@ class Random:
def sample(self, population, k, random=None, int=int):
"""Chooses k unique random elements from a population sequence.
Returns a new list containing elements from the population. The
list itself is in random order so that all sub-slices are also
random samples. The original sequence is left undisturbed.
Returns a new list containing elements from the population while
leaving the original population unchanged. The resulting list is
in selection order so that all sub-slices will also be valid random
samples. This allows raffle winners (the sample) to be partitioned
into grand prize and second place winners (the subslices).
If the population has repeated elements, then each occurrence is
a possible selection in the sample.
Members of the population need not be hashable or unique. If the
population contains repeats, then each occurrence is a possible
selection in the sample.
If indices are needed for a large population, use xrange as an
argument: sample(xrange(10000000), 60)
To choose a sample in a range of integers, use xrange as an argument.
This is especially fast and space efficient for sampling from a
large population: sample(xrange(10000000), 60)
Optional arg random is a 0-argument function returning a random
float in [0.0, 1.0); by default, the standard random.random.
"""
# Sampling without replacement entails tracking either potential
# selections (the pool) or previous selections.
# Pools are stored in lists which provide __getitem__ for selection
# and provide a way to remove selections. But each list.remove()
# rebuilds the entire list, so it is better to rearrange the list,
# placing non-selected elements at the head of the list. Tracking
# the selection pool is only space efficient with small populations.
# Previous selections are stored in dictionaries which provide
# __contains__ for detecting repeat selections. Discarding repeats
# is efficient unless most of the population has already been chosen.
# So, tracking selections is useful when sample sizes are much
# smaller than the total population.
n = len(population)
if not 0 <= k <= n:
raise ValueError, "sample larger than population"
if random is None:
random = self.random
result = [None] * k
if n < 6 * k: # if n len list takes less space than a k len dict
pool = list(population)
for i in xrange(n-1, n-k-1, -1):
j = int(random() * (i+1))
pool[i], pool[j] = pool[j], pool[i]
return pool[-k:]
inorder = [None] * k
selections = {}
for i in xrange(k):
j = int(random() * n)
while j in selections:
pool = list(population) # track potential selections
for i in xrange(k):
j = int(random() * (n-i)) # non-selected at [0,n-i)
result[i] = pool[j] # save selected element
pool[j] = pool[n-i-1] # non-selected to head of list
else:
selected = {} # track previous selections
for i in xrange(k):
j = int(random() * n)
selections[j] = inorder[i] = population[j]
return inorder # return selections in the order they were picked
while j in selected: # discard and replace repeats
j = int(random() * n)
result[i] = selected[j] = population[j]
return result # return selections in the order they were picked
## -------------------- real-valued distributions -------------------
@ -756,6 +776,7 @@ def _test_sample(n):
for k in xrange(n+1):
s = sample(population, k)
assert len(dict([(elem,True) for elem in s])) == len(s) == k
assert None not in s
def _sample_generator(n, k):
# Return a fixed element from the sample. Validates random ordering.
@ -787,7 +808,7 @@ def _test(N=2000):
_test_generator(N, 'weibullvariate(1.0, 1.0)')
_test_generator(N, '_sample_generator(50, 5)') # expected s.d.: 14.4
_test_generator(N, '_sample_generator(50, 45)') # expected s.d.: 14.4
_test_sample(1000)
_test_sample(500)
# Test jumpahead.
s = getstate()