Simplify choice()'s interaction with the private _randbelow() method (GH-19831)

This commit is contained in:
Raymond Hettinger 2020-05-01 10:34:19 -07:00 committed by GitHub
parent 03b7642265
commit 4168f1e460
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 12 deletions

View File

@ -265,10 +265,10 @@ class Random(_random.Random):
return self.randrange(a, b+1) return self.randrange(a, b+1)
def _randbelow_with_getrandbits(self, n): def _randbelow_with_getrandbits(self, n):
"Return a random int in the range [0,n). Raises ValueError if n==0." "Return a random int in the range [0,n). Returns 0 if n==0."
if not n: if not n:
raise ValueError("Boundary cannot be zero") return 0
getrandbits = self.getrandbits getrandbits = self.getrandbits
k = n.bit_length() # don't use (n-1) here because n can be 1 k = n.bit_length() # don't use (n-1) here because n can be 1
r = getrandbits(k) # 0 <= r < 2**k r = getrandbits(k) # 0 <= r < 2**k
@ -277,7 +277,7 @@ class Random(_random.Random):
return r return r
def _randbelow_without_getrandbits(self, n, int=int, maxsize=1<<BPF): def _randbelow_without_getrandbits(self, n, int=int, maxsize=1<<BPF):
"""Return a random int in the range [0,n). Raises ValueError if n==0. """Return a random int in the range [0,n). Returns 0 if n==0.
The implementation does not use getrandbits, but only random. The implementation does not use getrandbits, but only random.
""" """
@ -289,7 +289,7 @@ class Random(_random.Random):
"To remove the range limitation, add a getrandbits() method.") "To remove the range limitation, add a getrandbits() method.")
return int(random() * n) return int(random() * n)
if n == 0: if n == 0:
raise ValueError("Boundary cannot be zero") return 0
rem = maxsize % n rem = maxsize % n
limit = (maxsize - rem) / maxsize # int(limit * maxsize) % n == 0 limit = (maxsize - rem) / maxsize # int(limit * maxsize) % n == 0
r = random() r = random()
@ -303,11 +303,7 @@ class Random(_random.Random):
def choice(self, seq): def choice(self, seq):
"""Choose a random element from a non-empty sequence.""" """Choose a random element from a non-empty sequence."""
try: return seq[self._randbelow(len(seq))] # raises IndexError if seq is empty
i = self._randbelow(len(seq))
except ValueError:
raise IndexError('Cannot choose from an empty sequence') from None
return seq[i]
def shuffle(self, x, random=None): def shuffle(self, x, random=None):
"""Shuffle list x in place, and return None. """Shuffle list x in place, and return None.

View File

@ -688,10 +688,10 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
maxsize+1, maxsize=maxsize maxsize+1, maxsize=maxsize
) )
self.gen._randbelow_without_getrandbits(5640, maxsize=maxsize) self.gen._randbelow_without_getrandbits(5640, maxsize=maxsize)
# issue 33203: test that _randbelow raises ValueError on # issue 33203: test that _randbelow returns zero on
# n == 0 also in its getrandbits-independent branch. # n == 0 also in its getrandbits-independent branch.
with self.assertRaises(ValueError): x = self.gen._randbelow_without_getrandbits(0, maxsize=maxsize)
self.gen._randbelow_without_getrandbits(0, maxsize=maxsize) self.assertEqual(x, 0)
# This might be going too far to test a single line, but because of our # This might be going too far to test a single line, but because of our
# noble aim of achieving 100% test coverage we need to write a case in # noble aim of achieving 100% test coverage we need to write a case in