Index: Lib/test/lock_tests.py =================================================================== --- Lib/test/lock_tests.py (revision 85817) +++ Lib/test/lock_tests.py (working copy) @@ -591,3 +591,183 @@ sem.acquire() sem.release() self.assertRaises(ValueError, sem.release) + + +class BarrierTests(BaseTestCase): + """ + Tests for Barrier objects. + """ + def setUp(self): + self.N = 5 + self.barrier = self.barriertype(self.N, timeout=0.1) + def tearDown(self): + self.barrier.abort() + + def run_threads(self, f): + b = Bunch(f, self.N-1) + f() + b.wait_for_finished() + + def multipass(self, results, n): + m = self.barrier.get_parties() + for i in range(n): + results[0].append(True) + self.assertEqual(len(results[1]), i * m) + self.barrier.wait() + results[1].append(True) + self.assertEqual(len(results[0]), (i + 1) * m) + self.barrier.wait() + + def test_barrier(self, passes=1): + """ + Test that a barrier is passed in lockstep + """ + results = [[],[]] + def f(): + self.multipass(results, passes) + self.run_threads(f) + + def test_barrier_10(self): + """ + Test that a barrier works for 10 consecutive runs + """ + return self.test_barrier(10) + + def test_wait_return(self): + """ + test the return value from barrier.wait + """ + results = [] + def f(): + r = self.barrier.wait() + results.append(r) + + self.run_threads(f) + self.assertEqual(sum(results), sum(range(self.N))) + + def test_action(self): + """ + Test the 'action' callback + """ + results = [] + def action(): + results.append(True) + barrier = self.barriertype(self.N, action) + def f(): + barrier.wait() + self.assertEqual(len(results), 1) + + self.run_threads(f) + + def test_abort(self): + """ + Test that an abort will put the barrier in a broken state + """ + results1 = [] + results2 = [] + def f(): + try: + i = self.barrier.wait() + if i == self.N//2: + raise RuntimeError + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + self.barrier.abort() + pass + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + + def test_reset(self): + """ + Test that a 'reset' on a barrier frees the waiting threads + """ + results1 = [] + results2 = [] + results3 = [] + def f(): + i = self.barrier.wait() + if i == self.N//2: + # Wait until the other threads are all in the barrier. + while self.barrier.get_waiting() < self.N-1: + time.sleep(0.001) + self.barrier.reset() + else: + try: + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + # Now, pass the barrier again + self.barrier.wait() + results3.append(True) + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + + def test_abort_and_reset(self): + """ + Test that a barrier can be reset after being broken. + """ + results1 = [] + results2 = [] + results3 = [] + barrier2 = self.barriertype(self.N) + def f(): + try: + i = self.barrier.wait() + if i == self.N//2: + raise RuntimeError + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + self.barrier.abort() + pass + # Synchronize and reset the barrier. Must synchronize first so + # that everyone has left it when we reset, and after so that no + # one enters it before the reset. + if barrier2.wait() == self.N//2: + self.barrier.reset() + barrier2.wait() + self.barrier.wait() + results3.append(True) + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + def test_timeout(self): + """ + Test wait(timeout) + """ + def f(): + i = self.barrier.wait() + if i == self.N // 2: + # One thread is late! + time.sleep(0.1) + # Default timeout is 0.1, so this is shorter. + self.assertRaises(threading.BrokenBarrierError, + self.barrier.wait, 0.05) + self.run_threads(f) + + def test_default_timeout(self): + """ + Test the barrier's default timeout + """ + def f(): + i = self.barrier.wait() + if i == self.N // 2: + # One thread is later than the default timeout of 0.1s. + time.sleep(0.15) + self.assertRaises(threading.BrokenBarrierError, self.barrier.wait) + self.run_threads(f) Index: Lib/test/test_threading.py =================================================================== --- Lib/test/test_threading.py (revision 85817) +++ Lib/test/test_threading.py (working copy) @@ -555,6 +555,8 @@ class BoundedSemaphoreTests(lock_tests.BoundedSemaphoreTests): semtype = staticmethod(threading.BoundedSemaphore) +class BarrierTests(lock_tests.BarrierTests): + barriertype = staticmethod(threading.Barrier) def test_main(): test.support.run_unittest(LockTests, PyRLockTests, CRLockTests, EventTests, @@ -563,6 +565,7 @@ ThreadTests, ThreadJoinOnShutdown, ThreadingExceptionTests, + BarrierTests ) if __name__ == "__main__": Index: Lib/threading.py =================================================================== --- Lib/threading.py (revision 85817) +++ Lib/threading.py (working copy) @@ -232,6 +232,7 @@ try: # restore state no matter what (e.g., KeyboardInterrupt) if timeout is None: waiter.acquire() + gotit = True if __debug__: self._note("%s.wait(): got it", self) else: @@ -249,6 +250,8 @@ else: if __debug__: self._note("%s.wait(%s): got it", self, timeout) + #return true if timout occurred + return not gotit finally: self._acquire_restore(saved_state) @@ -390,6 +393,174 @@ finally: self._cond.release() + +# A barrier class. Inspired in part by the pthread_barrier_* api and +# the CyclicBarrier class from Java. See +# http://sourceware.org/pthreads-win32/manual/pthread_barrier_init.html and +# http://java.sun.com/j2se/1.5.0/docs/api/java/util/concurrent/ +# CyclicBarrier.html +# for information. +# We maintain two main states, 'filling' and 'draining' enabling the barrier +# to be cyclic. Threads are not allowed into it until it has fully drained +# since the previous cycle. In addition, a 'resetting' state exists which is +# similar to 'draining' except that threads leave with a BrokenBarrierError, +# and a 'broken' state in which all threads get get the exception. +class Barrier(_Verbose): + """ + Barrier. Useful for synchronizing a fixed number of threads + at known synchronization points. Threads block on 'wait()' and are + simultaneously once they have all made that call. + """ + def __init__(self, parties, action=None, timeout=None, verbose=None): + """ + Create a barrier, initialised to 'parties' threads. + 'action' is a callable which, when supplied, will be called + by one of the threads after they have all entered the + barrier and just prior to releasing them all. + If a 'timeout' is provided, it is uses as the default for + all subsequent 'wait()' calls. + """ + _Verbose.__init__(self, verbose) + self._cond = Condition(Lock()) + self._action = action + self._timeout = timeout + self._parties = parties + self._state = 0 #0 filling, 1, draining, -1 resetting, -2 broken + self._count = 0 + + def wait(self, timeout=None): + """ + Wait for the barrier. When the specified number of threads have + started waiting, they are all simultaneously awoken. If an 'action' + was provided for the barrier, one of the threads will have executed + that callback prior to returning. + Returns an individual index number from 0 to 'parties-1'. + """ + if timeout is None: + timeout = self._timeout + with self._cond: + self._enter() # Block while the barrier drains. + index = self._count + self._count += 1 + try: + if index + 1 == self._parties: + # We release the barrier + self._release() + else: + # We wait until someone releases us + self._wait(timeout) + return index + finally: + self._count -= 1 + # Wake up any threads waiting for barrier to drain. + self._exit() + + # Block until the barrier is ready for us, or raise an exception + # if it is broken. + def _enter(self): + while self._state in (-1, 1): + # It is draining or resetting, wait until done + self._cond.wait() + #see if the barrier is in a broken state + if self._state < 0: + raise BrokenBarrierError + assert self._state == 0 + + # Optionally run the 'action' and release the threads waiting + # in the barrier. + def _release(self): + try: + if self._action: + self._action() + # enter draining state + self._state = 1 + self._cond.notify_all() + except: + #an exception during the _action handler. Break and reraise + self._break() + raise + + # Wait in the barrier until we are relased. Raise an exception + # if the barrier is reset or broken. + def _wait(self, timeout): + while self._state == 0: + if self._cond.wait(timeout): + #timed out. Break the barrier + self._break() + raise BrokenBarrierError + if self._state < 0: + raise BrokenBarrierError + assert self._state == 1 + + # If we are the last thread to exit the barrier, signal any threads + # waiting for the barrier to drain. + def _exit(self): + if self._count == 0: + if self._state in (-1, 1): + #resetting or draining + self._state = 0 + self._cond.notify_all() + + def reset(self): + """ + Reset the barrier to the initial state. + Any threads currently waiting will get the BrokenBarrier exception + raised. + """ + with self._cond: + if self._count > 0: + if self._state == 0: + #reset the barrier, waking up threads + self._state = -1 + elif self._state == -2: + #was broken, set it to reset state + #which clears when the last thread exits + self._state = -1 + else: + self._state = 0 + self._cond.notify_all() + + def abort(self): + """ + Place the barrier into a 'broken' state. + Useful in case of error. Any currently waiting threads and + threads attempting to 'wait()' will have BrokenBarrierError + raised. + """ + with self._cond: + self._break() + + def _break(self): + # An internal error was detected. The barrier is set to + # a broken state all parties awakened. + self._state = -2 + self._cond.notify_all() + + def get_parties(self): + """ + Return the number of threads required to trip the barrier. + """ + return self._parties + + def get_waiting(self): + """ + Return the number of threads that are currently waiting at the barrier. + """ + with self._cond: + if self._state == 0: + return self._count + return 0 + + def is_broken(self): + """ + Return True if the barrier is in a broken state + """ + return self._state == -2 + +#exception raised by the Barrier class +class BrokenBarrierError(RuntimeError): pass + + # Helper to generate new thread names _counter = 0 def _newname(template="Thread-%d"):