From 3be00037d65178644b20a826f68eb3d0b25ccb5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 28 Oct 2010 09:43:10 +0000 Subject: [PATCH] issue 8777 Add threading.Barrier --- Doc/library/threading.rst | 104 ++++++++++++++++++++ Lib/test/lock_tests.py | 190 +++++++++++++++++++++++++++++++++++++ Lib/test/test_threading.py | 3 + Lib/threading.py | 172 +++++++++++++++++++++++++++++++++ 4 files changed, 469 insertions(+) diff --git a/Doc/library/threading.rst b/Doc/library/threading.rst index 7c8f709d7a4..64aa14b2d19 100644 --- a/Doc/library/threading.rst +++ b/Doc/library/threading.rst @@ -768,6 +768,110 @@ For example:: only work if the timer is still in its waiting stage. +.. _barrier-objects + +Barrier Objects +--------------- + +This class provides a simple synchronization primitive for use by a fixed +number of threads that need to wait for each other. Each of the threads +tries to pass the barrier by calling the :meth:`wait` method and will block +until all of the threads have made the call. +At this points, the threads are released simultanously. + +The barrier can be reused any number of times for the same number of threads. + +As an example, here is a simple way to synchronize a client and server thread:: + + b = Barrier(2, timeout=5) + server(): + start_server() + b.wait() + while True: + connection = accept_connection() + process_server_connection(connection) + + client(): + b.wait() + while True: + connection = make_connection() + process_client_connection(connection) + +.. class:: Barrier(parties, action=None, timeout=None) + + Create a barrier object for *parties* number of threads. An *action*, + when provided, is a callable to be called by one of the threads when + they are released. *timeout* is the default timeout value if none + is specified for the :meth:`wait` method. + + .. method:: wait(timeout=None) + + Pass the barrier. When all the threads party to the barrier have called + this function, they are all released simultaneously. If a *timeout* + is provided, is is used in preference to any that was supplied to the + class constructor. + + The return value is an integer in the range 0 to *parties*-1, different + for each thrad. This can be used to select a thread to do some special + housekeeping, eg: + + i = barrier.wait() + if i == 0: + # Only one thread needs to print this + print("passed the barrier") + + If an *action* was + provided to the constructor, one of the threads will have called it + prior to being released. Should this call raise an error, the barrier + is put into the broken state. + + If the call times out, the barrier is put into the broken state. + + This method may raise a :class:`BrokenBarrierError` exception if the + barrier is broken or reset while a thread is waiting + + .. method:: reset() + + Return the barrier to the default, empty state. Any threads waiting on + it will receive the :class:`BrokenBarrierError` exception. + + Note that using this function may can require some external + synchronization if there are other threads whose state is unknown. + If a barrier is broken it may be better to just leave it and create a + new one. + + .. method:: abort() + + Put the barrier into a broken state. This causes any active or future + calls to :meth:`wait` to fail with the :class:`BrokenBarrierError`. + Use this for example if one of the needs to abort, to avoid deadlocking + the application. + + It may be preferable to simply create the barrier with a sensible + *timeout* value to automatically guard against one of the threads + going awry. + + .. attribute:: parties + + The number of threads required to pass the barrier. + + .. attribute:: n_waiting + + The number of threads currently waiting in the barrier. + + .. attribute:: broken + + A boolean that is ``True`` if the barrier is in the broken state. + + .. versionadded:: 3.2 + +.. class:: BrokenBarrierError(RuntimeError) + + The exception raised when the :class:`Barrier` object is reset or broken. + + .. versionadded:: 3.2 + + .. _with-locks: Using locks, conditions, and semaphores in the :keyword:`with` statement diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py index 1ff6af0a7ab..f256a807fad 100644 --- a/Lib/test/lock_tests.py +++ b/Lib/test/lock_tests.py @@ -597,3 +597,193 @@ class BoundedSemaphoreTests(BaseSemaphoreTests): sem.acquire() sem.release() self.assertRaises(ValueError, sem.release) + + +class BarrierTests(BaseTestCase): + """ + Tests for Barrier objects. + """ + N = 5 + + def setUp(self): + self.barrier = self.barriertype(self.N, timeout=0.1) + def tearDown(self): + self.barrier.abort() + + def run_threads(self, f): + b = Bunch(f, self.N-1) + f() + b.wait_for_finished() + + def multipass(self, results, n): + m = self.barrier.parties + self.assertEqual(m, self.N) + for i in range(n): + results[0].append(True) + self.assertEqual(len(results[1]), i * m) + self.barrier.wait() + results[1].append(True) + self.assertEqual(len(results[0]), (i + 1) * m) + self.barrier.wait() + self.assertEqual(self.barrier.n_waiting, 0) + self.assertFalse(self.barrier.broken) + + def test_barrier(self, passes=1): + """ + Test that a barrier is passed in lockstep + """ + results = [[],[]] + def f(): + self.multipass(results, passes) + self.run_threads(f) + + def test_barrier_10(self): + """ + Test that a barrier works for 10 consecutive runs + """ + return self.test_barrier(10) + + def test_wait_return(self): + """ + test the return value from barrier.wait + """ + results = [] + def f(): + r = self.barrier.wait() + results.append(r) + + self.run_threads(f) + self.assertEqual(sum(results), sum(range(self.N))) + + def test_action(self): + """ + Test the 'action' callback + """ + results = [] + def action(): + results.append(True) + barrier = self.barriertype(self.N, action) + def f(): + barrier.wait() + self.assertEqual(len(results), 1) + + self.run_threads(f) + + def test_abort(self): + """ + Test that an abort will put the barrier in a broken state + """ + results1 = [] + results2 = [] + def f(): + try: + i = self.barrier.wait() + if i == self.N//2: + raise RuntimeError + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + self.barrier.abort() + pass + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertTrue(self.barrier.broken) + + def test_reset(self): + """ + Test that a 'reset' on a barrier frees the waiting threads + """ + results1 = [] + results2 = [] + results3 = [] + def f(): + i = self.barrier.wait() + if i == self.N//2: + # Wait until the other threads are all in the barrier. + while self.barrier.n_waiting < self.N-1: + time.sleep(0.001) + self.barrier.reset() + else: + try: + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + # Now, pass the barrier again + self.barrier.wait() + results3.append(True) + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + + def test_abort_and_reset(self): + """ + Test that a barrier can be reset after being broken. + """ + results1 = [] + results2 = [] + results3 = [] + barrier2 = self.barriertype(self.N) + def f(): + try: + i = self.barrier.wait() + if i == self.N//2: + raise RuntimeError + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + self.barrier.abort() + pass + # Synchronize and reset the barrier. Must synchronize first so + # that everyone has left it when we reset, and after so that no + # one enters it before the reset. + if barrier2.wait() == self.N//2: + self.barrier.reset() + barrier2.wait() + self.barrier.wait() + results3.append(True) + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + def test_timeout(self): + """ + Test wait(timeout) + """ + def f(): + i = self.barrier.wait() + if i == self.N // 2: + # One thread is late! + time.sleep(0.1) + # Default timeout is 0.1, so this is shorter. + self.assertRaises(threading.BrokenBarrierError, + self.barrier.wait, 0.05) + self.run_threads(f) + + def test_default_timeout(self): + """ + Test the barrier's default timeout + """ + def f(): + i = self.barrier.wait() + if i == self.N // 2: + # One thread is later than the default timeout of 0.1s. + time.sleep(0.15) + self.assertRaises(threading.BrokenBarrierError, self.barrier.wait) + self.run_threads(f) + + def test_single_thread(self): + b = self.barriertype(1) + b.wait() + b.wait() diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 62ad4af7ec5..a453ccc4909 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -555,6 +555,8 @@ class SemaphoreTests(lock_tests.SemaphoreTests): class BoundedSemaphoreTests(lock_tests.BoundedSemaphoreTests): semtype = staticmethod(threading.BoundedSemaphore) +class BarrierTests(lock_tests.BarrierTests): + barriertype = staticmethod(threading.Barrier) def test_main(): test.support.run_unittest(LockTests, PyRLockTests, CRLockTests, EventTests, @@ -563,6 +565,7 @@ def test_main(): ThreadTests, ThreadJoinOnShutdown, ThreadingExceptionTests, + BarrierTests ) if __name__ == "__main__": diff --git a/Lib/threading.py b/Lib/threading.py index 238a5c4508f..41956edce7a 100644 --- a/Lib/threading.py +++ b/Lib/threading.py @@ -392,6 +392,178 @@ class _Event(_Verbose): finally: self._cond.release() + +# A barrier class. Inspired in part by the pthread_barrier_* api and +# the CyclicBarrier class from Java. See +# http://sourceware.org/pthreads-win32/manual/pthread_barrier_init.html and +# http://java.sun.com/j2se/1.5.0/docs/api/java/util/concurrent/ +# CyclicBarrier.html +# for information. +# We maintain two main states, 'filling' and 'draining' enabling the barrier +# to be cyclic. Threads are not allowed into it until it has fully drained +# since the previous cycle. In addition, a 'resetting' state exists which is +# similar to 'draining' except that threads leave with a BrokenBarrierError, +# and a 'broken' state in which all threads get get the exception. +class Barrier(_Verbose): + """ + Barrier. Useful for synchronizing a fixed number of threads + at known synchronization points. Threads block on 'wait()' and are + simultaneously once they have all made that call. + """ + def __init__(self, parties, action=None, timeout=None, verbose=None): + """ + Create a barrier, initialised to 'parties' threads. + 'action' is a callable which, when supplied, will be called + by one of the threads after they have all entered the + barrier and just prior to releasing them all. + If a 'timeout' is provided, it is uses as the default for + all subsequent 'wait()' calls. + """ + _Verbose.__init__(self, verbose) + self._cond = Condition(Lock()) + self._action = action + self._timeout = timeout + self._parties = parties + self._state = 0 #0 filling, 1, draining, -1 resetting, -2 broken + self._count = 0 + + def wait(self, timeout=None): + """ + Wait for the barrier. When the specified number of threads have + started waiting, they are all simultaneously awoken. If an 'action' + was provided for the barrier, one of the threads will have executed + that callback prior to returning. + Returns an individual index number from 0 to 'parties-1'. + """ + if timeout is None: + timeout = self._timeout + with self._cond: + self._enter() # Block while the barrier drains. + index = self._count + self._count += 1 + try: + if index + 1 == self._parties: + # We release the barrier + self._release() + else: + # We wait until someone releases us + self._wait(timeout) + return index + finally: + self._count -= 1 + # Wake up any threads waiting for barrier to drain. + self._exit() + + # Block until the barrier is ready for us, or raise an exception + # if it is broken. + def _enter(self): + while self._state in (-1, 1): + # It is draining or resetting, wait until done + self._cond.wait() + #see if the barrier is in a broken state + if self._state < 0: + raise BrokenBarrierError + assert self._state == 0 + + # Optionally run the 'action' and release the threads waiting + # in the barrier. + def _release(self): + try: + if self._action: + self._action() + # enter draining state + self._state = 1 + self._cond.notify_all() + except: + #an exception during the _action handler. Break and reraise + self._break() + raise + + # Wait in the barrier until we are relased. Raise an exception + # if the barrier is reset or broken. + def _wait(self, timeout): + while self._state == 0: + if self._cond.wait(timeout) is False: + #timed out. Break the barrier + self._break() + raise BrokenBarrierError + if self._state < 0: + raise BrokenBarrierError + assert self._state == 1 + + # If we are the last thread to exit the barrier, signal any threads + # waiting for the barrier to drain. + def _exit(self): + if self._count == 0: + if self._state in (-1, 1): + #resetting or draining + self._state = 0 + self._cond.notify_all() + + def reset(self): + """ + Reset the barrier to the initial state. + Any threads currently waiting will get the BrokenBarrier exception + raised. + """ + with self._cond: + if self._count > 0: + if self._state == 0: + #reset the barrier, waking up threads + self._state = -1 + elif self._state == -2: + #was broken, set it to reset state + #which clears when the last thread exits + self._state = -1 + else: + self._state = 0 + self._cond.notify_all() + + def abort(self): + """ + Place the barrier into a 'broken' state. + Useful in case of error. Any currently waiting threads and + threads attempting to 'wait()' will have BrokenBarrierError + raised. + """ + with self._cond: + self._break() + + def _break(self): + # An internal error was detected. The barrier is set to + # a broken state all parties awakened. + self._state = -2 + self._cond.notify_all() + + @property + def parties(self): + """ + Return the number of threads required to trip the barrier. + """ + return self._parties + + @property + def n_waiting(self): + """ + Return the number of threads that are currently waiting at the barrier. + """ + # We don't need synchronization here since this is an ephemeral result + # anyway. It returns the correct value in the steady state. + if self._state == 0: + return self._count + return 0 + + @property + def broken(self): + """ + Return True if the barrier is in a broken state + """ + return self._state == -2 + +#exception raised by the Barrier class +class BrokenBarrierError(RuntimeError): pass + + # Helper to generate new thread names _counter = 0 def _newname(template="Thread-%d"):