Index: Lib/test/lock_tests.py =================================================================== --- Lib/test/lock_tests.py (revision 81379) +++ Lib/test/lock_tests.py (working copy) @@ -544,3 +544,95 @@ sem.acquire() sem.release() self.assertRaises(ValueError, sem.release) + + +class BarrierTests(BaseTestCase): + """ + Tests for Barrier objects. + """ + + def test_barrier(self): + # Pass a barrier once + N = 5 + results1 = [] + results2 = [] + barrier = self.barriertype(N) + def f(): + results1.append(True) + self.assertEqual(results2, [True] * 0) + barrier.enter() + results2.append(True) + self.assertEqual(results1, [True] * N) + + b = Bunch(f, N-1) + f() + b.wait_for_finished() + self.assertEqual(results1, [True] * N) + self.assertEqual(results2, [True] * N) + + def test_double_barrier(self): + # pass a barrier twice + N = 5 + results1 = [] + results2 = [] + results3 = [] + barrier = self.barriertype(N) + def f(): + results1.append(True) + self.assertEqual(results2, [True] * 0) + self.assertEqual(results3, [True] * 0) + barrier.enter() + results2.append(True) + self.assertEqual(results1, [True] * N) + self.assertEqual(results3, [True] * 0) + barrier.enter() + results3.append(True) + self.assertEqual(results1, [True] * N) + self.assertEqual(results2, [True] * N) + + b = Bunch(f, N-1) + f() + b.wait_for_finished() + self.assertEqual(results1, [True] * N) + self.assertEqual(results2, [True] * N) + self.assertEqual(results3, [True] * N) + + def test_exception(self): + # pass a barrier twice, but one thread dies after first pass + N = 5 + results1 = [] + results2 = [] + results3 = [] + barrier = self.barriertype(N) + def f(): + try: + results1.append(True) + self.assertEqual(results2, [True] * 0) + self.assertEqual(results3, [True] * 0) + barrier.enter() + results2.append(True) + if len(results2) == N/2: + raise RuntimeError + self.assertEqual(results1, [True] * N) + self.assertEqual(results3, [True] * 0) + barrier.enter() + results3.append(True) + self.assertEqual(results1, [True] * N) + self.assertEqual(results2, [True] * N) + except RuntimeError: + barrier.adjust_count(-1) + pass + + b = Bunch(f, N-1) + f() + b.wait_for_finished() + self.assertEqual(results1, [True] * N) + self.assertEqual(results2, [True] * N) + self.assertEqual(results3, [True] * (N-1)) + + def test_count(self): + N = 5 + barrier = self.barriertype(N) + self.assertEqual(barrier.adjust_count(0), 5) + self.assertEqual(barrier.adjust_count(1), 6) + self.assertEqual(barrier.adjust_count(-2), 4) Index: Lib/test/test_threading.py =================================================================== --- Lib/test/test_threading.py (revision 81379) +++ Lib/test/test_threading.py (working copy) @@ -534,6 +534,8 @@ class BoundedSemaphoreTests(lock_tests.BoundedSemaphoreTests): semtype = staticmethod(threading.BoundedSemaphore) +class BarrierTests(lock_tests.BarrierTests): + barriertype = staticmethod(threading.Barrier) def test_main(): test.test_support.run_unittest(LockTests, RLockTests, EventTests, @@ -542,6 +544,7 @@ ThreadTests, ThreadJoinOnShutdown, ThreadingExceptionTests, + BarrierTests ) if __name__ == "__main__": Index: Lib/threading.py =================================================================== --- Lib/threading.py (revision 81379) +++ Lib/threading.py (working copy) @@ -29,7 +29,8 @@ __all__ = ['activeCount', 'active_count', 'Condition', 'currentThread', 'current_thread', 'enumerate', 'Event', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Thread', - 'Timer', 'setprofile', 'settrace', 'local', 'stack_size'] + 'Timer', 'setprofile', 'settrace', 'local', 'stack_size', + 'Barrier'] _start_new_thread = thread.start_new_thread _allocate_lock = thread.allocate_lock @@ -396,6 +397,52 @@ finally: self.__cond.release() + +def Barrier(*args, **kwargs): + return _Barrier(*args, **kwargs) + +class _Barrier(_Verbose): + def __init__(self, count, verbose=None): + """ + Create a barrier, initialised to 'count' threads + """ + _Verbose.__init__(self, verbose) + self.__cond = Condition(Lock()) + self.__count = count + self.__entered = -1 + self.__released = -1 + + def enter(self): + """ + Enter the barrier. When 'count' threads have entered, they are all + released simultaneously. + """ + with self.__cond: + me = self.__entered + 1 + self.__entered = me + self._release() + while me > self.__released: + self.__cond.wait() + + def _release(self): + while self.__released + self.__count <= self.__entered: + self.__released += self.__count + self.__cond.notify_all() + + def adjust_count(self, adjust): + """ + Adjust the count of the barrier by the given integer. + Use this when a thread unexpectedly dies, to avoid deadlocking + the remaining threads. Returns the new count. + """ + with self.__cond: + self.__count += adjust + if self.__count: + #sometimes all threads fail and count drops to zero + self._release() + return self.__count + + # Helper to generate new thread names _counter = 0 def _newname(template="Thread-%d"):