mirror of https://github.com/python/cpython
Derive an industrial-strength conjoin() via cross-recursion loop unrolling,
and fiddle the conjoin tests to exercise all the new possible paths.
This commit is contained in:
parent
4efb6e9643
commit
c468fd28b6
|
@ -776,6 +776,62 @@ def conjoin(gs):
|
|||
for x in gen(0):
|
||||
yield x
|
||||
|
||||
# That works fine, but recursing a level and checking i against len(gs) for
|
||||
# each item produced is inefficient. By doing manual loop unrolling across
|
||||
# generator boundaries, it's possible to eliminate most of that overhead.
|
||||
# This isn't worth the bother *in general* for generators, but conjoin() is
|
||||
# a core building block for some CPU-intensive generator applications.
|
||||
|
||||
def conjoin(gs):
|
||||
|
||||
n = len(gs)
|
||||
values = [None] * n
|
||||
|
||||
# Do one loop nest at time recursively, until the # of loop nests
|
||||
# remaining is divisible by 3.
|
||||
|
||||
def gen(i, values=values):
|
||||
if i >= n:
|
||||
yield values
|
||||
|
||||
elif (n-i) % 3:
|
||||
ip1 = i+1
|
||||
for values[i] in gs[i]():
|
||||
for x in gen(ip1):
|
||||
yield x
|
||||
|
||||
else:
|
||||
for x in _gen3(i):
|
||||
yield x
|
||||
|
||||
# Do three loop nests at a time, recursing only if at least three more
|
||||
# remain. Don't call directly: this is an internal optimization for
|
||||
# gen's use.
|
||||
|
||||
def _gen3(i, values=values):
|
||||
assert i < n and (n-i) % 3 == 0
|
||||
ip1, ip2, ip3 = i+1, i+2, i+3
|
||||
g, g1, g2 = gs[i : ip3]
|
||||
|
||||
if ip3 >= n:
|
||||
# These are the last three, so we can yield values directly.
|
||||
for values[i] in g():
|
||||
for values[ip1] in g1():
|
||||
for values[ip2] in g2():
|
||||
yield values
|
||||
|
||||
else:
|
||||
# At least 6 loop nests remain; peel off 3 and recurse for the
|
||||
# rest.
|
||||
for values[i] in g():
|
||||
for values[ip1] in g1():
|
||||
for values[ip2] in g2():
|
||||
for x in _gen3(ip3):
|
||||
yield x
|
||||
|
||||
for x in gen(0):
|
||||
yield x
|
||||
|
||||
# A conjoin-based N-Queens solver.
|
||||
|
||||
class Queens:
|
||||
|
@ -804,8 +860,7 @@ class Queens:
|
|||
def rowgen(rowuses=rowuses):
|
||||
for j in rangen:
|
||||
uses = rowuses[j]
|
||||
if uses & self.used:
|
||||
continue
|
||||
if uses & self.used == 0:
|
||||
self.used |= uses
|
||||
yield j
|
||||
self.used &= ~uses
|
||||
|
@ -834,10 +889,7 @@ conjoin_tests = """
|
|||
Generate the 3-bit binary numbers in order. This illustrates dumbest-
|
||||
possible use of conjoin, just to generate the full cross-product.
|
||||
|
||||
>>> def g():
|
||||
... return [0, 1]
|
||||
|
||||
>>> for c in conjoin([g] * 3):
|
||||
>>> for c in conjoin([lambda: (0, 1)] * 3):
|
||||
... print c
|
||||
[0, 0, 0]
|
||||
[0, 0, 1]
|
||||
|
@ -848,6 +900,28 @@ possible use of conjoin, just to generate the full cross-product.
|
|||
[1, 1, 0]
|
||||
[1, 1, 1]
|
||||
|
||||
For efficiency in typical backtracking apps, conjoin() yields the same list
|
||||
object each time. So if you want to save away a full account of its
|
||||
generated sequence, you need to copy its results.
|
||||
|
||||
>>> def gencopy(iterator):
|
||||
... for x in iterator:
|
||||
... yield x[:]
|
||||
|
||||
>>> for n in range(10):
|
||||
... all = list(gencopy(conjoin([lambda: (0, 1)] * n)))
|
||||
... print n, len(all), all[0] == [0] * n, all[-1] == [1] * n
|
||||
0 1 1 1
|
||||
1 2 1 1
|
||||
2 4 1 1
|
||||
3 8 1 1
|
||||
4 16 1 1
|
||||
5 32 1 1
|
||||
6 64 1 1
|
||||
7 128 1 1
|
||||
8 256 1 1
|
||||
9 512 1 1
|
||||
|
||||
And run an 8-queens solver.
|
||||
|
||||
>>> q = Queens(8)
|
||||
|
|
Loading…
Reference in New Issue