Index: Doc/library/csv.rst =================================================================== --- Doc/library/csv.rst (revision 69457) +++ Doc/library/csv.rst (working copy) @@ -161,7 +161,35 @@ The :mod:`csv` module defines the following classes: +.. class:: NamedTupleReader(csvfile, fieldnames=None[, restkey=None[, restval=None[, dialect='excel'[, *args, **kwds]]]]) + Create an object which operates like a regular reader but maps the information + read into a *namedtuple* whose fields can be accessed using attribute lookup. + The contents of *fieldnames* are passed directly to be used as the + namedtuple fieldnames. If *fieldnames* is None the values in the + first row of the *csvfile* will be used as the fieldnames. + If the row read has fewer fields than the fieldnames sequence, the value of + *restval* will be used as the default value. If the row read has more fields + than the fieldnames sequence, then the extra fields will be clipped unless + *restkey* has been specified. If it is has, then the extra fields are stored + as a single list in the last field, named is *restkey*. Any other optional or + keyword arguments are passed to the underlying :class:`reader` instance. + + +.. class:: NamedTupleWriter(csvfile[, fieldnames=None[, restkey=None[, restval=None[, dialect='excel'[, *args, **kwds]]]]]) + + Create an object which operates like a regular writer but maps namedtuples onto + output rows. The *fieldnames* parameter identifies the valid fieldnames that will + be written from a *namedtuple* passed to the :meth:`writetrow` to the *csvfile*. + The optional *restval* parameter specifies the value to be written if the *namedtuple* + is missing a field listed in *fieldnames*. If the *namedtuple* passed to the + :meth:`writerow` method contains a field not found in *fieldnames*, the optional + *extrasaction* parameter indicates what action to take. If it is set to ``'raise'`` a + :exc:`ValueError` is raised. If it is set to ``'ignore'``, extra fields in the + *namedtuple* are ignored. Any other optional or keyword arguments are passed to + the underlying :class:`writer` instance. + + .. class:: DictReader(csvfile[, fieldnames=None[, restkey=None[, restval=None[, dialect='excel'[, *args, **kwds]]]]]) Create an object which operates like a regular reader but maps the information Index: Lib/csv.py =================================================================== --- Lib/csv.py (revision 69457) +++ Lib/csv.py (working copy) @@ -11,6 +11,7 @@ QUOTE_MINIMAL, QUOTE_ALL, QUOTE_NONNUMERIC, QUOTE_NONE, \ __doc__ from _csv import Dialect as _Dialect +from collections import namedtuple as _namedtuple try: from cStringIO import StringIO @@ -69,7 +70,129 @@ delimiter = '\t' register_dialect("excel-tab", excel_tab) +class NamedTupleReader: + def __init__(self, f, fieldnames=None, restkey=None, restval=None, + dialect="excel", *args, **kwds): + # list of fieldnames for the namedtuple or a namedtuple. + self._fieldnames = fieldnames + + try: + # namedtuple subclass name. + self._name = self._fieldnames.__name__ + except AttributeError: + self._name = 'Fields' + + self.restkey = restkey # key to catch long rows + self.restval = restval # default value for short rows + self.reader = reader(f, dialect, *args, **kwds) + self.dialect = dialect + self.line_num = 0 + + def __iter__(self): + return self + + @property + def fieldnames(self): + """Fetch field names from the stored namedtuple subclass, create one if + one is not yet stored.""" + # attempt to short circuit if we already have a namedtuple stored. + try: + return list(self._fieldnames._fields) + except AttributeError: + pass + + # if no fieldnames were passed, attempt to read from + # first row. + if self._fieldnames is None: + try: + self._fieldnames = self.reader.next() + except StopIteration: + return + finally: + self.line_num = self.reader.line_num + + self._fieldnames = _namedtuple(self._name, self._fieldnames) + return list(self._fieldnames._fields) + + @fieldnames.setter + def fieldnames(self, value): + """Set the fieldnames to a new namedtuple. + Can be a sequence of fieldnames or a namedtuple subclass.""" + if hasattr(value, '_fields'): + # attempt to keep self._name attribute up to date. + self._name = value.__name__ + + self._fieldnames = value + + def next(self): + if self.line_num == 0: + # Used only for its side effect. + self.fieldnames + row = self.reader.next() + self.line_num = self.reader.line_num + + # unlike the basic reader, we prefer not to return blanks, + # because we will typically wind up with a namedtuple full of None + # values + while row == []: + row = self.reader.next() + + n = len(self.fieldnames) + + # pad missing fields with restval. + if len(row) < n: + row += [self.restval] * (n - len(row)) + + # either clip or assign to restkey. + if self.restkey is None: + row = row[:n] + else: + # need to make a new nt with restkey fieldname. + fieldnames = self._fieldnames._fields + (self.restkey,) + rest_nt = _namedtuple(self._name, fieldnames) + row[n:] = [row[n:]] + + return rest_nt._make(row) + + return self._fieldnames._make(row) + +class NamedTupleWriter: + def __init__(self, f, fieldnames=None, restval="", extrasaction="raise", + dialect="excel", *args, **kwds): + + self.fieldnames = fieldnames # list of fieldnames for the namedtuple. + self.restval = restval # for writing short namedtuples. + + if extrasaction.lower() not in ("raise", "ignore"): + raise ValueError, \ + ("extrasaction (%s) must be 'raise' or 'ignore'" % + extrasaction) + + self.extrasaction = extrasaction + self.writer = writer(f, dialect, *args, **kwds) + + def _nt_to_list(self, row_nt): + if self.fieldnames is None: + return row_nt + + if self.extrasaction == "raise": + wrong_fields = [n for n in row_nt._fields if n not in self.fieldnames] + if wrong_fields: + raise ValueError("namedtuple contains fields not in fieldnames: " + ", ".join(wrong_fields)) + return [getattr(row_nt, name, self.restval) for name in self.fieldnames] + + def writerow(self, row_nt): + return self.writer.writerow(self._nt_to_list(row_nt)) + + def writerows(self, row_nts): + rows = [] + for row_nt in row_nts: + rows.append(self._nt_to_list(row_nt)) + return self.writer.writerows(rows) + + class DictReader: def __init__(self, f, fieldnames=None, restkey=None, restval=None, dialect="excel", *args, **kwds): Index: Lib/test/test_csv.py =================================================================== --- Lib/test/test_csv.py (revision 69457) +++ Lib/test/test_csv.py (working copy) @@ -8,6 +8,7 @@ from StringIO import StringIO import tempfile import csv +import collections import gc from test import test_support @@ -572,6 +573,241 @@ def test_read_escape_fieldsep(self): self.readerAssertEqual('"abc\\,def"\r\n', [['abc,def']]) +#--------------------------------------------------------------------------------- + +class TestNamedTupleFields(unittest.TestCase): + ### "long" means the row is longer than the number of fieldnames + ### "short" means there are fewer elements in the row than fieldnames + def test_write_simple_nt(self): + fd, name = tempfile.mkstemp() + fileobj = os.fdopen(fd, "w+b") + try: + writer = csv.NamedTupleWriter(fileobj, fieldnames=["f1", "f2", "f3"]) + + Fields = collections.namedtuple('mynt', 'f1 f3') + fieldnames = Fields(f1=10, f3='abc') + + writer.writerow(fieldnames) + fileobj.seek(0) + self.assertEqual(fileobj.read(), "10,,abc\r\n") + finally: + fileobj.close() + os.unlink(name) + + def test_write_nt_no_fields(self): + fd, name = tempfile.mkstemp() + fileobj = os.fdopen(fd, "w+b") + try: + writer = csv.NamedTupleWriter(fileobj) + + Fields = collections.namedtuple('mynt', 'f1 f2 f3') + fieldnames = Fields(f1=10, f2=None, f3='abc') + + writer.writerow(fieldnames) + fileobj.seek(0) + self.assertEqual(fileobj.read(), "10,,abc\r\n") + finally: + fileobj.close() + os.unlink(name) + + def test_read_namedtuple_fields(self): + fd, name = tempfile.mkstemp() + fileobj = os.fdopen(fd, "w+b") + try: + fileobj.write("1,2,abc\r\n") + fileobj.seek(0) + fieldnames = collections.namedtuple('fieldnames', 'f1 f2 f3') + reader = csv.NamedTupleReader(fileobj, + fieldnames=fieldnames) + + Fields = collections.namedtuple('fieldnames', 'f1 f2 f3') + fieldnames = Fields(f1='1', f2='2', f3='abc') + + self.assertEqual(reader.next(), fieldnames) + finally: + fileobj.close() + os.unlink(name) + + def test_read_nt_no_fieldnames(self): + fd, name = tempfile.mkstemp() + fileobj = os.fdopen(fd, "w+b") + try: + fileobj.write("f1,f2,f3\r\n1,2,abc\r\n") + fileobj.seek(0) + reader = csv.NamedTupleReader(fileobj) + self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) + + Fields = collections.namedtuple('fieldnames', 'f1 f2 f3') + fieldnames = Fields(f1='1', f2='2', f3='abc') + + self.assertEqual(reader.next(), fieldnames) + finally: + fileobj.close() + os.unlink(name) + + # Two test cases to make sure existing ways of implicitly setting + # fieldnames continue to work. Both arise from discussion in issue3436. + def test_read_nt_fieldnames_from_file(self): + fd, name = tempfile.mkstemp() + f = os.fdopen(fd, "w+b") + try: + f.write("f1,f2,f3\r\n1,2,abc\r\n") + f.seek(0) + reader = csv.NamedTupleReader(f, fieldnames=csv.reader(f).next()) + self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) + + Fields = collections.namedtuple('fieldnames', 'f1 f2 f3') + fieldnames = Fields(f1='1', f2='2', f3='abc') + + self.assertEqual(reader.next(), fieldnames) + finally: + f.close() + os.unlink(name) + + def test_read_nt_fieldnames_chain(self): + import itertools + fd, name = tempfile.mkstemp() + f = os.fdopen(fd, "w+b") + try: + f.write("f1,f2,f3\r\n1,2,abc\r\n") + f.seek(0) + reader = csv.NamedTupleReader(f) + first = next(reader) + for row in itertools.chain([first], reader): + self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) + + Fields = collections.namedtuple('fieldnames', 'f1 f2 f3') + fieldnames = Fields(f1='1', f2='2', f3='abc') + + self.assertEqual(row, fieldnames) + finally: + f.close() + os.unlink(name) + + def test_read_long_clipped(self): + fd, name = tempfile.mkstemp() + fileobj = os.fdopen(fd, "w+b") + try: + fileobj.write("1,2,abc,4,5,6\r\n") + fileobj.seek(0) + reader = csv.NamedTupleReader(fileobj, + fieldnames=["f1", "f2"]) + + Fields = collections.namedtuple('fieldnames', 'f1 f2') + fieldnames = Fields(f1='1', f2='2') + + self.assertEqual(reader.next(), fieldnames) + finally: + fileobj.close() + os.unlink(name) + + def test_read_long_rest(self): + fd, name = tempfile.mkstemp() + fileobj = os.fdopen(fd, "w+b") + try: + fileobj.write("1,2,abc,4,5,6\r\n") + fileobj.seek(0) + reader = csv.NamedTupleReader(fileobj, + fieldnames=["f1", "f2"], restkey='DEFAULT') + + Fields = collections.namedtuple('fieldnames', 'f1 f2 DEFAULT') + fieldnames = Fields(f1='1', f2='2', DEFAULT=['abc', '4', '5', '6']) + + self.assertEqual(reader.next(), fieldnames) + finally: + fileobj.close() + os.unlink(name) + + + def test_read_long_with_rest(self): + fd, name = tempfile.mkstemp() + fileobj = os.fdopen(fd, "w+b") + try: + fileobj.write("1,2,abc,4,5,6\r\n") + fileobj.seek(0) + reader = csv.NamedTupleReader(fileobj, + fieldnames=["f1", "f2"], restkey="rest") + + Fields = collections.namedtuple('fieldnames', 'f1 f2 rest') + fieldnames = Fields(f1='1', f2='2', rest=['abc', '4', '5', '6']) + + self.assertEqual(reader.next(), fieldnames) + finally: + fileobj.close() + os.unlink(name) + + def test_read_long_with_rest_no_fieldnames(self): + fd, name = tempfile.mkstemp() + fileobj = os.fdopen(fd, "w+b") + try: + fileobj.write("f1,f2\r\n1,2,abc,4,5,6\r\n") + fileobj.seek(0) + reader = csv.NamedTupleReader(fileobj, restkey="rest") + self.assertEqual(reader.fieldnames, ["f1", "f2"]) + self.assertEqual(reader.next()._asdict(), {"f1": '1', "f2": '2', + "rest": ["abc", "4", "5", "6"]}) + finally: + fileobj.close() + os.unlink(name) + + def test_read_short(self): + fd, name = tempfile.mkstemp() + fileobj = os.fdopen(fd, "w+b") + try: + fileobj.write("1,2,abc,4,5,6\r\n1,2,abc\r\n") + fileobj.seek(0) + reader = csv.NamedTupleReader(fileobj, + fieldnames="one two three four five six", + restval="DEFAULT") + + self.assertEqual(reader.next()._asdict(), + dict(one='1', two='2', three='abc', + four='4', five='5', six='6')) + + self.assertEqual(reader.next()._asdict(), + dict(one='1', two='2', three='abc', + four='DEFAULT', five='DEFAULT', + six='DEFAULT')) + finally: + fileobj.close() + os.unlink(name) + + def test_read_multi(self): + sample = [ + '2147483648,43.0e12,17,abc,def\r\n', + '147483648,43.0e2,17,abc,def\r\n', + '47483648,43.0,170,abc,def\r\n' + ] + + reader = csv.NamedTupleReader(sample, + fieldnames="i1 float i2 s1 s2".split()) + self.assertEqual(reader.next()._asdict(), {"i1": '2147483648', + "float": '43.0e12', + "i2": '17', + "s1": 'abc', + "s2": 'def'}) + + def test_read_with_blanks(self): + reader = csv.NamedTupleReader(["1,2,abc,4,5,6\r\n","\r\n", + "1,2,abc,4,5,6\r\n"], + fieldnames="one two three four five six") + + Fields = collections.namedtuple('Fields', 'one two three four five six') + fieldnames = Fields(one='1', two='2', three='abc', four='4', five='5', + six='6') + + self.assertEqual(reader.next(), fieldnames) + self.assertEqual(reader.next(), fieldnames) + + def test_read_semi_sep(self): + reader = csv.NamedTupleReader(["1;2;abc;4;5;6\r\n"], + fieldnames="one two three four five six", + delimiter=';') + self.assertEqual(reader.next()._asdict(), dict(one='1', two='2', + three='abc', four='4', five='5', six='6')) + +#--------------------------------------------------------------------------------- + class TestDictFields(unittest.TestCase): ### "long" means the row is longer than the number of fieldnames ### "short" means there are fewer elements in the row than fieldnames