issue 8777

Add threading.Barrier
This commit is contained in:
Kristján Valur Jónsson 2010-10-28 09:43:10 +00:00
parent 65ffae0aa3
commit 3be00037d6
4 changed files with 469 additions and 0 deletions

View File

@ -768,6 +768,110 @@ For example::
only work if the timer is still in its waiting stage.
.. _barrier-objects
Barrier Objects
---------------
This class provides a simple synchronization primitive for use by a fixed
number of threads that need to wait for each other. Each of the threads
tries to pass the barrier by calling the :meth:`wait` method and will block
until all of the threads have made the call.
At this points, the threads are released simultanously.
The barrier can be reused any number of times for the same number of threads.
As an example, here is a simple way to synchronize a client and server thread::
b = Barrier(2, timeout=5)
server():
start_server()
b.wait()
while True:
connection = accept_connection()
process_server_connection(connection)
client():
b.wait()
while True:
connection = make_connection()
process_client_connection(connection)
.. class:: Barrier(parties, action=None, timeout=None)
Create a barrier object for *parties* number of threads. An *action*,
when provided, is a callable to be called by one of the threads when
they are released. *timeout* is the default timeout value if none
is specified for the :meth:`wait` method.
.. method:: wait(timeout=None)
Pass the barrier. When all the threads party to the barrier have called
this function, they are all released simultaneously. If a *timeout*
is provided, is is used in preference to any that was supplied to the
class constructor.
The return value is an integer in the range 0 to *parties*-1, different
for each thrad. This can be used to select a thread to do some special
housekeeping, eg:
i = barrier.wait()
if i == 0:
# Only one thread needs to print this
print("passed the barrier")
If an *action* was
provided to the constructor, one of the threads will have called it
prior to being released. Should this call raise an error, the barrier
is put into the broken state.
If the call times out, the barrier is put into the broken state.
This method may raise a :class:`BrokenBarrierError` exception if the
barrier is broken or reset while a thread is waiting
.. method:: reset()
Return the barrier to the default, empty state. Any threads waiting on
it will receive the :class:`BrokenBarrierError` exception.
Note that using this function may can require some external
synchronization if there are other threads whose state is unknown.
If a barrier is broken it may be better to just leave it and create a
new one.
.. method:: abort()
Put the barrier into a broken state. This causes any active or future
calls to :meth:`wait` to fail with the :class:`BrokenBarrierError`.
Use this for example if one of the needs to abort, to avoid deadlocking
the application.
It may be preferable to simply create the barrier with a sensible
*timeout* value to automatically guard against one of the threads
going awry.
.. attribute:: parties
The number of threads required to pass the barrier.
.. attribute:: n_waiting
The number of threads currently waiting in the barrier.
.. attribute:: broken
A boolean that is ``True`` if the barrier is in the broken state.
.. versionadded:: 3.2
.. class:: BrokenBarrierError(RuntimeError)
The exception raised when the :class:`Barrier` object is reset or broken.
.. versionadded:: 3.2
.. _with-locks:
Using locks, conditions, and semaphores in the :keyword:`with` statement

View File

@ -597,3 +597,193 @@ class BoundedSemaphoreTests(BaseSemaphoreTests):
sem.acquire()
sem.release()
self.assertRaises(ValueError, sem.release)
class BarrierTests(BaseTestCase):
"""
Tests for Barrier objects.
"""
N = 5
def setUp(self):
self.barrier = self.barriertype(self.N, timeout=0.1)
def tearDown(self):
self.barrier.abort()
def run_threads(self, f):
b = Bunch(f, self.N-1)
f()
b.wait_for_finished()
def multipass(self, results, n):
m = self.barrier.parties
self.assertEqual(m, self.N)
for i in range(n):
results[0].append(True)
self.assertEqual(len(results[1]), i * m)
self.barrier.wait()
results[1].append(True)
self.assertEqual(len(results[0]), (i + 1) * m)
self.barrier.wait()
self.assertEqual(self.barrier.n_waiting, 0)
self.assertFalse(self.barrier.broken)
def test_barrier(self, passes=1):
"""
Test that a barrier is passed in lockstep
"""
results = [[],[]]
def f():
self.multipass(results, passes)
self.run_threads(f)
def test_barrier_10(self):
"""
Test that a barrier works for 10 consecutive runs
"""
return self.test_barrier(10)
def test_wait_return(self):
"""
test the return value from barrier.wait
"""
results = []
def f():
r = self.barrier.wait()
results.append(r)
self.run_threads(f)
self.assertEqual(sum(results), sum(range(self.N)))
def test_action(self):
"""
Test the 'action' callback
"""
results = []
def action():
results.append(True)
barrier = self.barriertype(self.N, action)
def f():
barrier.wait()
self.assertEqual(len(results), 1)
self.run_threads(f)
def test_abort(self):
"""
Test that an abort will put the barrier in a broken state
"""
results1 = []
results2 = []
def f():
try:
i = self.barrier.wait()
if i == self.N//2:
raise RuntimeError
self.barrier.wait()
results1.append(True)
except threading.BrokenBarrierError:
results2.append(True)
except RuntimeError:
self.barrier.abort()
pass
self.run_threads(f)
self.assertEqual(len(results1), 0)
self.assertEqual(len(results2), self.N-1)
self.assertTrue(self.barrier.broken)
def test_reset(self):
"""
Test that a 'reset' on a barrier frees the waiting threads
"""
results1 = []
results2 = []
results3 = []
def f():
i = self.barrier.wait()
if i == self.N//2:
# Wait until the other threads are all in the barrier.
while self.barrier.n_waiting < self.N-1:
time.sleep(0.001)
self.barrier.reset()
else:
try:
self.barrier.wait()
results1.append(True)
except threading.BrokenBarrierError:
results2.append(True)
# Now, pass the barrier again
self.barrier.wait()
results3.append(True)
self.run_threads(f)
self.assertEqual(len(results1), 0)
self.assertEqual(len(results2), self.N-1)
self.assertEqual(len(results3), self.N)
def test_abort_and_reset(self):
"""
Test that a barrier can be reset after being broken.
"""
results1 = []
results2 = []
results3 = []
barrier2 = self.barriertype(self.N)
def f():
try:
i = self.barrier.wait()
if i == self.N//2:
raise RuntimeError
self.barrier.wait()
results1.append(True)
except threading.BrokenBarrierError:
results2.append(True)
except RuntimeError:
self.barrier.abort()
pass
# Synchronize and reset the barrier. Must synchronize first so
# that everyone has left it when we reset, and after so that no
# one enters it before the reset.
if barrier2.wait() == self.N//2:
self.barrier.reset()
barrier2.wait()
self.barrier.wait()
results3.append(True)
self.run_threads(f)
self.assertEqual(len(results1), 0)
self.assertEqual(len(results2), self.N-1)
self.assertEqual(len(results3), self.N)
def test_timeout(self):
"""
Test wait(timeout)
"""
def f():
i = self.barrier.wait()
if i == self.N // 2:
# One thread is late!
time.sleep(0.1)
# Default timeout is 0.1, so this is shorter.
self.assertRaises(threading.BrokenBarrierError,
self.barrier.wait, 0.05)
self.run_threads(f)
def test_default_timeout(self):
"""
Test the barrier's default timeout
"""
def f():
i = self.barrier.wait()
if i == self.N // 2:
# One thread is later than the default timeout of 0.1s.
time.sleep(0.15)
self.assertRaises(threading.BrokenBarrierError, self.barrier.wait)
self.run_threads(f)
def test_single_thread(self):
b = self.barriertype(1)
b.wait()
b.wait()

View File

@ -555,6 +555,8 @@ class SemaphoreTests(lock_tests.SemaphoreTests):
class BoundedSemaphoreTests(lock_tests.BoundedSemaphoreTests):
semtype = staticmethod(threading.BoundedSemaphore)
class BarrierTests(lock_tests.BarrierTests):
barriertype = staticmethod(threading.Barrier)
def test_main():
test.support.run_unittest(LockTests, PyRLockTests, CRLockTests, EventTests,
@ -563,6 +565,7 @@ def test_main():
ThreadTests,
ThreadJoinOnShutdown,
ThreadingExceptionTests,
BarrierTests
)
if __name__ == "__main__":

View File

@ -392,6 +392,178 @@ class _Event(_Verbose):
finally:
self._cond.release()
# A barrier class. Inspired in part by the pthread_barrier_* api and
# the CyclicBarrier class from Java. See
# http://sourceware.org/pthreads-win32/manual/pthread_barrier_init.html and
# http://java.sun.com/j2se/1.5.0/docs/api/java/util/concurrent/
# CyclicBarrier.html
# for information.
# We maintain two main states, 'filling' and 'draining' enabling the barrier
# to be cyclic. Threads are not allowed into it until it has fully drained
# since the previous cycle. In addition, a 'resetting' state exists which is
# similar to 'draining' except that threads leave with a BrokenBarrierError,
# and a 'broken' state in which all threads get get the exception.
class Barrier(_Verbose):
"""
Barrier. Useful for synchronizing a fixed number of threads
at known synchronization points. Threads block on 'wait()' and are
simultaneously once they have all made that call.
"""
def __init__(self, parties, action=None, timeout=None, verbose=None):
"""
Create a barrier, initialised to 'parties' threads.
'action' is a callable which, when supplied, will be called
by one of the threads after they have all entered the
barrier and just prior to releasing them all.
If a 'timeout' is provided, it is uses as the default for
all subsequent 'wait()' calls.
"""
_Verbose.__init__(self, verbose)
self._cond = Condition(Lock())
self._action = action
self._timeout = timeout
self._parties = parties
self._state = 0 #0 filling, 1, draining, -1 resetting, -2 broken
self._count = 0
def wait(self, timeout=None):
"""
Wait for the barrier. When the specified number of threads have
started waiting, they are all simultaneously awoken. If an 'action'
was provided for the barrier, one of the threads will have executed
that callback prior to returning.
Returns an individual index number from 0 to 'parties-1'.
"""
if timeout is None:
timeout = self._timeout
with self._cond:
self._enter() # Block while the barrier drains.
index = self._count
self._count += 1
try:
if index + 1 == self._parties:
# We release the barrier
self._release()
else:
# We wait until someone releases us
self._wait(timeout)
return index
finally:
self._count -= 1
# Wake up any threads waiting for barrier to drain.
self._exit()
# Block until the barrier is ready for us, or raise an exception
# if it is broken.
def _enter(self):
while self._state in (-1, 1):
# It is draining or resetting, wait until done
self._cond.wait()
#see if the barrier is in a broken state
if self._state < 0:
raise BrokenBarrierError
assert self._state == 0
# Optionally run the 'action' and release the threads waiting
# in the barrier.
def _release(self):
try:
if self._action:
self._action()
# enter draining state
self._state = 1
self._cond.notify_all()
except:
#an exception during the _action handler. Break and reraise
self._break()
raise
# Wait in the barrier until we are relased. Raise an exception
# if the barrier is reset or broken.
def _wait(self, timeout):
while self._state == 0:
if self._cond.wait(timeout) is False:
#timed out. Break the barrier
self._break()
raise BrokenBarrierError
if self._state < 0:
raise BrokenBarrierError
assert self._state == 1
# If we are the last thread to exit the barrier, signal any threads
# waiting for the barrier to drain.
def _exit(self):
if self._count == 0:
if self._state in (-1, 1):
#resetting or draining
self._state = 0
self._cond.notify_all()
def reset(self):
"""
Reset the barrier to the initial state.
Any threads currently waiting will get the BrokenBarrier exception
raised.
"""
with self._cond:
if self._count > 0:
if self._state == 0:
#reset the barrier, waking up threads
self._state = -1
elif self._state == -2:
#was broken, set it to reset state
#which clears when the last thread exits
self._state = -1
else:
self._state = 0
self._cond.notify_all()
def abort(self):
"""
Place the barrier into a 'broken' state.
Useful in case of error. Any currently waiting threads and
threads attempting to 'wait()' will have BrokenBarrierError
raised.
"""
with self._cond:
self._break()
def _break(self):
# An internal error was detected. The barrier is set to
# a broken state all parties awakened.
self._state = -2
self._cond.notify_all()
@property
def parties(self):
"""
Return the number of threads required to trip the barrier.
"""
return self._parties
@property
def n_waiting(self):
"""
Return the number of threads that are currently waiting at the barrier.
"""
# We don't need synchronization here since this is an ephemeral result
# anyway. It returns the correct value in the steady state.
if self._state == 0:
return self._count
return 0
@property
def broken(self):
"""
Return True if the barrier is in a broken state
"""
return self._state == -2
#exception raised by the Barrier class
class BrokenBarrierError(RuntimeError): pass
# Helper to generate new thread names
_counter = 0
def _newname(template="Thread-%d"):