Tweak recipes and tests

This commit is contained in:
Raymond Hettinger 2008-03-07 01:33:20 +00:00
parent a1ca94a102
commit 38fb9bee6c
2 changed files with 36 additions and 13 deletions

View File

@ -662,15 +662,15 @@ which incur interpreter overhead. ::
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = tee(iterable)
try:
b.next()
except StopIteration:
pass
for elem in b:
break
return izip(a, b)
def grouper(n, iterable, padvalue=None):
def grouper(n, iterable, fillvalue=None):
"grouper(3, 'abcdefg', 'x') --> ('a','b','c'), ('d','e','f'), ('g','x','x')"
return izip(*[chain(iterable, repeat(padvalue, n-1))]*n)
args = [iter(iterable)] * n
kwds = dict(fillvalue=fillvalue)
return izip_longest(*args, **kwds)
def roundrobin(*iterables):
"roundrobin('abc', 'd', 'ef') --> 'a', 'd', 'e', 'b', 'f', 'c'"

View File

@ -410,6 +410,28 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
self.assertRaises(TypeError, product, range(6), None)
def product1(*args, **kwds):
pools = map(tuple, args) * kwds.get('repeat', 1)
n = len(pools)
if n == 0:
yield ()
return
if any(len(pool) == 0 for pool in pools):
return
indices = [0] * n
yield tuple(pool[i] for pool, i in zip(pools, indices))
while 1:
for i in reversed(range(n)): # right to left
if indices[i] == len(pools[i]) - 1:
continue
indices[i] += 1
for j in range(i+1, n):
indices[j] = 0
yield tuple(pool[i] for pool, i in zip(pools, indices))
break
else:
return
def product2(*args, **kwds):
'Pure python version used in docs'
pools = map(tuple, args) * kwds.get('repeat', 1)
@ -425,6 +447,7 @@ class TestBasicOps(unittest.TestCase):
args = [random.choice(argtypes) for j in range(random.randrange(5))]
expected_len = prod(map(len, args))
self.assertEqual(len(list(product(*args))), expected_len)
self.assertEqual(list(product(*args)), list(product1(*args)))
self.assertEqual(list(product(*args)), list(product2(*args)))
args = map(iter, args)
self.assertEqual(len(list(product(*args))), expected_len)
@ -1213,7 +1236,7 @@ Samuele
... return sum(imap(operator.mul, vec1, vec2))
>>> def flatten(listOfLists):
... return list(chain(*listOfLists))
... return list(chain.from_iterable(listOfLists))
>>> def repeatfunc(func, times=None, *args):
... "Repeat calls to func with specified arguments."
@ -1226,15 +1249,15 @@ Samuele
>>> def pairwise(iterable):
... "s -> (s0,s1), (s1,s2), (s2, s3), ..."
... a, b = tee(iterable)
... try:
... b.next()
... except StopIteration:
... pass
... for elem in b:
... break
... return izip(a, b)
>>> def grouper(n, iterable, padvalue=None):
>>> def grouper(n, iterable, fillvalue=None):
... "grouper(3, 'abcdefg', 'x') --> ('a','b','c'), ('d','e','f'), ('g','x','x')"
... return izip(*[chain(iterable, repeat(padvalue, n-1))]*n)
... args = [iter(iterable)] * n
... kwds = dict(fillvalue=fillvalue)
... return izip_longest(*args, **kwds)
>>> def roundrobin(*iterables):
... "roundrobin('abc', 'd', 'ef') --> 'a', 'd', 'e', 'b', 'f', 'c'"