diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index 2b0c0179707..92e44d5f64d 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -408,6 +408,29 @@ class BuiltinTest(unittest.TestCase): unicode("345") ) + def test_filter_subclasses(self): + # test, that filter() never returns tuple, str or unicode subclasses + funcs = (None, lambda x: True) + class tuple2(tuple): + pass + class str2(str): + pass + inputs = { + tuple2: [(), (1,2,3)], + str2: ["", "123"] + } + if have_unicode: + class unicode2(unicode): + pass + inputs[unicode2] = [unicode(), unicode("123")] + + for func in funcs: + for (cls, inps) in inputs.iteritems(): + for inp in inps: + out = filter(func, cls(inp)) + self.assertEqual(inp, out) + self.assert_(not isinstance(out, cls)) + def test_float(self): self.assertEqual(float(3.14), 3.14) self.assertEqual(float(314), 314.0) diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c index 2383b4fdd07..b74e09c26ac 100644 --- a/Python/bltinmodule.c +++ b/Python/bltinmodule.c @@ -1838,7 +1838,10 @@ filtertuple(PyObject *func, PyObject *tuple) int len = PyTuple_Size(tuple); if (len == 0) { - Py_INCREF(tuple); + if (PyTuple_CheckExact(tuple)) + Py_INCREF(tuple); + else + tuple = PyTuple_New(0); return tuple; } @@ -1895,8 +1898,15 @@ filterstring(PyObject *func, PyObject *strobj) int outlen = len; if (func == Py_None) { - /* No character is ever false -- share input string */ - Py_INCREF(strobj); + /* No character is ever false -- share input string + * (if it's not a subclass) */ + if (PyString_CheckExact(strobj)) + Py_INCREF(strobj); + else + strobj = PyString_FromStringAndSize( + PyString_AS_STRING(strobj), + len + ); return strobj; } if ((result = PyString_FromStringAndSize(NULL, len)) == NULL) @@ -1980,8 +1990,15 @@ filterunicode(PyObject *func, PyObject *strobj) int outlen = len; if (func == Py_None) { - /* No character is ever false -- share input string */ - Py_INCREF(strobj); + /* No character is ever false -- share input string + * (it if's not a subclass) */ + if (PyUnicode_CheckExact(strobj)) + Py_INCREF(strobj); + else + strobj = PyUnicode_FromUnicode( + PyUnicode_AS_UNICODE(strobj), + len + ); return strobj; } if ((result = PyUnicode_FromUnicode(NULL, len)) == NULL)