diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py --- a/Lib/test/lock_tests.py +++ b/Lib/test/lock_tests.py @@ -19,7 +19,7 @@ """ A bunch of threads. """ - def __init__(self, f, n, wait_before_exit=False): + def __init__(self, f, n, wait_before_exit=False, index=False): """ Construct a bunch of `n` threads running the same function `f`. If `wait_before_exit` is True, the threads won't terminate until @@ -30,17 +30,20 @@ self.started = [] self.finished = [] self._can_exit = not wait_before_exit - def task(): + def task(idx): tid = threading.get_ident() self.started.append(tid) try: - f() + if index: + f(idx) + else: + f() finally: self.finished.append(tid) while not self._can_exit: _wait() for i in range(n): - start_new_thread(task, ()) + start_new_thread(task, (i,)) def wait_for_started(self): while len(self.started) < self.n: @@ -403,53 +406,49 @@ self.assertRaises(RuntimeError, cond.notify) def _check_notify(self, cond): + # Check that 'nodify' wakes up 'at least' the + # specified number. The protocol does alow + # for spurious wakeups, so we must allow for that. N = 5 - results1 = [] - results2 = [] - phase_num = 0 + signaled = alive = waiting = 0 def f(): - cond.acquire() - result = cond.wait() - cond.release() - results1.append((result, phase_num)) - cond.acquire() - result = cond.wait() - cond.release() - results2.append((result, phase_num)) + nonlocal alive, waiting, signaled + with cond: + try: + waiting += 1 + try: + # only break wait if we were actually signaled + while not signaled: + cond.wait() + signaled -= 1 + finally: + waiting -= 1 + finally: + alive -= 1 + def settle(): + # wait for the workers to settle into their wait or exit + while True: + with cond: + if signaled == 0 and waiting == alive: + return + _wait() + alive = N b = Bunch(f, N) - b.wait_for_started() - _wait() - self.assertEqual(results1, []) + settle() # Notify 3 threads at first - cond.acquire() - cond.notify(3) - _wait() - phase_num = 1 - cond.release() - while len(results1) < 3: - _wait() - self.assertEqual(results1, [(True, 1)] * 3) - self.assertEqual(results2, []) - # Notify 5 threads: they might be in their first or second wait - cond.acquire() - cond.notify(5) - _wait() - phase_num = 2 - cond.release() - while len(results1) + len(results2) < 8: - _wait() - self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2) - self.assertEqual(results2, [(True, 2)] * 3) - # Notify all threads: they are all in their second wait - cond.acquire() - cond.notify_all() - _wait() - phase_num = 3 - cond.release() - while len(results2) < 5: - _wait() - self.assertEqual(results1, [(True, 1)] * 3 + [(True,2)] * 2) - self.assertEqual(results2, [(True, 2)] * 3 + [(True, 3)] * 2) + with cond: + signaled = 3; + cond.notify(3) + settle() + self.assertEqual(waiting, N-3) + + # Notify the rest + with cond: + signaled = waiting + cond.notify_all() + settle() + self.assertEqual(waiting, 0) + self.assertEqual(alive, 0) b.wait_for_finished() def test_notify(self): @@ -520,6 +519,90 @@ b.wait_for_finished() self.assertEqual(len(success), 1) + class Queue(): + """A simple queue class""" + def __init__(self, test): + self.test = test + self.cond = test.condtype() + self.queue = [] + def put(self, item): + with self.cond: + self.queue.append(item) + self.cond.notify() + def get(self, timeout=1.0): + with self.cond: + if not self.cond.wait_for((lambda:self.queue), timeout=timeout): + return None # timeout + self.test.assertTrue(self.queue) + return self.queue.pop(0) + + def test_queue(self): + source = self.Queue(self) + sink = self.Queue(self) + def f(): + # A job dispatch function + while True: + d = source.get() + if not d: + return + sink.put(d[0](*d[1])) + + b = Bunch(f, 5) + b.wait_for_started() + # Create a bunch of jobs to double a number + for i in range(10): + source.put(((lambda a : a+a), (i,))) + r = [sink.get() for i in range(10)] + r.sort() + self.assertEqual(r, [i+i for i in range(10)]) + b.wait_for_finished() + source = sink = None + + class NRLock(): + """ A simple non-recursive lock """ + def __init__(self, test): + self.test = test + self.cond = test.condtype() + self.locked = False + def acquire(self, timeout=None): + with self.cond: + if not self.cond.wait_for(lambda: not self.locked, timeout): + return False + self.test.assertFalse(self.locked) + self.locked = True + return True + def release(self): + with self.cond: + self.test.assertTrue(self.locked) + self.locked = False + self.cond.notify() + def __del__(self): + self.test.assertFalse(self.locked) + + def test_nrlock(self): + # test two cooperating threads using a non-recursive lock + locks = [self.NRLock(self), self.NRLock(self)] + result = [] + + def f(idx): + # two threads release each other + for i in range(3): + if locks[idx].acquire(1): + result.append((idx, i)) + if len(result) < 10: + locks[1-idx].release() + else: + break + locks[0].acquire() + locks[1].acquire() + b = Bunch(f, 2, index=True) + b.wait_for_started() + locks[0].release() + b.wait_for_finished() + locks[1].release() + self.assertEqual(result, [(0,0), (1,0), (0,1), (1,1), (0,2), (1,2)]) + locks = None + class BaseSemaphoreTests(BaseTestCase): """