diff --git a/Doc/library/csv.rst b/Doc/library/csv.rst index e7516b67f92..325a4219938 100644 --- a/Doc/library/csv.rst +++ b/Doc/library/csv.rst @@ -419,7 +419,7 @@ Writer Objects :class:`Writer` objects (:class:`DictWriter` instances and objects returned by the :func:`writer` function) have the following public methods. A *row* must be -a sequence of strings or numbers for :class:`Writer` objects and a dictionary +an iterable of strings or numbers for :class:`Writer` objects and a dictionary mapping fieldnames to strings or numbers (by passing them through :func:`str` first) for :class:`DictWriter` objects. Note that complex numbers are written out surrounded by parens. This may cause some problems for other programs which @@ -431,6 +431,8 @@ read CSV files (assuming they support complex numbers at all). Write the *row* parameter to the writer's file object, formatted according to the current dialect. + .. versionchanged:: 3.5 + Added support of arbitrary iterables. .. method:: csvwriter.writerows(rows) diff --git a/Lib/csv.py b/Lib/csv.py index c3c31f01fd0..ca40e5e0efc 100644 --- a/Lib/csv.py +++ b/Lib/csv.py @@ -147,16 +147,13 @@ class DictWriter: if wrong_fields: raise ValueError("dict contains fields not in fieldnames: " + ", ".join([repr(x) for x in wrong_fields])) - return [rowdict.get(key, self.restval) for key in self.fieldnames] + return (rowdict.get(key, self.restval) for key in self.fieldnames) def writerow(self, rowdict): return self.writer.writerow(self._dict_to_list(rowdict)) def writerows(self, rowdicts): - rows = [] - for rowdict in rowdicts: - rows.append(self._dict_to_list(rowdict)) - return self.writer.writerows(rows) + return self.writer.writerows(map(self._dict_to_list, rowdicts)) # Guard Sniffer's type checking against builds that exclude complex() try: diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py index 41ef790eb25..7be3cc3c06d 100644 --- a/Lib/test/test_csv.py +++ b/Lib/test/test_csv.py @@ -186,6 +186,14 @@ class Test_Csv(unittest.TestCase): self._write_test(['a',1,'p,q'], 'a,1,p\\,q', escapechar='\\', quoting = csv.QUOTE_NONE) + def test_write_iterable(self): + self._write_test(iter(['a', 1, 'p,q']), 'a,1,"p,q"') + self._write_test(iter(['a', 1, None]), 'a,1,') + self._write_test(iter([]), '') + self._write_test(iter([None]), '""') + self._write_error_test(csv.Error, iter([None]), quoting=csv.QUOTE_NONE) + self._write_test(iter([None, None]), ',') + def test_writerows(self): class BrokenFile: def write(self, buf): diff --git a/Misc/NEWS b/Misc/NEWS index ff81056a1c2..0aae52cbd40 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -56,6 +56,8 @@ Core and Builtins Library ------- +- Issue #23171: csv.Writer.writerow() now supports arbitrary iterables. + - Issue #23745: The new email header parser now handles duplicate MIME parameter names without error, similar to how get_param behaves. diff --git a/Modules/_csv.c b/Modules/_csv.c index ade35e5bf71..eb886264d7a 100644 --- a/Modules/_csv.c +++ b/Modules/_csv.c @@ -1009,7 +1009,7 @@ join_reset(WriterObj *self) */ static Py_ssize_t join_append_data(WriterObj *self, unsigned int field_kind, void *field_data, - Py_ssize_t field_len, int quote_empty, int *quoted, + Py_ssize_t field_len, int *quoted, int copy_phase) { DialectObj *dialect = self->dialect; @@ -1071,18 +1071,6 @@ join_append_data(WriterObj *self, unsigned int field_kind, void *field_data, ADDCH(c); } - /* If field is empty check if it needs to be quoted. - */ - if (i == 0 && quote_empty) { - if (dialect->quoting == QUOTE_NONE) { - PyErr_Format(_csvstate_global->error_obj, - "single empty field record must be quoted"); - return -1; - } - else - *quoted = 1; - } - if (*quoted) { if (copy_phase) ADDCH(dialect->quotechar); @@ -1126,7 +1114,7 @@ join_check_rec_size(WriterObj *self, Py_ssize_t rec_len) } static int -join_append(WriterObj *self, PyObject *field, int *quoted, int quote_empty) +join_append(WriterObj *self, PyObject *field, int quoted) { unsigned int field_kind = -1; void *field_data = NULL; @@ -1141,7 +1129,7 @@ join_append(WriterObj *self, PyObject *field, int *quoted, int quote_empty) field_len = PyUnicode_GET_LENGTH(field); } rec_len = join_append_data(self, field_kind, field_data, field_len, - quote_empty, quoted, 0); + "ed, 0); if (rec_len < 0) return 0; @@ -1150,7 +1138,7 @@ join_append(WriterObj *self, PyObject *field, int *quoted, int quote_empty) return 0; self->rec_len = join_append_data(self, field_kind, field_data, field_len, - quote_empty, quoted, 1); + "ed, 1); self->num_fields++; return 1; @@ -1181,37 +1169,30 @@ join_append_lineterminator(WriterObj *self) } PyDoc_STRVAR(csv_writerow_doc, -"writerow(sequence)\n" +"writerow(iterable)\n" "\n" -"Construct and write a CSV record from a sequence of fields. Non-string\n" +"Construct and write a CSV record from an iterable of fields. Non-string\n" "elements will be converted to string."); static PyObject * csv_writerow(WriterObj *self, PyObject *seq) { DialectObj *dialect = self->dialect; - Py_ssize_t len, i; - PyObject *line, *result; + PyObject *iter, *field, *line, *result; - if (!PySequence_Check(seq)) - return PyErr_Format(_csvstate_global->error_obj, "sequence expected"); - - len = PySequence_Length(seq); - if (len < 0) - return NULL; + iter = PyObject_GetIter(seq); + if (iter == NULL) + return PyErr_Format(_csvstate_global->error_obj, + "iterable expected, not %.200s", + seq->ob_type->tp_name); /* Join all fields in internal buffer. */ join_reset(self); - for (i = 0; i < len; i++) { - PyObject *field; + while ((field = PyIter_Next(iter))) { int append_ok; int quoted; - field = PySequence_GetItem(seq, i); - if (field == NULL) - return NULL; - switch (dialect->quoting) { case QUOTE_NONNUMERIC: quoted = !PyNumber_Check(field); @@ -1225,11 +1206,11 @@ csv_writerow(WriterObj *self, PyObject *seq) } if (PyUnicode_Check(field)) { - append_ok = join_append(self, field, "ed, len == 1); + append_ok = join_append(self, field, quoted); Py_DECREF(field); } else if (field == Py_None) { - append_ok = join_append(self, NULL, "ed, len == 1); + append_ok = join_append(self, NULL, quoted); Py_DECREF(field); } else { @@ -1237,19 +1218,37 @@ csv_writerow(WriterObj *self, PyObject *seq) str = PyObject_Str(field); Py_DECREF(field); - if (str == NULL) + if (str == NULL) { + Py_DECREF(iter); return NULL; - append_ok = join_append(self, str, "ed, len == 1); + } + append_ok = join_append(self, str, quoted); Py_DECREF(str); } - if (!append_ok) + if (!append_ok) { + Py_DECREF(iter); + return NULL; + } + } + Py_DECREF(iter); + if (PyErr_Occurred()) + return NULL; + + if (self->num_fields > 0 && self->rec_size == 0) { + if (dialect->quoting == QUOTE_NONE) { + PyErr_Format(_csvstate_global->error_obj, + "single empty field record must be quoted"); + return NULL; + } + self->num_fields--; + if (!join_append(self, NULL, 1)) return NULL; } /* Add line terminator. */ if (!join_append_lineterminator(self)) - return 0; + return NULL; line = PyUnicode_FromKindAndData(PyUnicode_4BYTE_KIND, (void *) self->rec, self->rec_len); @@ -1261,9 +1260,9 @@ csv_writerow(WriterObj *self, PyObject *seq) } PyDoc_STRVAR(csv_writerows_doc, -"writerows(sequence of sequences)\n" +"writerows(iterable of iterables)\n" "\n" -"Construct and write a series of sequences to a csv file. Non-string\n" +"Construct and write a series of iterables to a csv file. Non-string\n" "elements will be converted to string."); static PyObject *