Issue #23171: csv.Writer.writerow() now supports arbitrary iterables.

This commit is contained in:
Serhiy Storchaka 2015-03-30 09:09:54 +03:00
parent a695f83f0d
commit 7901b48a1f
5 changed files with 54 additions and 46 deletions

View File

@ -419,7 +419,7 @@ Writer Objects
:class:`Writer` objects (:class:`DictWriter` instances and objects returned by :class:`Writer` objects (:class:`DictWriter` instances and objects returned by
the :func:`writer` function) have the following public methods. A *row* must be the :func:`writer` function) have the following public methods. A *row* must be
a sequence of strings or numbers for :class:`Writer` objects and a dictionary an iterable of strings or numbers for :class:`Writer` objects and a dictionary
mapping fieldnames to strings or numbers (by passing them through :func:`str` mapping fieldnames to strings or numbers (by passing them through :func:`str`
first) for :class:`DictWriter` objects. Note that complex numbers are written first) for :class:`DictWriter` objects. Note that complex numbers are written
out surrounded by parens. This may cause some problems for other programs which out surrounded by parens. This may cause some problems for other programs which
@ -431,6 +431,8 @@ read CSV files (assuming they support complex numbers at all).
Write the *row* parameter to the writer's file object, formatted according to Write the *row* parameter to the writer's file object, formatted according to
the current dialect. the current dialect.
.. versionchanged:: 3.5
Added support of arbitrary iterables.
.. method:: csvwriter.writerows(rows) .. method:: csvwriter.writerows(rows)

View File

@ -147,16 +147,13 @@ class DictWriter:
if wrong_fields: if wrong_fields:
raise ValueError("dict contains fields not in fieldnames: " raise ValueError("dict contains fields not in fieldnames: "
+ ", ".join([repr(x) for x in wrong_fields])) + ", ".join([repr(x) for x in wrong_fields]))
return [rowdict.get(key, self.restval) for key in self.fieldnames] return (rowdict.get(key, self.restval) for key in self.fieldnames)
def writerow(self, rowdict): def writerow(self, rowdict):
return self.writer.writerow(self._dict_to_list(rowdict)) return self.writer.writerow(self._dict_to_list(rowdict))
def writerows(self, rowdicts): def writerows(self, rowdicts):
rows = [] return self.writer.writerows(map(self._dict_to_list, rowdicts))
for rowdict in rowdicts:
rows.append(self._dict_to_list(rowdict))
return self.writer.writerows(rows)
# Guard Sniffer's type checking against builds that exclude complex() # Guard Sniffer's type checking against builds that exclude complex()
try: try:

View File

@ -186,6 +186,14 @@ class Test_Csv(unittest.TestCase):
self._write_test(['a',1,'p,q'], 'a,1,p\\,q', self._write_test(['a',1,'p,q'], 'a,1,p\\,q',
escapechar='\\', quoting = csv.QUOTE_NONE) escapechar='\\', quoting = csv.QUOTE_NONE)
def test_write_iterable(self):
self._write_test(iter(['a', 1, 'p,q']), 'a,1,"p,q"')
self._write_test(iter(['a', 1, None]), 'a,1,')
self._write_test(iter([]), '')
self._write_test(iter([None]), '""')
self._write_error_test(csv.Error, iter([None]), quoting=csv.QUOTE_NONE)
self._write_test(iter([None, None]), ',')
def test_writerows(self): def test_writerows(self):
class BrokenFile: class BrokenFile:
def write(self, buf): def write(self, buf):

View File

@ -56,6 +56,8 @@ Core and Builtins
Library Library
------- -------
- Issue #23171: csv.Writer.writerow() now supports arbitrary iterables.
- Issue #23745: The new email header parser now handles duplicate MIME - Issue #23745: The new email header parser now handles duplicate MIME
parameter names without error, similar to how get_param behaves. parameter names without error, similar to how get_param behaves.

View File

@ -1009,7 +1009,7 @@ join_reset(WriterObj *self)
*/ */
static Py_ssize_t static Py_ssize_t
join_append_data(WriterObj *self, unsigned int field_kind, void *field_data, join_append_data(WriterObj *self, unsigned int field_kind, void *field_data,
Py_ssize_t field_len, int quote_empty, int *quoted, Py_ssize_t field_len, int *quoted,
int copy_phase) int copy_phase)
{ {
DialectObj *dialect = self->dialect; DialectObj *dialect = self->dialect;
@ -1071,18 +1071,6 @@ join_append_data(WriterObj *self, unsigned int field_kind, void *field_data,
ADDCH(c); ADDCH(c);
} }
/* If field is empty check if it needs to be quoted.
*/
if (i == 0 && quote_empty) {
if (dialect->quoting == QUOTE_NONE) {
PyErr_Format(_csvstate_global->error_obj,
"single empty field record must be quoted");
return -1;
}
else
*quoted = 1;
}
if (*quoted) { if (*quoted) {
if (copy_phase) if (copy_phase)
ADDCH(dialect->quotechar); ADDCH(dialect->quotechar);
@ -1126,7 +1114,7 @@ join_check_rec_size(WriterObj *self, Py_ssize_t rec_len)
} }
static int static int
join_append(WriterObj *self, PyObject *field, int *quoted, int quote_empty) join_append(WriterObj *self, PyObject *field, int quoted)
{ {
unsigned int field_kind = -1; unsigned int field_kind = -1;
void *field_data = NULL; void *field_data = NULL;
@ -1141,7 +1129,7 @@ join_append(WriterObj *self, PyObject *field, int *quoted, int quote_empty)
field_len = PyUnicode_GET_LENGTH(field); field_len = PyUnicode_GET_LENGTH(field);
} }
rec_len = join_append_data(self, field_kind, field_data, field_len, rec_len = join_append_data(self, field_kind, field_data, field_len,
quote_empty, quoted, 0); &quoted, 0);
if (rec_len < 0) if (rec_len < 0)
return 0; return 0;
@ -1150,7 +1138,7 @@ join_append(WriterObj *self, PyObject *field, int *quoted, int quote_empty)
return 0; return 0;
self->rec_len = join_append_data(self, field_kind, field_data, field_len, self->rec_len = join_append_data(self, field_kind, field_data, field_len,
quote_empty, quoted, 1); &quoted, 1);
self->num_fields++; self->num_fields++;
return 1; return 1;
@ -1181,37 +1169,30 @@ join_append_lineterminator(WriterObj *self)
} }
PyDoc_STRVAR(csv_writerow_doc, PyDoc_STRVAR(csv_writerow_doc,
"writerow(sequence)\n" "writerow(iterable)\n"
"\n" "\n"
"Construct and write a CSV record from a sequence of fields. Non-string\n" "Construct and write a CSV record from an iterable of fields. Non-string\n"
"elements will be converted to string."); "elements will be converted to string.");
static PyObject * static PyObject *
csv_writerow(WriterObj *self, PyObject *seq) csv_writerow(WriterObj *self, PyObject *seq)
{ {
DialectObj *dialect = self->dialect; DialectObj *dialect = self->dialect;
Py_ssize_t len, i; PyObject *iter, *field, *line, *result;
PyObject *line, *result;
if (!PySequence_Check(seq)) iter = PyObject_GetIter(seq);
return PyErr_Format(_csvstate_global->error_obj, "sequence expected"); if (iter == NULL)
return PyErr_Format(_csvstate_global->error_obj,
len = PySequence_Length(seq); "iterable expected, not %.200s",
if (len < 0) seq->ob_type->tp_name);
return NULL;
/* Join all fields in internal buffer. /* Join all fields in internal buffer.
*/ */
join_reset(self); join_reset(self);
for (i = 0; i < len; i++) { while ((field = PyIter_Next(iter))) {
PyObject *field;
int append_ok; int append_ok;
int quoted; int quoted;
field = PySequence_GetItem(seq, i);
if (field == NULL)
return NULL;
switch (dialect->quoting) { switch (dialect->quoting) {
case QUOTE_NONNUMERIC: case QUOTE_NONNUMERIC:
quoted = !PyNumber_Check(field); quoted = !PyNumber_Check(field);
@ -1225,11 +1206,11 @@ csv_writerow(WriterObj *self, PyObject *seq)
} }
if (PyUnicode_Check(field)) { if (PyUnicode_Check(field)) {
append_ok = join_append(self, field, &quoted, len == 1); append_ok = join_append(self, field, quoted);
Py_DECREF(field); Py_DECREF(field);
} }
else if (field == Py_None) { else if (field == Py_None) {
append_ok = join_append(self, NULL, &quoted, len == 1); append_ok = join_append(self, NULL, quoted);
Py_DECREF(field); Py_DECREF(field);
} }
else { else {
@ -1237,19 +1218,37 @@ csv_writerow(WriterObj *self, PyObject *seq)
str = PyObject_Str(field); str = PyObject_Str(field);
Py_DECREF(field); Py_DECREF(field);
if (str == NULL) if (str == NULL) {
Py_DECREF(iter);
return NULL; return NULL;
append_ok = join_append(self, str, &quoted, len == 1); }
append_ok = join_append(self, str, quoted);
Py_DECREF(str); Py_DECREF(str);
} }
if (!append_ok) if (!append_ok) {
Py_DECREF(iter);
return NULL;
}
}
Py_DECREF(iter);
if (PyErr_Occurred())
return NULL;
if (self->num_fields > 0 && self->rec_size == 0) {
if (dialect->quoting == QUOTE_NONE) {
PyErr_Format(_csvstate_global->error_obj,
"single empty field record must be quoted");
return NULL;
}
self->num_fields--;
if (!join_append(self, NULL, 1))
return NULL; return NULL;
} }
/* Add line terminator. /* Add line terminator.
*/ */
if (!join_append_lineterminator(self)) if (!join_append_lineterminator(self))
return 0; return NULL;
line = PyUnicode_FromKindAndData(PyUnicode_4BYTE_KIND, line = PyUnicode_FromKindAndData(PyUnicode_4BYTE_KIND,
(void *) self->rec, self->rec_len); (void *) self->rec, self->rec_len);
@ -1261,9 +1260,9 @@ csv_writerow(WriterObj *self, PyObject *seq)
} }
PyDoc_STRVAR(csv_writerows_doc, PyDoc_STRVAR(csv_writerows_doc,
"writerows(sequence of sequences)\n" "writerows(iterable of iterables)\n"
"\n" "\n"
"Construct and write a series of sequences to a csv file. Non-string\n" "Construct and write a series of iterables to a csv file. Non-string\n"
"elements will be converted to string."); "elements will be converted to string.");
static PyObject * static PyObject *