[Python-checkins] CVS: python/dist/src/Lib/test test_generators.py,1.14,1.15

Tim Peters tim_one@users.sourceforge.net
2001年6月30日 00:29:46 -0700


Update of /cvsroot/python/python/dist/src/Lib/test
In directory usw-pr-cvs1:/tmp/cvs-serv20431/python/dist/src/Lib/test
Modified Files:
	test_generators.py 
Log Message:
Derive an industrial-strength conjoin() via cross-recursion loop unrolling,
and fiddle the conjoin tests to exercise all the new possible paths.
Index: test_generators.py
===================================================================
RCS file: /cvsroot/python/python/dist/src/Lib/test/test_generators.py,v
retrieving revision 1.14
retrieving revision 1.15
diff -C2 -r1.14 -r1.15
*** test_generators.py	2001年06月29日 02:41:16	1.14
--- test_generators.py	2001年06月30日 07:29:44	1.15
***************
*** 777,780 ****
--- 777,836 ----
 yield x
 
+ # That works fine, but recursing a level and checking i against len(gs) for
+ # each item produced is inefficient. By doing manual loop unrolling across
+ # generator boundaries, it's possible to eliminate most of that overhead.
+ # This isn't worth the bother *in general* for generators, but conjoin() is
+ # a core building block for some CPU-intensive generator applications.
+ 
+ def conjoin(gs):
+ 
+ n = len(gs)
+ values = [None] * n
+ 
+ # Do one loop nest at time recursively, until the # of loop nests
+ # remaining is divisible by 3.
+ 
+ def gen(i, values=values):
+ if i >= n:
+ yield values
+ 
+ elif (n-i) % 3:
+ ip1 = i+1
+ for values[i] in gs[i]():
+ for x in gen(ip1):
+ yield x
+ 
+ else:
+ for x in _gen3(i):
+ yield x
+ 
+ # Do three loop nests at a time, recursing only if at least three more
+ # remain. Don't call directly: this is an internal optimization for
+ # gen's use.
+ 
+ def _gen3(i, values=values):
+ assert i < n and (n-i) % 3 == 0
+ ip1, ip2, ip3 = i+1, i+2, i+3
+ g, g1, g2 = gs[i : ip3]
+ 
+ if ip3 >= n:
+ # These are the last three, so we can yield values directly.
+ for values[i] in g():
+ for values[ip1] in g1():
+ for values[ip2] in g2():
+ yield values
+ 
+ else:
+ # At least 6 loop nests remain; peel off 3 and recurse for the
+ # rest.
+ for values[i] in g():
+ for values[ip1] in g1():
+ for values[ip2] in g2():
+ for x in _gen3(ip3):
+ yield x
+ 
+ for x in gen(0):
+ yield x
+ 
 # A conjoin-based N-Queens solver.
 
***************
*** 805,813 ****
 for j in rangen:
 uses = rowuses[j]
! if uses & self.used:
! continue
! self.used |= uses
! yield j
! self.used &= ~uses
 
 self.rowgenerators.append(rowgen)
--- 861,868 ----
 for j in rangen:
 uses = rowuses[j]
! if uses & self.used == 0:
! self.used |= uses
! yield j
! self.used &= ~uses
 
 self.rowgenerators.append(rowgen)
***************
*** 835,842 ****
 possible use of conjoin, just to generate the full cross-product.
 
! >>> def g():
! ... return [0, 1]
! 
! >>> for c in conjoin([g] * 3):
 ... print c
 [0, 0, 0]
--- 890,894 ----
 possible use of conjoin, just to generate the full cross-product.
 
! >>> for c in conjoin([lambda: (0, 1)] * 3):
 ... print c
 [0, 0, 0]
***************
*** 848,851 ****
--- 900,925 ----
 [1, 1, 0]
 [1, 1, 1]
+ 
+ For efficiency in typical backtracking apps, conjoin() yields the same list
+ object each time. So if you want to save away a full account of its
+ generated sequence, you need to copy its results.
+ 
+ >>> def gencopy(iterator):
+ ... for x in iterator:
+ ... yield x[:]
+ 
+ >>> for n in range(10):
+ ... all = list(gencopy(conjoin([lambda: (0, 1)] * n)))
+ ... print n, len(all), all[0] == [0] * n, all[-1] == [1] * n
+ 0 1 1 1
+ 1 2 1 1
+ 2 4 1 1
+ 3 8 1 1
+ 4 16 1 1
+ 5 32 1 1
+ 6 64 1 1
+ 7 128 1 1
+ 8 256 1 1
+ 9 512 1 1
 
 And run an 8-queens solver.

AltStyle によって変換されたページ (->オリジナル) /