Derive an industrial-strength conjoin() via cross-recursion loop unrolling,

and fiddle the conjoin tests to exercise all the new possible paths.
This commit is contained in:
Tim Peters 2001-06-30 07:29:44 +00:00
parent 4efb6e9643
commit c468fd28b6
1 changed files with 83 additions and 9 deletions

View File

@ -776,6 +776,62 @@ def conjoin(gs):
for x in gen(0):
yield x
# That works fine, but recursing a level and checking i against len(gs) for
# each item produced is inefficient. By doing manual loop unrolling across
# generator boundaries, it's possible to eliminate most of that overhead.
# This isn't worth the bother *in general* for generators, but conjoin() is
# a core building block for some CPU-intensive generator applications.
def conjoin(gs):
n = len(gs)
values = [None] * n
# Do one loop nest at time recursively, until the # of loop nests
# remaining is divisible by 3.
def gen(i, values=values):
if i >= n:
yield values
elif (n-i) % 3:
ip1 = i+1
for values[i] in gs[i]():
for x in gen(ip1):
yield x
else:
for x in _gen3(i):
yield x
# Do three loop nests at a time, recursing only if at least three more
# remain. Don't call directly: this is an internal optimization for
# gen's use.
def _gen3(i, values=values):
assert i < n and (n-i) % 3 == 0
ip1, ip2, ip3 = i+1, i+2, i+3
g, g1, g2 = gs[i : ip3]
if ip3 >= n:
# These are the last three, so we can yield values directly.
for values[i] in g():
for values[ip1] in g1():
for values[ip2] in g2():
yield values
else:
# At least 6 loop nests remain; peel off 3 and recurse for the
# rest.
for values[i] in g():
for values[ip1] in g1():
for values[ip2] in g2():
for x in _gen3(ip3):
yield x
for x in gen(0):
yield x
# A conjoin-based N-Queens solver.
class Queens:
@ -804,11 +860,10 @@ class Queens:
def rowgen(rowuses=rowuses):
for j in rangen:
uses = rowuses[j]
if uses & self.used:
continue
self.used |= uses
yield j
self.used &= ~uses
if uses & self.used == 0:
self.used |= uses
yield j
self.used &= ~uses
self.rowgenerators.append(rowgen)
@ -834,10 +889,7 @@ conjoin_tests = """
Generate the 3-bit binary numbers in order. This illustrates dumbest-
possible use of conjoin, just to generate the full cross-product.
>>> def g():
... return [0, 1]
>>> for c in conjoin([g] * 3):
>>> for c in conjoin([lambda: (0, 1)] * 3):
... print c
[0, 0, 0]
[0, 0, 1]
@ -848,6 +900,28 @@ possible use of conjoin, just to generate the full cross-product.
[1, 1, 0]
[1, 1, 1]
For efficiency in typical backtracking apps, conjoin() yields the same list
object each time. So if you want to save away a full account of its
generated sequence, you need to copy its results.
>>> def gencopy(iterator):
... for x in iterator:
... yield x[:]
>>> for n in range(10):
... all = list(gencopy(conjoin([lambda: (0, 1)] * n)))
... print n, len(all), all[0] == [0] * n, all[-1] == [1] * n
0 1 1 1
1 2 1 1
2 4 1 1
3 8 1 1
4 16 1 1
5 32 1 1
6 64 1 1
7 128 1 1
8 256 1 1
9 512 1 1
And run an 8-queens solver.
>>> q = Queens(8)