diff -r c1fc6b6d1cfc Lib/test/test_sched.py --- a/Lib/test/test_sched.py Sat Jan 05 06:26:39 2013 -0800 +++ b/Lib/test/test_sched.py Sun Jan 06 12:29:11 2013 +0200 @@ -9,6 +9,98 @@ except ImportError: threading = None + +class Timer: + def __init__(self): + self._sleep_cond = threading.Condition() + self._parties = 0 + self._sleepers = 0 + self._awake_cond = threading.Condition() + self._time = 0 + self._awake_time = float('inf') + + def time(self): + with self._awake_cond: + return self._time + + def sleep(self, t): + if t < 0: + raise ValueError('sleep length must be non-negative') + ident = threading.get_ident() + with self._awake_cond: + t += self._time + self._awake_time = min(self._awake_time, t) + with self._sleep_cond: + self._sleepers += 1 + self._sleep_cond.notify_all() + self._awake_cond.wait() + while self._time < t: + self._awake_cond.wait() + + def advance(self, t=None): + ident = threading.get_ident() + if t is not None: + if t < 0: + raise ValueError('advance length must be non-negative') + t += self._time + while True: + if not self._wait(): + return + with self._awake_cond: + if t is None: + self._awake() + break + elif t < self._awake_time: + self._time = t + return + else: + self._awake() + if t <= self._time: + break + self._wait() + + def close(self): + ident = threading.get_ident() + while True: + if not self._wait(): + return + with self._awake_cond: + self._awake() + + # wait until all the participants will fall asleep + def _wait(self): + with self._sleep_cond: + while self._sleepers < self._parties: + self._sleep_cond.wait() + return self._parties + + # awake all the sleepers + def _awake(self): + self._time = self._awake_time + self._awake_time = float('inf') + with self._sleep_cond: + self._sleepers = 0 + self._awake_cond.notify_all() + + def new_thread(self, *args, **kwargs): + t = threading.Thread(*args, **kwargs) + with self._sleep_cond: + self._parties += 1 + oldrun = t.run + def run(): + nonlocal oldrun + try: + oldrun() + finally: + with self._sleep_cond: + self._parties -= 1 + self._sleep_cond.notify_all() + # avoid a refcycle + del oldrun + t.run = run + return t + + class TestCase(unittest.TestCase): def test_enter(self): @@ -33,15 +125,29 @@ def test_enter_concurrent(self): l = [] fun = lambda x: l.append(x) - scheduler = sched.scheduler(time.time, time.sleep) - scheduler.enter(0.03, 1, fun, (0.03,)) - t = threading.Thread(target=scheduler.run) + timer = Timer() + scheduler = sched.scheduler(timer.time, timer.sleep) + scheduler.enter(1, 1, fun, (1,)) + scheduler.enter(3, 1, fun, (3,)) + t = timer.new_thread(target=scheduler.run) t.start() - for x in [0.05, 0.04, 0.02, 0.01]: - z = scheduler.enter(x, 1, fun, (x,)) - scheduler.run() - t.join() - self.assertEqual(l, [0.01, 0.02, 0.03, 0.04, 0.05]) + try: + timer.advance(1) + self.assertEqual(l, [1]) + for x in [4, 5, 2]: + z = scheduler.enter(x - 1, 1, fun, (x,)) + timer.advance(1) + self.assertEqual(l, [1, 2]) + timer.advance(1) + self.assertEqual(l, [1, 2, 3]) + timer.advance(1) + self.assertEqual(l, [1, 2, 3, 4]) + timer.advance(1) + self.assertEqual(l, [1, 2, 3, 4, 5]) + finally: + timer.close() + t.join() + self.assertEqual(l, [1, 2, 3, 4, 5]) def test_priority(self): l = [] @@ -71,19 +177,33 @@ def test_cancel_concurrent(self): l = [] fun = lambda x: l.append(x) - scheduler = sched.scheduler(time.time, time.sleep) - now = time.time() - event1 = scheduler.enterabs(now + 0.01, 1, fun, (0.01,)) - event2 = scheduler.enterabs(now + 0.02, 1, fun, (0.02,)) - event3 = scheduler.enterabs(now + 0.03, 1, fun, (0.03,)) - event4 = scheduler.enterabs(now + 0.04, 1, fun, (0.04,)) - event5 = scheduler.enterabs(now + 0.05, 1, fun, (0.05,)) - t = threading.Thread(target=scheduler.run) + timer = Timer() + scheduler = sched.scheduler(timer.time, timer.sleep) + now = timer.time() + event1 = scheduler.enterabs(now + 1, 1, fun, (1,)) + event2 = scheduler.enterabs(now + 2, 1, fun, (2,)) + event3 = scheduler.enterabs(now + 3, 1, fun, (3,)) + event4 = scheduler.enterabs(now + 4, 1, fun, (4,)) + event5 = scheduler.enterabs(now + 5, 1, fun, (5,)) + t = timer.new_thread(target=scheduler.run) t.start() - scheduler.cancel(event1) - scheduler.cancel(event5) - t.join() - self.assertEqual(l, [0.02, 0.03, 0.04]) + try: + timer.advance(1.5) + self.assertEqual(l, [1]) + scheduler.cancel(event2) + scheduler.cancel(event5) + timer.advance(1) + self.assertEqual(l, [1]) + timer.advance(1) + self.assertEqual(l, [1, 3]) + timer.advance(1) + self.assertEqual(l, [1, 3, 4]) + timer.advance(1) + self.assertEqual(l, [1, 3, 4]) + finally: + timer.close() + t.join() + self.assertEqual(l, [1, 3, 4]) def test_empty(self): l = []