Don't decrement below zero. And add more tests.

This commit is contained in:
Raymond Hettinger 2006-03-25 12:15:04 +00:00
parent ccc7bb4ef2
commit c4e94b90a8
2 changed files with 19 additions and 4 deletions

View File

@ -56,11 +56,12 @@ class Queue:
""" """
self.all_tasks_done.acquire() self.all_tasks_done.acquire()
try: try:
self.unfinished_tasks = unfinished = self.unfinished_tasks - 1 unfinished = self.unfinished_tasks - 1
if unfinished <= 0: if unfinished <= 0:
if unfinished < 0: if unfinished < 0:
raise ValueError('task_done() called too many times') raise ValueError('task_done() called too many times')
self.all_tasks_done.notifyAll() self.all_tasks_done.notifyAll()
self.unfinished_tasks = unfinished
finally: finally:
self.all_tasks_done.release() self.all_tasks_done.release()

View File

@ -228,6 +228,9 @@ def worker(q):
global cum global cum
while True: while True:
x = q.get() x = q.get()
if x is None:
q.task_done()
return
cumlock.acquire() cumlock.acquire()
try: try:
cum += x cum += x
@ -239,18 +242,29 @@ def QueueJoinTest(q):
global cum global cum
cum = 0 cum = 0
for i in (0,1): for i in (0,1):
t = threading.Thread(target=worker, args=(q,)) threading.Thread(target=worker, args=(q,)).start()
t.setDaemon(True)
t.start()
for i in xrange(100): for i in xrange(100):
q.put(i) q.put(i)
q.join() q.join()
verify(cum==sum(range(100)), "q.join() did not block until all tasks were done") verify(cum==sum(range(100)), "q.join() did not block until all tasks were done")
for i in (0,1):
q.put(None) # instruct the threads to close
q.join() # verify that you can join twice
def QueueTaskDoneTest(q)
try:
q.task_done()
except ValueError:
pass
else:
raise TestFailed("Did not detect task count going negative")
def test(): def test():
q = Queue.Queue() q = Queue.Queue()
QueueTaskDoneTest(q)
QueueJoinTest(q) QueueJoinTest(q)
QueueJoinTest(q) QueueJoinTest(q)
QueueTaskDoneTest(q)
q = Queue.Queue(QUEUE_SIZE) q = Queue.Queue(QUEUE_SIZE)
# Do it a couple of times on the same queue # Do it a couple of times on the same queue