diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -2278,6 +2278,24 @@ `read_data` is a string for the `read` methoddline`, and `readlines` of the file handle to return. This is an empty string by default. """ + def _readlines_side_effect(*args, **kwargs): + if handle.readlines.return_value is not None: + return handle.readlines.return_value + return list(_state[0]) + + def _read_side_effect(*args, **kwargs): + if handle.read.return_value is not None: + return handle.read.return_value + return ''.join(_state[0]) + + def _readline_side_effect(): + if handle.readline.return_value is not None: + while True: + yield handle.readline.return_value + for line in _state[0]: + yield line + + global file_spec if file_spec is None: import _io @@ -2286,42 +2304,31 @@ if mock is None: mock = MagicMock(name='open', spec=open) - def make_handle(*args, **kwargs): - # Arg checking is handled by __call__ - def _readlines_side_effect(*args, **kwargs): - if handle.readlines.return_value is not None: - return handle.readlines.return_value - return list(_data) - - def _read_side_effect(*args, **kwargs): - if handle.read.return_value is not None: - return handle.read.return_value - return ''.join(_data) - - def _readline_side_effect(): - if handle.readline.return_value is not None: - while True: - yield handle.readline.return_value - for line in _data: - yield line - - handle = MagicMock(spec=file_spec) - handle.__enter__.return_value = handle - - _data = _iterate_read_data(read_data) - - handle.write.return_value = None - handle.read.return_value = None - handle.readline.return_value = None - handle.readlines.return_value = None - - handle.read.side_effect = _read_side_effect - handle.readline.side_effect = _readline_side_effect() - handle.readlines.side_effect = _readlines_side_effect - _check_and_set_parent(mock, handle, None, '()') - return handle - - mock.side_effect = make_handle + handle = MagicMock(spec=file_spec) + handle.__enter__.return_value = handle + + _state = [_iterate_read_data(read_data), None] + + handle.write.return_value = None + handle.read.return_value = None + handle.readline.return_value = None + handle.readlines.return_value = None + + handle.read.side_effect = _read_side_effect + _state[1] = _readline_side_effect() + handle.readline.side_effect = _state[1] + handle.readlines.side_effect = _readlines_side_effect + + def reset_data(*args, **kwargs): + _state[0] = _iterate_read_data(read_data) + if handle.readline.side_effect == _state[1]: + # Only reset the side effect if the user hasn't overridden it. + _state[1] = _readline_side_effect() + handle.readline.side_effect = _state[1] + return DEFAULT + + mock.side_effect = reset_data + mock.return_value = handle return mock diff --git a/Lib/unittest/test/testmock/testmock.py b/Lib/unittest/test/testmock/testmock.py --- a/Lib/unittest/test/testmock/testmock.py +++ b/Lib/unittest/test/testmock/testmock.py @@ -1,5 +1,6 @@ import copy import sys +import tempfile import unittest from unittest.test.testmock.support import is_instance @@ -1329,8 +1330,29 @@ def test_mock_open_reuse_issue_21750(self): mocked_open = mock.mock_open(read_data='data') f1 = mocked_open('a-name') + f1_data = f1.read() f2 = mocked_open('another-name') - self.assertEqual(f1.read(), f2.read()) + f2_data = f2.read() + self.assertEqual(f1_data, f2_data) + + def test_mock_open_write(self): + # Test exception in file writing write() + mock_namedtemp = mock.mock_open(mock.MagicMock(name='JLV')) + with mock.patch('tempfile.NamedTemporaryFile', mock_namedtemp): + mock_filehandle = mock_namedtemp.return_value + mock_write = mock_filehandle.write + mock_write.side_effect = OSError('Test 2 Error') + def attempt(): + tempfile.NamedTemporaryFile().write('asd') + self.assertRaises(OSError, attempt) + + def test_mock_open_alter_readline(self): + mopen = mock.mock_open(read_data='foo\nbarn') + mopen.return_value.readline.side_effect = lambda *args:'abc' + first = mopen().readline() + second = mopen().readline() + self.assertEqual('abc', first) + self.assertEqual('abc', second) def test_mock_parents(self): for Klass in Mock, MagicMock: diff --git a/Lib/unittest/test/testmock/testwith.py b/Lib/unittest/test/testmock/testwith.py --- a/Lib/unittest/test/testmock/testwith.py +++ b/Lib/unittest/test/testmock/testwith.py @@ -141,6 +141,7 @@ def test_mock_open_context_manager(self): mock = mock_open() + handle = mock.return_value with patch('%s.open' % __name__, mock, create=True): with open('foo') as f: f.read() @@ -148,8 +149,7 @@ expected_calls = [call('foo'), call().__enter__(), call().read(), call().__exit__(None, None, None)] self.assertEqual(mock.mock_calls, expected_calls) - # mock_open.return_value is no longer static, because - # readline support requires that it mutate state + self.assertIs(f, handle) def test_mock_open_context_manager_multiple_times(self): mock = mock_open()