diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py --- a/Lib/test/lock_tests.py +++ b/Lib/test/lock_tests.py @@ -874,3 +874,128 @@ b = self.barriertype(1) b.wait() b.wait() + +class SharableLockTests(RLockTests): + """ + Tests for SharableLock objects + """ + def test_many_readers(self): + lock = self.shlocktype() + N = 5 + locked = [] + nlocked = [] + def f(): + with lock.shared_lock(): + locked.append(1) + _wait() + nlocked.append(len(locked)) + _wait() + locked.pop(-1) + Bunch(f, N).wait_for_finished() + self.assertTrue(max(nlocked) > 1) + + def test_shared_recursion(self): + lock = self.shlocktype() + N = 5 + locked = [] + nlocked = [] + def f(): + with lock.shared_lock(): + with lock.shared_lock(): + locked.append(1) + _wait() + nlocked.append(len(locked)) + _wait() + locked.pop(-1) + Bunch(f, N).wait_for_finished() + self.assertTrue(max(nlocked) > 1) + + def test_exclusive_recursion(self): + lock = self.shlocktype() + N = 5 + locked = [] + nlocked = [] + def f(): + with lock.exclusive_lock(): + with lock.shared_lock(): + locked.append(1) + _wait() + nlocked.append(len(locked)) + _wait() + locked.pop(-1) + Bunch(f, N).wait_for_finished() + self.assertEqual(max(nlocked), 1) + + def test_exclusive_recursionfail(self): + lock = self.shlocktype() + N = 5 + locked = [] + def f(): + with lock.shared_lock(): + self.assertRaises(RuntimeError, lock.acquire_exclusive) + locked.append(1) + Bunch(f, N).wait_for_finished() + self.assertEqual(len(locked), N) + + def test_readers_writers(self): + lock = self.shlocktype() + N = 5 + rlocked = [] + wlocked = [] + nlocked = [] + def r(): + with lock.shared_lock(): + rlocked.append(1) + _wait() + nlocked.append((len(rlocked), len(wlocked))) + _wait() + rlocked.pop(-1) + def w(): + with lock.exclusive_lock(): + wlocked.append(1) + _wait() + nlocked.append((len(rlocked), len(wlocked))) + _wait() + wlocked.pop(-1) + b1 = Bunch(r, N) + b2 = Bunch(w, N) + b1.wait_for_finished() + b2.wait_for_finished() + r, w, = zip(*nlocked) + self.assertTrue(max(r) > 1) + self.assertEqual(max(w), 1) + for r, w in nlocked: + if w: + self.assertEqual(r, 0) + if r: + self.assertEqual(w, 0) + + def test_writer_success(self): + """Verify that a writer can get access""" + lock = self.shlocktype() + N = 5 + reads = 0 + writes = 0 + def r(): + # read until we achive write successes + nonlocal reads, writes + while writes < 2: + with lock.shared_lock(): + reads += 1 + def w(): + nonlocal reads, writes + while reads == 0: + _wait() + for i in range(2): + _wait() + with lock.exclusive_lock(): + writes += 1 + + b1 = Bunch(r, N) + b2 = Bunch(w, 1) + b1.wait_for_finished() + b2.wait_for_finished() + self.assertEqual(writes, 2) + # uncomment this to view performance + #print(writes, reads) + diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -815,6 +815,33 @@ class BarrierTests(lock_tests.BarrierTests): barriertype = staticmethod(threading.Barrier) +class SharableLockTests(lock_tests.SharableLockTests): + shlocktype = staticmethod(threading.SharableLock) + def locktype(self): + return self.shlocktype().exclusive_lock() + +class SharableConditionTests(lock_tests.ConditionTests): + shlocktype = staticmethod(threading.SharableLock) + def condtype(self, lock=None): + if lock: + return threading.Condition(lock) + return threading.Condition(self.shlocktype().exclusive_lock()) + +class SharableConditionAsRLockTests(lock_tests.RLockTests): + shlocktype = staticmethod(threading.SharableLock) + def locktype(self): + return threading.Condition(self.shlocktype().exclusive_lock()) + +class SimpleSharableLockMixin: + shlocktype = staticmethod(threading.SimpleSharableLock) + +class SimpleSharableLockTests(SharableLockTests, SimpleSharableLockMixin): pass + +class SimpleSharableConditionTests(SharableConditionTests, SimpleSharableLockMixin): pass + +class SimpleSharableConditionAsRLockTests(SharableConditionAsRLockTests, SimpleSharableLockMixin): pass + + def test_main(): test.support.run_unittest(LockTests, PyRLockTests, CRLockTests, EventTests, @@ -824,6 +851,12 @@ ThreadJoinOnShutdown, ThreadingExceptionTests, BarrierTests, + SharableLockTests, + SharableConditionTests, + SharableConditionAsRLockTests, + SimpleSharableLockTests, + SimpleSharableConditionTests, + SimpleSharableConditionAsRLockTests, ) if __name__ == "__main__": diff --git a/Lib/threading.py b/Lib/threading.py --- a/Lib/threading.py +++ b/Lib/threading.py @@ -502,6 +502,262 @@ class BrokenBarrierError(RuntimeError): pass + +class SharableLockBase(object): + """ + A sharable lock, that can be acquired either in shared mode or exclusive mode. + In shared mode, the lock may be held by any number of threads. + In exclusive mode, only one thread can hold the lock. + This is typically used to implement a reader/writer pattern, where multiple + "readers" may have hold a shared lock to a resource, but a "writer" will need + exclusive access to it. + A SharableLock is reentrant in a limited fashion: A thread holding the lock + can always get another shared lock. And a thread holding an exclusive lock + can get another lock (shared or exclusive.) + But in general, a thread holding a shared lock cannot recursively acquire an + exclusive lock. + Any acquire_shared() or acquire_exclusive must be matched with a release(). + """ + + + # Proxy classes that do either shared or exclusive locking + class _SharedLock: + def __init__(self, lock): + self.lock = lock + + @staticmethod + def _timeout(blocking, timeout): + # A few sanity checks to satisfy the unittests. + if timeout < 0 and timeout != -1: + raise ValueError("invalid timeout") + if timeout > TIMEOUT_MAX: + raise OverflowError + if blocking: + return timeout if timeout >= 0 else None + if timeout > 0: + raise ValueError("can't specify a timeout when non-blocking") + return 0 + + def acquire(self, blocking=True, timeout=-1): + return self.lock.acquire_shared(self._timeout(blocking, timeout)) + + def release(self): + self.lock.release() + + def __enter__(self): + self.acquire() + + def __exit__(self, exc, val, tb): + self.release() + + class _ExclusiveLock(_SharedLock): + def acquire(self, blocking=True, timeout=-1): + return self.lock.acquire_exclusive(self._timeout(blocking, timeout)) + + def _is_owned(self): + return self.lock._is_owned() + + def _release_save(self): + return self.lock._release_save() + + def _acquire_restore(self, arg): + return self.lock._acquire_restore(arg) + + def shared_lock(self): + """ + Return a proxy object that acquires and releases the lock in shared mode + """ + return self._SharedLock(self) + + def exclusive_lock(self): + """ + Return a proxy object that acquires and releases the lock in exclusive mode + """ + return self._ExclusiveLock(self) + +class SharableLock(SharableLockBase): + def __init__(self): + self.lock = Lock() # internal synchronization + self.cond_shr = Condition(self.lock) + self.cond_exc = Condition(self.lock) + self.nw = 0 # number of waiting threads + self.state = 0 # positive is shared count, negative exclusive count + self.owning = [] # threads will be few, so a list is not inefficient + + def acquire_exclusive(self, timeout=None): + """ + Acquire the lock in exclusive mode + """ + with self.lock: + self.nw += 1 + try: + return self.cond_exc.wait_for(self._acquire_exclusive, timeout) + finally: + self.nw -= 1 + + def _acquire_exclusive(self): + #we can only take the write lock if no one is there, or we already hold the lock + me = get_ident() + if self.state == 0 or (self.state < 0 and me in self.owning): + self.state -= 1 + self.owning.append(me) + return True + if self.state > 0 and me in self.owning: + raise RuntimeError("cannot recursively wrlock a rdlocked lock") + return False + + def acquire_shared(self, timeout=None): + """ + Acquire the lock in shared mode + """ + with self.lock: + return self.cond_shr.wait_for(self._acquire_shared, timeout) + + def _acquire_shared(self): + if self.state < 0: + # lock is in exclusive mode. See if it is ours and we can recurse + return self._acquire_exclusive() + + me = get_ident() + # Implement "exclusive bias" giving exclusive lock priority. + if not self.nw: + ok = True # no exclusive acquires waiting. + else: + # Recursion must have the highest priority + ok = me in self.owning + + if ok: + self.state += 1 + self.owning.append(me) + return ok + + def release(self): + """ + Release the lock + """ + me = get_ident() + with self.lock: + try: + self.owning.remove(me) + except ValueError: + raise RuntimeError("cannot upgrade lock from shared to exclusive") + + if self.state > 0: + self.state -= 1 + else: + self.state += 1 + if not self.state: + if self.nw: + self.cond_exc.notify() + else: + self.cond_shr.notify_all() + + # Interface for condition variable. Must hold an exclusive lock since the + # condition variable's state may be protected by the lock + def _is_owned(self): + return self.state < 0 and get_ident() in self.owning + + def _release_save(self): + # In a exlusively locked state, get the recursion level and free the lock + with self.lock: + if get_ident() not in self.owning: + raise RuntimeError("cannot release an un-acquired lock") + r = self.owning + self.owning = [] + self.state = 0 + if self.nw: + self.cond_exc.notify() + else: + self.cond_shr.notify_all() + return r + + def _acquire_restore(self, x): + # Reclaim the exclusive lock at the old recursion level + self.acquire_exclusive() + with self.lock: + self.owning = x + self.state = -len(x) + +class SimpleSharableLock(SharableLockBase): + def __init__(self): + self.cond = Condition() + self.state = 0 # positive is shared count, negative exclusive count + self.owning = [] # threads will be few, so a list is not inefficient + + def acquire_exclusive(self, timeout=None): + """ + Acquire the lock in exclusive mode + """ + with self.cond: + return self.cond.wait_for(self._acquire_exclusive, timeout) + + def _acquire_exclusive(self): + # we can only take the exclusive lock if no one is there, or we already hold the lock + me = get_ident() + if self.state == 0 or (self.state < 0 and me in self.owning): + self.state -= 1 + self.owning.append(me) + return True + if self.state > 0 and me in self.owning: + raise RuntimeError("cannot upgrade lock from shared to exclusive") + return False + + def acquire_shared(self, timeout=False): + """ + Acquire the lock in shared mode + """ + with self.cond: + return self.cond.wait_for(self._acquire_shared, timeout) + + def _acquire_shared(self): + if self.state < 0: + # We are write locked. See if we re-acquire + return self._acquire_exclusive() + + me = get_ident() + self.state += 1 + self.owning.append(me) + return True + + def release(self): + """ + Release the lock + """ + me = get_ident() + with self.cond: + try: + self.owning.remove(me) + except ValueError: + raise RuntimeError("cannot release un-acquired lock") + + if self.state > 0: + self.state -= 1 + else: + self.state += 1 + if not self.state: + self.cond.notify_all() + + # Interface for condition variable. Must hold an exclusive lock since the + # condition variable's state may be protected by the lock + def _is_owned(self): + return self.state < 0 and get_ident() in self.owning + + def _release_save(self): + # In a exlusively locked state, get the recursion level and free the lock + with self.cond: + r = self.owning + self.owning = [] + self.state = 0 + self.cond.notify_all() + return r + + def _acquire_restore(self, x): + # Reclaim the exclusive lock at the old recursion level + self.acquire_exclusive() + with self.cond: + self.owning = x + self.state = -len(x) + # Helper to generate new thread names _counter = 0 def _newname(template="Thread-%d"):