diff -r 06cf4044a11a Lib/unittest/case.py --- a/Lib/unittest/case.py Sat Aug 09 09:34:25 2014 +0300 +++ b/Lib/unittest/case.py Sun Aug 10 10:36:26 2014 +0300 @@ -3,6 +3,8 @@ import sys import functools import difflib +import io +import itertools import logging import pprint import re @@ -18,8 +20,8 @@ __unittest = True -DIFF_OMITTED = ('\nDiff is %s characters long. ' - 'Set self.maxDiff to None to see it.') +DIFF_OMITTED = ('*** Diff is %s characters long. ' + 'Set self.maxDiff to None to see it.') class SkipTest(Exception): """ @@ -350,6 +352,8 @@ longMessage = True + maxDiffItems = 32 + maxDiff = 80*8 # If a string is longer than _diffThreshold, use normal comparison instead @@ -968,19 +972,38 @@ differing += ('Unable to index element %d ' 'of second %s\n' % (len1, seq_type_name)) standardMsg = differing - diffMsg = '\n' + '\n'.join( - difflib.ndiff(pprint.pformat(seq1).splitlines(), - pprint.pformat(seq2).splitlines())) + diffMsg = self._truncateDiffs(pprint.pformat(seq1).splitlines(), + pprint.pformat(seq2).splitlines()) + msg = self._formatMessage(msg, standardMsg + '\n' + diffMsg) + self.fail(msg) - standardMsg = self._truncateMessage(standardMsg, diffMsg) - msg = self._formatMessage(msg, standardMsg) - self.fail(msg) + def _truncateDiffs(self, a, b, sep='\n'): + max_lines = self.maxDiffItems + truncated = False + if max_lines is not None: + start = min(len(a), len(b)) + for i in range(start): + if a[i] != b[i]: + start = i + break + end = start + max_lines + truncated = end < max(len(a), len(b)) + a = a[start: end] + b = b[start: end] + diff = sep.join(difflib.ndiff(a, b)) + max_diff = self.maxDiff + if max_diff is not None and len(diff) > max_diff: + diff = DIFF_OMITTED % len(diff) + elif max_lines is not None and truncated: + diff += ('\n...\n*** Diff is truncated. ' + 'Set self.maxDiffItems to None to see it all.') + return diff def _truncateMessage(self, message, diff): max_diff = self.maxDiff if max_diff is None or len(diff) <= max_diff: return message + diff - return message + (DIFF_OMITTED % len(diff)) + return message + '\n' + (DIFF_OMITTED % len(diff)) def assertListEqual(self, list1, list2, msg=None): """A list-specific equality assertion. @@ -1081,11 +1104,9 @@ if d1 != d2: standardMsg = '%s != %s' % _common_shorten_repr(d1, d2) - diff = ('\n' + '\n'.join(difflib.ndiff( - pprint.pformat(d1).splitlines(), - pprint.pformat(d2).splitlines()))) - standardMsg = self._truncateMessage(standardMsg, diff) - self.fail(self._formatMessage(msg, standardMsg)) + diff = self._truncateDiffs(pprint.pformat(d1).splitlines(), + pprint.pformat(d2).splitlines()) + self.fail(self._formatMessage(msg, standardMsg + '\n' + diff)) def assertDictContainsSubset(self, subset, dictionary, msg=None): """Checks whether dictionary is a superset of subset.""" @@ -1165,9 +1186,8 @@ firstlines = [first + '\n'] secondlines = [second + '\n'] standardMsg = '%s != %s' % _common_shorten_repr(first, second) - diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines)) - standardMsg = self._truncateMessage(standardMsg, diff) - self.fail(self._formatMessage(msg, standardMsg)) + diff = self._truncateDiffs(firstlines, secondlines, '') + self.fail(self._formatMessage(msg, standardMsg + '\n' + diff)) def assertLess(self, a, b, msg=None): """Just like self.assertTrue(a < b), but with a nicer default message.""" diff -r 06cf4044a11a Lib/unittest/test/test_case.py --- a/Lib/unittest/test/test_case.py Sat Aug 09 09:34:25 2014 +0300 +++ b/Lib/unittest/test/test_case.py Sun Aug 10 10:36:26 2014 +0300 @@ -745,49 +745,46 @@ def testAssertSequenceEqualMaxDiff(self): self.assertEqual(self.maxDiff, 80*8) - seq1 = 'a' + 'x' * 80**2 - seq2 = 'b' + 'x' * 80**2 - diff = '\n'.join(difflib.ndiff(pprint.pformat(seq1).splitlines(), - pprint.pformat(seq2).splitlines())) - # the +1 is the leading \n added by assertSequenceEqual - omitted = unittest.case.DIFF_OMITTED % (len(diff) + 1,) + # keep the number of elements under the threshold or + # the diff won't be created + elements_num = 30 + self.assertLess(elements_num + 1, self.maxDiffItems) + seq1 = list('a' + 'x' * elements_num) + seq2 = list('b' + 'x' * elements_num) - self.maxDiff = len(diff)//2 - try: + # use the default maxDiff + with self.assertRaises(self.failureException) as cm: + self.assertSequenceEqual(seq1, seq2) + msg = str(cm.exception) + # check that the diff marker is included in the msg + self.assertIn('^', msg) + self.assertNotIn('Set self.maxDiff to None to see it.', msg) + + # set the maxDiff to a low value + self.maxDiff = 10 + with self.assertRaises(self.failureException) as cm: self.assertSequenceEqual(seq1, seq2) - except self.failureException as e: - msg = e.args[0] - else: - self.fail('assertSequenceEqual did not fail.') - self.assertLess(len(msg), len(diff)) - self.assertIn(omitted, msg) + msg = str(cm.exception) - self.maxDiff = len(diff) * 2 - try: + # check that the diff marker is not included in the msg + self.assertNotIn('^', msg) + self.assertIn('Set self.maxDiff to None to see it.', msg) + + # set the maxDiff to None (all the diff is included) + self.maxDiff = None + with self.assertRaises(self.failureException) as cm: self.assertSequenceEqual(seq1, seq2) - except self.failureException as e: - msg = e.args[0] - else: - self.fail('assertSequenceEqual did not fail.') - self.assertGreater(len(msg), len(diff)) - self.assertNotIn(omitted, msg) + msg = str(cm.exception) + self.assertIn('^', msg) + self.assertNotIn('Set self.maxDiff to None to see it.', msg) - self.maxDiff = None - try: - self.assertSequenceEqual(seq1, seq2) - except self.failureException as e: - msg = e.args[0] - else: - self.fail('assertSequenceEqual did not fail.') - self.assertGreater(len(msg), len(diff)) - self.assertNotIn(omitted, msg) def testTruncateMessage(self): self.maxDiff = 1 message = self._truncateMessage('foo', 'bar') omitted = unittest.case.DIFF_OMITTED % len('bar') - self.assertEqual(message, 'foo' + omitted) + self.assertEqual(message, 'foo' + '\n' + omitted) self.maxDiff = None message = self._truncateMessage('foo', 'bar') @@ -799,63 +796,95 @@ def testAssertDictEqualTruncates(self): test = unittest.TestCase('assertEqual') - def truncate(msg, diff): - return 'foo' - test._truncateMessage = truncate try: test.assertDictEqual({}, {1: 0}) except self.failureException as e: - self.assertEqual(str(e), 'foo') + self.assertEqual(str(e), '{} != {1: 0}\n' + '- {}\n' + '+ {1: 0}') else: self.fail('assertDictEqual did not fail') def testAssertMultiLineEqualTruncates(self): test = unittest.TestCase('assertEqual') - def truncate(msg, diff): - return 'foo' - test._truncateMessage = truncate try: test.assertMultiLineEqual('foo', 'bar') except self.failureException as e: - self.assertEqual(str(e), 'foo') + self.assertEqual(str(e), "'foo' != 'bar'\n" + '- foo\n' + '+ bar\n') else: self.fail('assertMultiLineEqual did not fail') - def testAssertEqual_diffThreshold(self): - # check threshold value - self.assertEqual(self._diffThreshold, 2**16) + + def check_diff_threshold(self, test, below_threshold, above_threshold): + # get two pairs of unequal values (below and above the threshold) + # and check that the diff is generated only if the values are below + # the threshold + v1_below, v2_below = below_threshold + v1_above, v2_above = above_threshold # disable madDiff to get diff markers - self.maxDiff = None + test.maxDiff = None - # set a lower threshold value and add a cleanup to restore it - old_threshold = self._diffThreshold - self._diffThreshold = 2**5 - self.addCleanup(lambda: setattr(self, '_diffThreshold', old_threshold)) + # values below the threshold: diff should be generated + with self.assertRaises(test.failureException) as cm: + test.assertEqual(v1_below, v2_below) + msg = str(cm.exception) + # check that the "x != y" and the diff marker (^) + # are included in the error message + self.assertIn(' != ', msg) + self.assertIn('^', msg) - # under the threshold: diff marker (^) in error message - s = 'x' * (2**4) - with self.assertRaises(self.failureException) as cm: - self.assertEqual(s + 'a', s + 'b') - self.assertIn('^', str(cm.exception)) - self.assertEqual(s + 'a', s + 'a') - - # over the threshold: diff not used and marker (^) not in error message - s = 'x' * (2**6) # if the path that uses difflib is taken, _truncateMessage will be # called -- replace it with explodingTruncation to verify that this # doesn't happen def explodingTruncation(message, diff): raise SystemError('this should not be raised') - old_truncate = self._truncateMessage - self._truncateMessage = explodingTruncation - self.addCleanup(lambda: setattr(self, '_truncateMessage', old_truncate)) - s1, s2 = s + 'a', s + 'b' - with self.assertRaises(self.failureException) as cm: - self.assertEqual(s1, s2) - self.assertNotIn('^', str(cm.exception)) - self.assertEqual(str(cm.exception), '%r != %r' % (s1, s2)) - self.assertEqual(s + 'a', s + 'a') + test._truncateMessage = explodingTruncation + + # values above the threshold: diff should not be generated + with self.assertRaises(test.failureException) as cm: + test.assertEqual(v1_above, v2_above) + msg = str(cm.exception) + # check that the "x != y" is still included in the error message + self.assertIn(' != ', msg) + # make sure that no diffing/truncation happens with equal values + self.assertEqual(v1_above, v1_above) + + + def testAssertEqual_diffThreshold(self): + test = unittest.TestCase() + test._diffThreshold = 32 + below_threshold = 'x' * 16 + above_threshold = 'x' * 64 + self.check_diff_threshold(test, + (below_threshold + 'a', below_threshold + 'b'), + (above_threshold + 'a', above_threshold + 'b') + ) + + + def testAssertEqual_maxDiffItems(self): + test = unittest.TestCase() + test.maxDiffItems = 32 + below_threshold = (1,) * 16 + above_threshold = (1,) * 64 + self.check_diff_threshold(test, + (below_threshold + (2,), below_threshold + (3,)), + (above_threshold + (2,), above_threshold + (3,)) + ) + + + def testAssertEqual_dictDiffThreshold(self): + test = unittest.TestCase() + test.maxDiffItems = 32 + below_threshold = {i: 1 for i in range(16)} + above_threshold = {i: 1 for i in range(64)} + self.check_diff_threshold(test, + (dict(below_threshold, x=2), dict(below_threshold, x=3)), + (dict(above_threshold, x=2), dict(above_threshold, x=3)) + ) + def testAssertEqual_shorten(self): # set a lower threshold value and add a cleanup to restore it