# HG changeset patch # Parent b2f86880517fbb04ee501b7c2bea513929e34548 diff -r b2f86880517f Doc/library/multiprocessing.rst --- a/Doc/library/multiprocessing.rst Tue Jun 05 13:15:29 2012 +0100 +++ b/Doc/library/multiprocessing.rst Tue Jun 05 21:22:30 2012 +0100 @@ -226,11 +226,11 @@ holds Python objects and allows other processes to manipulate them using proxies. - A manager returned by :func:`Manager` will support types :class:`list`, - :class:`dict`, :class:`Namespace`, :class:`Lock`, :class:`RLock`, - :class:`Semaphore`, :class:`BoundedSemaphore`, :class:`Condition`, - :class:`Event`, :class:`Queue`, :class:`Value` and :class:`Array`. For - example, :: + A manager returned by :func:`Manager` will support types + :class:`list`, :class:`dict`, :class:`Namespace`, :class:`Lock`, + :class:`RLock`, :class:`Semaphore`, :class:`BoundedSemaphore`, + :class:`Condition`, :class:`Event`, :class:`Barrier`, + :class:`Queue`, :class:`Value` and :class:`Array`. For example, :: from multiprocessing import Process, Manager @@ -885,6 +885,12 @@ Note that one can also create synchronization primitives by using a manager object -- see :ref:`multiprocessing-managers`. +.. class:: Barrier(parties[, action[, timeout]]) + + A barrier object: a clone of :class:`threading.Barrier`. + + .. versionadded:: 3.3 + .. class:: BoundedSemaphore([value]) A bounded semaphore object: a clone of :class:`threading.BoundedSemaphore`. @@ -1279,6 +1285,13 @@ It also supports creation of shared lists and dictionaries. + .. method:: Barrier(parties[, action[, timeout]]) + + Create a shared :class:`threading.Barrier` object and return a + proxy for it. + + .. versionadded:: 3.3 + .. method:: BoundedSemaphore([value]) Create a shared :class:`threading.BoundedSemaphore` object and return a diff -r b2f86880517f Lib/multiprocessing/__init__.py --- a/Lib/multiprocessing/__init__.py Tue Jun 05 13:15:29 2012 +0100 +++ b/Lib/multiprocessing/__init__.py Tue Jun 05 21:22:30 2012 +0100 @@ -23,8 +23,8 @@ 'Manager', 'Pipe', 'cpu_count', 'log_to_stderr', 'get_logger', 'allow_connection_pickling', 'BufferTooShort', 'TimeoutError', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', - 'Event', 'Queue', 'SimpleQueue', 'JoinableQueue', 'Pool', 'Value', 'Array', - 'RawValue', 'RawArray', 'SUBDEBUG', 'SUBWARNING', + 'Event', 'Barrier', 'Queue', 'SimpleQueue', 'JoinableQueue', 'Pool', + 'Value', 'Array', 'RawValue', 'RawArray', 'SUBDEBUG', 'SUBWARNING', ] __author__ = 'R. Oudkerk (r.m.oudkerk@gmail.com)' @@ -186,6 +186,13 @@ from multiprocessing.synchronize import Event return Event() +def Barrier(parties, action=None, timeout=None): + ''' + Returns a barrier object + ''' + from multiprocessing.synchronize import Barrier + return Barrier(parties, action, timeout) + def Queue(maxsize=0): ''' Returns a queue object diff -r b2f86880517f Lib/multiprocessing/dummy/__init__.py --- a/Lib/multiprocessing/dummy/__init__.py Tue Jun 05 13:15:29 2012 +0100 +++ b/Lib/multiprocessing/dummy/__init__.py Tue Jun 05 21:22:30 2012 +0100 @@ -35,7 +35,7 @@ __all__ = [ 'Process', 'current_process', 'active_children', 'freeze_support', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', - 'Event', 'Queue', 'Manager', 'Pipe', 'Pool', 'JoinableQueue' + 'Event', 'Barrier', 'Queue', 'Manager', 'Pipe', 'Pool', 'JoinableQueue' ] # @@ -49,7 +49,7 @@ from multiprocessing.dummy.connection import Pipe from threading import Lock, RLock, Semaphore, BoundedSemaphore -from threading import Event, Condition +from threading import Event, Condition, Barrier from queue import Queue # diff -r b2f86880517f Lib/multiprocessing/managers.py --- a/Lib/multiprocessing/managers.py Tue Jun 05 13:15:29 2012 +0100 +++ b/Lib/multiprocessing/managers.py Tue Jun 05 21:22:30 2012 +0100 @@ -1006,6 +1006,26 @@ def wait(self, timeout=None): return self._callmethod('wait', (timeout,)) + +class BarrierProxy(BaseProxy): + _exposed_ = ('__getattribute__', 'wait', 'abort', 'reset') + def wait(self, timeout=None): + return self._callmethod('wait', (timeout,)) + def abort(self): + return self._callmethod('abort') + def reset(self): + return self._callmethod('reset') + @property + def parties(self): + return self._callmethod('__getattribute__', ('parties',)) + @property + def n_waiting(self): + return self._callmethod('__getattribute__', ('n_waiting',)) + @property + def broken(self): + return self._callmethod('__getattribute__', ('broken',)) + + class NamespaceProxy(BaseProxy): _exposed_ = ('__getattribute__', '__setattr__', '__delattr__') def __getattr__(self, key): @@ -1097,6 +1117,7 @@ SyncManager.register('BoundedSemaphore', threading.BoundedSemaphore, AcquirerProxy) SyncManager.register('Condition', threading.Condition, ConditionProxy) +SyncManager.register('Barrier', threading.Barrier, BarrierProxy) SyncManager.register('Pool', Pool, PoolProxy) SyncManager.register('list', list, ListProxy) SyncManager.register('dict', dict, DictProxy) diff -r b2f86880517f Lib/multiprocessing/synchronize.py --- a/Lib/multiprocessing/synchronize.py Tue Jun 05 13:15:29 2012 +0100 +++ b/Lib/multiprocessing/synchronize.py Tue Jun 05 21:22:30 2012 +0100 @@ -333,3 +333,43 @@ return False finally: self._cond.release() + +# +# Barrier +# + +class Barrier(threading.Barrier): + + def __init__(self, parties, action=None, timeout=None): + import struct + from multiprocessing.heap import BufferWrapper + wrapper = BufferWrapper(struct.calcsize('i') * 2) + cond = Condition() + self.__setstate__((parties, action, timeout, cond, wrapper)) + self._state = 0 + self._count = 0 + + def __setstate__(self, state): + (self._parties, self._action, self._timeout, + self._cond, self._wrapper) = state + self._array = self._wrapper.create_memoryview().cast('i') + + def __getstate__(self): + return (self._parties, self._action, self._timeout, + self._cond, self._wrapper) + + @property + def _state(self): + return self._array[0] + + @_state.setter + def _state(self, value): + self._array[0] = value + + @property + def _count(self): + return self._array[1] + + @_count.setter + def _count(self, value): + self._array[1] = value diff -r b2f86880517f Lib/test/test_multiprocessing.py --- a/Lib/test/test_multiprocessing.py Tue Jun 05 13:15:29 2012 +0100 +++ b/Lib/test/test_multiprocessing.py Tue Jun 05 21:22:30 2012 +0100 @@ -18,6 +18,7 @@ import socket import random import logging +import struct import test.support @@ -1027,6 +1028,336 @@ self.assertEqual(wait(), True) # +# Tests for Barrier - adapted from tests in test/lock_tests.py +# + +# Many of the tests for threading.Barrier use a list as an atomic +# counter: a value is appended to increment the counter, and the +# length of the list gives the value. We use the class DummyList +# for the same purpose. + +class _DummyList(object): + + def __init__(self): + wrapper = multiprocessing.heap.BufferWrapper(struct.calcsize('i')) + lock = multiprocessing.Lock() + self.__setstate__((wrapper, lock)) + self._lengthbuf[0] = 0 + + def __setstate__(self, state): + (self._wrapper, self._lock) = state + self._lengthbuf = self._wrapper.create_memoryview().cast('i') + + def __getstate__(self): + return (self._wrapper, self._lock) + + def append(self, _): + with self._lock: + self._lengthbuf[0] += 1 + + def __len__(self): + with self._lock: + return self._lengthbuf[0] + +def _wait(): + # A crude wait/yield function not relying on synchronization primitives. + time.sleep(0.01) + + +class Bunch(object): + """ + A bunch of threads. + """ + def __init__(self, namespace, f, args, n, wait_before_exit=False): + """ + Construct a bunch of `n` threads running the same function `f`. + If `wait_before_exit` is True, the threads won't terminate until + do_finish() is called. + """ + self.f = f + self.args = args + self.n = n + self.started = namespace.DummyList() + self.finished = namespace.DummyList() + self._can_exit = namespace.Value('i', not wait_before_exit) + for i in range(n): + namespace.Process(target=self.task).start() + + def task(self): + pid = os.getpid() + self.started.append(pid) + try: + self.f(*self.args) + finally: + self.finished.append(pid) + while not self._can_exit.value: + _wait() + + def wait_for_started(self): + while len(self.started) < self.n: + _wait() + + def wait_for_finished(self): + while len(self.finished) < self.n: + _wait() + + def do_finish(self): + self._can_exit.value = True + + +class AppendTrue(object): + def __init__(self, obj): + self.obj = obj + def __call__(self): + self.obj.append(True) + + +class _TestBarrier(BaseTestCase): + """ + Tests for Barrier objects. + """ + N = 5 + defaultTimeout = 10.0 # XXX Slow Windows buildbots need generous timeout + + def setUp(self): + self.barrier = self.Barrier(self.N, timeout=self.defaultTimeout) + + def tearDown(self): + self.barrier.abort() + self.barrier = None + + def DummyList(self): + if self.TYPE == 'threads': + return [] + elif self.TYPE == 'manager': + return self.manager.list() + else: + return _DummyList() + + def run_threads(self, f, args): + b = Bunch(self, f, args, self.N-1) + f(*args) + b.wait_for_finished() + + @classmethod + def multipass(cls, barrier, results, n): + m = barrier.parties + assert m == cls.N + for i in range(n): + results[0].append(True) + assert len(results[1]) == i * m + barrier.wait() + results[1].append(True) + assert len(results[0]) == (i + 1) * m + barrier.wait() + try: + assert barrier.n_waiting == 0 + except NotImplementedError: + pass + assert not barrier.broken + + def test_barrier(self, passes=1): + """ + Test that a barrier is passed in lockstep + """ + results = [self.DummyList(), self.DummyList()] + self.run_threads(self.multipass, (self.barrier, results, passes)) + + def test_barrier_10(self): + """ + Test that a barrier works for 10 consecutive runs + """ + return self.test_barrier(10) + + @classmethod + def _test_wait_return_f(cls, barrier, queue): + res = barrier.wait() + queue.put(res) + + def test_wait_return(self): + """ + test the return value from barrier.wait + """ + queue = self.Queue() + self.run_threads(self._test_wait_return_f, (self.barrier, queue)) + results = [queue.get() for i in range(self.N)] + self.assertEqual(results.count(0), 1) + + @classmethod + def _test_action_f(cls, barrier, results): + barrier.wait() + if len(results) != 1: + raise RuntimeError + + def test_action(self): + """ + Test the 'action' callback + """ + results = self.DummyList() + barrier = self.Barrier(self.N, action=AppendTrue(results)) + self.run_threads(self._test_action_f, (barrier, results)) + self.assertEqual(len(results), 1) + + @classmethod + def _test_abort_f(cls, barrier, results1, results2): + try: + i = barrier.wait() + if i == cls.N//2: + raise RuntimeError + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + barrier.abort() + + def test_abort(self): + """ + Test that an abort will put the barrier in a broken state + """ + results1 = self.DummyList() + results2 = self.DummyList() + self.run_threads(self._test_abort_f, + (self.barrier, results1, results2)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertTrue(self.barrier.broken) + + @classmethod + def _test_reset_f(cls, barrier, results1, results2, results3): + i = barrier.wait() + if i == cls.N//2: + # Wait until the other threads are all in the barrier. + while barrier.n_waiting < cls.N-1: + time.sleep(0.001) + barrier.reset() + else: + try: + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + # Now, pass the barrier again + barrier.wait() + results3.append(True) + + def test_reset(self): + """ + Test that a 'reset' on a barrier frees the waiting threads + """ + results1 = self.DummyList() + results2 = self.DummyList() + results3 = self.DummyList() + self.run_threads(self._test_reset_f, + (self.barrier, results1, results2, results3)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + @classmethod + def _test_abort_and_reset_f(cls, barrier, barrier2, + results1, results2, results3): + try: + i = barrier.wait() + if i == cls.N//2: + raise RuntimeError + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + barrier.abort() + # Synchronize and reset the barrier. Must synchronize first so + # that everyone has left it when we reset, and after so that no + # one enters it before the reset. + if barrier2.wait() == cls.N//2: + barrier.reset() + barrier2.wait() + barrier.wait() + results3.append(True) + + def test_abort_and_reset(self): + """ + Test that a barrier can be reset after being broken. + """ + results1 = self.DummyList() + results2 = self.DummyList() + results3 = self.DummyList() + barrier2 = self.Barrier(self.N) + + self.run_threads(self._test_abort_and_reset_f, + (self.barrier, barrier2, results1, results2, results3)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + @classmethod + def _test_timeout_f(cls, barrier, results): + i = barrier.wait(20) + if i == cls.N//2: + # One thread is late! + time.sleep(1.0) + try: + barrier.wait(0.5) + except threading.BrokenBarrierError: + results.append(True) + + def test_timeout(self): + """ + Test wait(timeout) + """ + results = self.DummyList() + self.run_threads(self._test_timeout_f, (self.barrier, results)) + self.assertEqual(len(results), self.barrier.parties) + + @classmethod + def _test_default_timeout_f(cls, barrier, results): + i = barrier.wait(20) + if i == cls.N//2: + # One thread is later than the default timeout + time.sleep(2.0) + try: + barrier.wait() + except threading.BrokenBarrierError: + results.append(True) + + def test_default_timeout(self): + """ + Test the barrier's default timeout + """ + barrier = self.Barrier(self.N, timeout=1.0) + results = self.DummyList() + self.run_threads(self._test_default_timeout_f, (barrier, results)) + self.assertEqual(len(results), barrier.parties) + + def test_single_thread(self): + b = self.Barrier(1) + b.wait() + b.wait() + + @classmethod + def _test_thousand_f(cls, barrier, passes, conn, lock): + for i in range(passes): + barrier.wait() + with lock: + conn.send(i) + + def test_thousand(self): + if self.TYPE == 'manager': + return + passes = 1000 + lock = self.Lock() + conn, child_conn = self.Pipe(False) + for j in range(self.N): + p = self.Process(target=self._test_thousand_f, + args=(self.barrier, passes, child_conn, lock)) + p.start() + + for i in range(passes): + for j in range(self.N): + self.assertEqual(conn.recv(), i) + +# # # @@ -2485,7 +2816,7 @@ Process = multiprocessing.Process locals().update(get_attributes(multiprocessing, ( 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'RawValue', + 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'RawValue', 'RawArray', 'current_process', 'active_children', 'Pipe', 'connection', 'JoinableQueue' ))) @@ -2500,7 +2831,7 @@ manager = object.__new__(multiprocessing.managers.SyncManager) locals().update(get_attributes(manager, ( 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'list', 'dict', + 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'list', 'dict', 'Namespace', 'JoinableQueue' ))) @@ -2513,7 +2844,7 @@ Process = multiprocessing.dummy.Process locals().update(get_attributes(multiprocessing.dummy, ( 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'current_process', + 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'current_process', 'active_children', 'Pipe', 'connection', 'dict', 'list', 'Namespace', 'JoinableQueue' )))