diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -350,12 +350,17 @@ longMessage = True + # If two elements are unequal, assertEqual will try to generated a diff + # only if the number of elements is below these thresholds. + # This is necessary because difflib might require lot of time/memory + # to produce the diff if there are many elements. See #11763 and #19217. + _strDiffThreshold = 2**16 # used for strings + _seqDiffThreshold = 32 # used for sequences and containers + + # If a diff is generated, it will be displayed only if its size is + # below the maxDiff value or if maxDiff is None. maxDiff = 80*8 - # If a string is longer than _diffThreshold, use normal comparison instead - # of difflib. See #11763. - _diffThreshold = 2**16 - # Attribute used by TestSuite for classSetUp _classSetupFailed = False @@ -902,9 +907,10 @@ else: seq_type_name = "sequence" + max_seq_len = 0 differing = None try: - len1 = len(seq1) + len1 = max_seq_len = len(seq1) except (TypeError, NotImplementedError): differing = 'First %s has no length. Non-sequence?' % ( seq_type_name) @@ -915,6 +921,9 @@ except (TypeError, NotImplementedError): differing = 'Second %s has no length. Non-sequence?' % ( seq_type_name) + else: + if len2 > max_seq_len: + max_seq_len = len2 if differing is None: if seq1 == seq2: @@ -968,13 +977,13 @@ differing += ('Unable to index element %d ' 'of second %s\n' % (len1, seq_type_name)) standardMsg = differing - diffMsg = '\n' + '\n'.join( - difflib.ndiff(pprint.pformat(seq1).splitlines(), - pprint.pformat(seq2).splitlines())) - - standardMsg = self._truncateMessage(standardMsg, diffMsg) - msg = self._formatMessage(msg, standardMsg) - self.fail(msg) + # if the number of elements is below the threshold, use difflib + if max_seq_len < self._seqDiffThreshold: + diffMsg = '\n' + '\n'.join( + difflib.ndiff(pprint.pformat(seq1).splitlines(), + pprint.pformat(seq2).splitlines())) + standardMsg = self._truncateMessage(standardMsg, diffMsg) + self.fail(self._formatMessage(msg, standardMsg)) def _truncateMessage(self, message, diff): max_diff = self.maxDiff @@ -1078,13 +1087,15 @@ def assertDictEqual(self, d1, d2, msg=None): self.assertIsInstance(d1, dict, 'First argument is not a dictionary') self.assertIsInstance(d2, dict, 'Second argument is not a dictionary') - if d1 != d2: standardMsg = '%s != %s' % _common_shorten_repr(d1, d2) - diff = ('\n' + '\n'.join(difflib.ndiff( - pprint.pformat(d1).splitlines(), - pprint.pformat(d2).splitlines()))) - standardMsg = self._truncateMessage(standardMsg, diff) + # if the number of elements is below the threshold, use difflib + if (len(d1) < self._seqDiffThreshold and + len(d2) < self._seqDiffThreshold): + diff = ('\n' + '\n'.join( + difflib.ndiff(pprint.pformat(d1).splitlines(), + pprint.pformat(d2).splitlines()))) + standardMsg = self._truncateMessage(standardMsg, diff) self.fail(self._formatMessage(msg, standardMsg)) def assertDictContainsSubset(self, subset, dictionary, msg=None): @@ -1155,18 +1166,17 @@ self.assertIsInstance(second, str, 'Second argument is not a string') if first != second: - # don't use difflib if the strings are too long - if (len(first) > self._diffThreshold or - len(second) > self._diffThreshold): - self._baseAssertEqual(first, second, msg) - firstlines = first.splitlines(keepends=True) - secondlines = second.splitlines(keepends=True) - if len(firstlines) == 1 and first.strip('\r\n') == first: - firstlines = [first + '\n'] - secondlines = [second + '\n'] standardMsg = '%s != %s' % _common_shorten_repr(first, second) - diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines)) - standardMsg = self._truncateMessage(standardMsg, diff) + # if the number of elements is below the threshold, use difflib + if (len(first) < self._strDiffThreshold and + len(second) < self._strDiffThreshold): + firstlines = first.splitlines(keepends=True) + secondlines = second.splitlines(keepends=True) + if len(firstlines) == 1 and first.strip('\r\n') == first: + firstlines = [first + '\n'] + secondlines = [second + '\n'] + diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines)) + standardMsg = self._truncateMessage(standardMsg, diff) self.fail(self._formatMessage(msg, standardMsg)) def assertLess(self, a, b, msg=None): diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py --- a/Lib/unittest/test/test_case.py +++ b/Lib/unittest/test/test_case.py @@ -745,43 +745,40 @@ def testAssertSequenceEqualMaxDiff(self): self.assertEqual(self.maxDiff, 80*8) - seq1 = 'a' + 'x' * 80**2 - seq2 = 'b' + 'x' * 80**2 - diff = '\n'.join(difflib.ndiff(pprint.pformat(seq1).splitlines(), - pprint.pformat(seq2).splitlines())) - # the +1 is the leading \n added by assertSequenceEqual - omitted = unittest.case.DIFF_OMITTED % (len(diff) + 1,) + # keep the number of elements under the threshold or + # the diff won't be created + elements_num = 30 + self.assertLess(elements_num + 1, self._seqDiffThreshold) + seq1 = list('a' + 'x' * elements_num) + seq2 = list('b' + 'x' * elements_num) - self.maxDiff = len(diff)//2 - try: + # use the default maxDiff + with self.assertRaises(self.failureException) as cm: + self.assertSequenceEqual(seq1, seq2) + msg = str(cm.exception) + # check that the diff marker is included in the msg + self.assertIn('^', msg) + self.assertNotIn('Set self.maxDiff to None to see it.', msg) + + # set the maxDiff to a low value + self.maxDiff = 10 + with self.assertRaises(self.failureException) as cm: self.assertSequenceEqual(seq1, seq2) - except self.failureException as e: - msg = e.args[0] - else: - self.fail('assertSequenceEqual did not fail.') - self.assertLess(len(msg), len(diff)) - self.assertIn(omitted, msg) + msg = str(cm.exception) - self.maxDiff = len(diff) * 2 - try: + # check that the diff marker is not included in the msg + self.assertNotIn('^', msg) + self.assertIn('Set self.maxDiff to None to see it.', msg) + + # set the maxDiff to None (all the diff is included) + self.maxDiff = None + with self.assertRaises(self.failureException) as cm: self.assertSequenceEqual(seq1, seq2) - except self.failureException as e: - msg = e.args[0] - else: - self.fail('assertSequenceEqual did not fail.') - self.assertGreater(len(msg), len(diff)) - self.assertNotIn(omitted, msg) + msg = str(cm.exception) + self.assertIn('^', msg) + self.assertNotIn('Set self.maxDiff to None to see it.', msg) - self.maxDiff = None - try: - self.assertSequenceEqual(seq1, seq2) - except self.failureException as e: - msg = e.args[0] - else: - self.fail('assertSequenceEqual did not fail.') - self.assertGreater(len(msg), len(diff)) - self.assertNotIn(omitted, msg) def testTruncateMessage(self): self.maxDiff = 1 @@ -821,47 +818,87 @@ else: self.fail('assertMultiLineEqual did not fail') - def testAssertEqual_diffThreshold(self): - # check threshold value - self.assertEqual(self._diffThreshold, 2**16) + + def check_diff_threshold(self, below_threshold, above_threshold): + # get two pairs of unequal values (below and above the threshold) + # and check that the diff is generated only if the values are below + # the threshold + v1_below, v2_below = below_threshold + v1_above, v2_above = above_threshold # disable madDiff to get diff markers self.maxDiff = None - # set a lower threshold value and add a cleanup to restore it - old_threshold = self._diffThreshold - self._diffThreshold = 2**5 - self.addCleanup(lambda: setattr(self, '_diffThreshold', old_threshold)) + # values below the threshold: diff should be generated + with self.assertRaises(self.failureException) as cm: + self.assertEqual(v1_below, v2_below) + msg = str(cm.exception) + # check that the "x != y" and the diff marker (^) + # are included in the error message + self.assertIn(' != ', msg) + self.assertIn('^', msg) - # under the threshold: diff marker (^) in error message - s = 'x' * (2**4) - with self.assertRaises(self.failureException) as cm: - self.assertEqual(s + 'a', s + 'b') - self.assertIn('^', str(cm.exception)) - self.assertEqual(s + 'a', s + 'a') - - # over the threshold: diff not used and marker (^) not in error message - s = 'x' * (2**6) # if the path that uses difflib is taken, _truncateMessage will be # called -- replace it with explodingTruncation to verify that this # doesn't happen def explodingTruncation(message, diff): raise SystemError('this should not be raised') + old_truncate = self._truncateMessage self._truncateMessage = explodingTruncation self.addCleanup(lambda: setattr(self, '_truncateMessage', old_truncate)) - s1, s2 = s + 'a', s + 'b' + # values above the threshold: diff should not be generated with self.assertRaises(self.failureException) as cm: - self.assertEqual(s1, s2) - self.assertNotIn('^', str(cm.exception)) - self.assertEqual(str(cm.exception), '%r != %r' % (s1, s2)) - self.assertEqual(s + 'a', s + 'a') + self.assertEqual(v1_above, v2_above) + msg = str(cm.exception) + # check that the "x != y" is still included in the error message and + # that the diff marker (^) is not + self.assertIn(' != ', msg) + self.assertNotIn('^', msg) + # make sure that no diffing/truncation happens with equal values + self.assertEqual(v1_above, v1_above) + + + def testAssertEqual_strDiffThreshold(self): + self.assertEqual(self._strDiffThreshold, 2**16) + # set a lower threshold value and add a cleanup to restore it + old_threshold = self._strDiffThreshold + self._strDiffThreshold = 32 + self.addCleanup(lambda: setattr(self, '_strDiffThreshold', old_threshold)) + + below_threshold = 'x' * 16 + above_threshold = 'x' * 64 + self.check_diff_threshold( + (below_threshold + 'a', below_threshold + 'b'), + (above_threshold + 'a', above_threshold + 'b') + ) + + + def testAssertEqual_seqDiffThreshold(self): + self.assertEqual(self._seqDiffThreshold, 32) + below_threshold = (1,) * 16 + above_threshold = (1,) * 64 + self.check_diff_threshold( + (below_threshold + (2,), below_threshold + (3,)), + (above_threshold + (2,), above_threshold + (3,)) + ) + + + def testAssertEqual_dictDiffThreshold(self): + self.assertEqual(self._seqDiffThreshold, 32) + below_threshold = {i: 1 for i in range(16)} + above_threshold = {i: 1 for i in range(64)} + self.check_diff_threshold( + (dict(below_threshold, x=2), dict(below_threshold, x=3)), + (dict(above_threshold, x=2), dict(above_threshold, x=3)) + ) + def testAssertEqual_shorten(self): # set a lower threshold value and add a cleanup to restore it - old_threshold = self._diffThreshold - self._diffThreshold = 0 - self.addCleanup(lambda: setattr(self, '_diffThreshold', old_threshold)) + old_threshold = self._strDiffThreshold + self._strDiffThreshold = 0 + self.addCleanup(lambda: setattr(self, '_strDiffThreshold', old_threshold)) s = 'x' * 100 s1, s2 = s + 'a', s + 'b'