Index: Doc/library/fileinput.rst =================================================================== --- Doc/library/fileinput.rst (revision 76339) +++ Doc/library/fileinput.rst (working copy) @@ -53,7 +53,14 @@ as global state for the functions of this module, and is also returned to use during iteration. The parameters to this function will be passed along to the constructor of the :class:`FileInput` class. + + The :func:`fileinput.input` function can be used as a context manager via + the :keyword:`with` statement. In this example, *input* is closed after the + :keyword:`with` statement is exited---even if an exception occurs:: + with fileinput.input(files=('spam.txt', 'eggs.txt')) as input: + process(input) + .. versionchanged:: 2.5 Added the *mode* and *openhook* parameters. @@ -135,7 +142,14 @@ The *openhook*, when given, must be a function that takes two arguments, *filename* and *mode*, and returns an accordingly opened file-like object. You cannot use *inplace* and *openhook* together. + + FileInput is a context manager and supports the :keyword:`with` statement. + In this example, *input* is closed after the :keyword:`with` statement is + exited---even if an exception occurs:: + with FileInput(files=('spam.txt', 'eggs.txt')) as input: + process(input) + .. versionchanged:: 2.5 Added the *mode* and *openhook* parameters. Index: Lib/fileinput.py =================================================================== --- Lib/fileinput.py (revision 76339) +++ Lib/fileinput.py (working copy) @@ -110,7 +110,13 @@ _state = None if state: state.close() + +def __enter__(self): + return _state +def __exit__(self, type, value, traceback): + close() + def nextfile(): """ Close the current file so that the next iteration will read the first @@ -237,6 +243,12 @@ self.nextfile() self._files = () + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + def __iter__(self): return self Index: Lib/test/test_fileinput.py =================================================================== --- Lib/test/test_fileinput.py (revision 76339) +++ Lib/test/test_fileinput.py (working copy) @@ -8,11 +8,12 @@ from test.test_support import unlink as safe_unlink import sys, re from StringIO import StringIO +import fileinput from fileinput import FileInput, hook_encoded # The fileinput module has 2 interfaces: the FileInput class which does # all the work, and a few functions (input, etc.) that use a global _state -# variable. We only test the FileInput class, since the other functions +# variable. We mostly test the FileInput class, since the other functions # only provide a thin facade over FileInput. # Write lines (a list of lines) to temp file number i, and return the @@ -118,7 +119,7 @@ self.assertEqual(int(m.group(1)), fi.filelineno()) fi.close() -class FileInputTests(unittest.TestCase): +class FileInputClassTests(unittest.TestCase): def test_zero_byte_files(self): try: t1 = writeTmp(1, [""]) @@ -217,9 +218,58 @@ self.assertEqual(lines, ["N\n", "O"]) finally: remove_tempfiles(t1) + + def test_context_manager(self): + try: + t1 = writeTmp(1, ["A\nB\nC"]) + t2 = writeTmp(2, ["D\nE\nF"]) + with FileInput(files=(t1, t2)) as fi: + lines = list(fi) + self.assertEqual(lines, ["A\n", "B\n", "C", "D\n", "E\n", "F"]) + self.assertEqual(fi.filelineno(), 3) + self.assertEqual(fi.lineno(), 6) + self.assertEqual(fi._files, ()) + finally: + remove_tempfiles(t1, t2) + + def test_close_on_exception(self): + try: + t1 = writeTmp(1, [""]) + with FileInput(files=(t1)) as fi: + raise IOError + except IOError: + self.assertEqual(fi._files, ()) + finally: + remove_tempfiles(t1) + +class FileInputFunctionTests(unittest.TestCase): + def test_context_manager(self): + try: + t1 = writeTmp(1, ["A\nB\nC"]) + t2 = writeTmp(2, ["D\nE\nF"]) + with fileinput.input(files=(t1, t2)) as fi: + lines = list(fi) + self.assertEqual(lines, ["A\n", "B\n", "C", "D\n", "E\n", "F"]) + self.assertEqual(fi.filelineno(), 3) + self.assertEqual(fi.lineno(), 6) + self.assertEqual(fi._files, ()) + finally: + remove_tempfiles(t1, t2) + + def test_close_on_exception(self): + try: + t1 = writeTmp(1, [""]) + with fileinput.input(files=(t1)) as fi: + raise IOError + except IOError: + self.assertEqual(fi._files, ()) + finally: + remove_tempfiles(t1) + + def test_main(): - run_unittest(BufferSizesTests, FileInputTests) + run_unittest(BufferSizesTests, FileInputClassTests, FileInputFunctionTests) if __name__ == "__main__": test_main()