Issue #11707: Fast C version of functools.cmp_to_key()
This commit is contained in:
parent
271b27e5fe
commit
7ab9e22e34
|
@ -97,7 +97,7 @@ def cmp_to_key(mycmp):
|
||||||
"""Convert a cmp= function into a key= function"""
|
"""Convert a cmp= function into a key= function"""
|
||||||
class K(object):
|
class K(object):
|
||||||
__slots__ = ['obj']
|
__slots__ = ['obj']
|
||||||
def __init__(self, obj, *args):
|
def __init__(self, obj):
|
||||||
self.obj = obj
|
self.obj = obj
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
return mycmp(self.obj, other.obj) < 0
|
return mycmp(self.obj, other.obj) < 0
|
||||||
|
@ -115,6 +115,11 @@ def cmp_to_key(mycmp):
|
||||||
raise TypeError('hash not implemented')
|
raise TypeError('hash not implemented')
|
||||||
return K
|
return K
|
||||||
|
|
||||||
|
try:
|
||||||
|
from _functools import cmp_to_key
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
_CacheInfo = namedtuple("CacheInfo", "hits misses maxsize currsize")
|
_CacheInfo = namedtuple("CacheInfo", "hits misses maxsize currsize")
|
||||||
|
|
||||||
def lru_cache(maxsize=100):
|
def lru_cache(maxsize=100):
|
||||||
|
|
|
@ -435,18 +435,81 @@ class TestReduce(unittest.TestCase):
|
||||||
self.assertEqual(self.func(add, d), "".join(d.keys()))
|
self.assertEqual(self.func(add, d), "".join(d.keys()))
|
||||||
|
|
||||||
class TestCmpToKey(unittest.TestCase):
|
class TestCmpToKey(unittest.TestCase):
|
||||||
|
|
||||||
def test_cmp_to_key(self):
|
def test_cmp_to_key(self):
|
||||||
|
def cmp1(x, y):
|
||||||
|
return (x > y) - (x < y)
|
||||||
|
key = functools.cmp_to_key(cmp1)
|
||||||
|
self.assertEqual(key(3), key(3))
|
||||||
|
self.assertGreater(key(3), key(1))
|
||||||
|
def cmp2(x, y):
|
||||||
|
return int(x) - int(y)
|
||||||
|
key = functools.cmp_to_key(cmp2)
|
||||||
|
self.assertEqual(key(4.0), key('4'))
|
||||||
|
self.assertLess(key(2), key('35'))
|
||||||
|
|
||||||
|
def test_cmp_to_key_arguments(self):
|
||||||
|
def cmp1(x, y):
|
||||||
|
return (x > y) - (x < y)
|
||||||
|
key = functools.cmp_to_key(mycmp=cmp1)
|
||||||
|
self.assertEqual(key(obj=3), key(obj=3))
|
||||||
|
self.assertGreater(key(obj=3), key(obj=1))
|
||||||
|
with self.assertRaises((TypeError, AttributeError)):
|
||||||
|
key(3) > 1 # rhs is not a K object
|
||||||
|
with self.assertRaises((TypeError, AttributeError)):
|
||||||
|
1 < key(3) # lhs is not a K object
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
key = functools.cmp_to_key() # too few args
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
key = functools.cmp_to_key(cmp1, None) # too many args
|
||||||
|
key = functools.cmp_to_key(cmp1)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
key() # too few args
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
key(None, None) # too many args
|
||||||
|
|
||||||
|
def test_bad_cmp(self):
|
||||||
|
def cmp1(x, y):
|
||||||
|
raise ZeroDivisionError
|
||||||
|
key = functools.cmp_to_key(cmp1)
|
||||||
|
with self.assertRaises(ZeroDivisionError):
|
||||||
|
key(3) > key(1)
|
||||||
|
|
||||||
|
class BadCmp:
|
||||||
|
def __lt__(self, other):
|
||||||
|
raise ZeroDivisionError
|
||||||
|
def cmp1(x, y):
|
||||||
|
return BadCmp()
|
||||||
|
with self.assertRaises(ZeroDivisionError):
|
||||||
|
key(3) > key(1)
|
||||||
|
|
||||||
|
def test_obj_field(self):
|
||||||
|
def cmp1(x, y):
|
||||||
|
return (x > y) - (x < y)
|
||||||
|
key = functools.cmp_to_key(mycmp=cmp1)
|
||||||
|
self.assertEqual(key(50).obj, 50)
|
||||||
|
|
||||||
|
def test_sort_int(self):
|
||||||
def mycmp(x, y):
|
def mycmp(x, y):
|
||||||
return y - x
|
return y - x
|
||||||
self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
|
self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
|
||||||
[4, 3, 2, 1, 0])
|
[4, 3, 2, 1, 0])
|
||||||
|
|
||||||
|
def test_sort_int_str(self):
|
||||||
|
def mycmp(x, y):
|
||||||
|
x, y = int(x), int(y)
|
||||||
|
return (x > y) - (x < y)
|
||||||
|
values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
|
||||||
|
values = sorted(values, key=functools.cmp_to_key(mycmp))
|
||||||
|
self.assertEqual([int(value) for value in values],
|
||||||
|
[0, 1, 1, 2, 3, 4, 5, 7, 10])
|
||||||
|
|
||||||
def test_hash(self):
|
def test_hash(self):
|
||||||
def mycmp(x, y):
|
def mycmp(x, y):
|
||||||
return y - x
|
return y - x
|
||||||
key = functools.cmp_to_key(mycmp)
|
key = functools.cmp_to_key(mycmp)
|
||||||
k = key(10)
|
k = key(10)
|
||||||
self.assertRaises(TypeError, hash(k))
|
self.assertRaises(TypeError, hash, k)
|
||||||
|
|
||||||
class TestTotalOrdering(unittest.TestCase):
|
class TestTotalOrdering(unittest.TestCase):
|
||||||
|
|
||||||
|
@ -655,6 +718,7 @@ class TestLRU(unittest.TestCase):
|
||||||
|
|
||||||
def test_main(verbose=None):
|
def test_main(verbose=None):
|
||||||
test_classes = (
|
test_classes = (
|
||||||
|
TestCmpToKey,
|
||||||
TestPartial,
|
TestPartial,
|
||||||
TestPartialSubclass,
|
TestPartialSubclass,
|
||||||
TestPythonPartial,
|
TestPythonPartial,
|
||||||
|
|
|
@ -97,6 +97,9 @@ Library
|
||||||
- Issue #10791: Implement missing method GzipFile.read1(), allowing GzipFile
|
- Issue #10791: Implement missing method GzipFile.read1(), allowing GzipFile
|
||||||
to be wrapped in a TextIOWrapper. Patch by Nadeem Vawda.
|
to be wrapped in a TextIOWrapper. Patch by Nadeem Vawda.
|
||||||
|
|
||||||
|
- Issue #11707: Added a fast C version of functools.cmp_to_key().
|
||||||
|
Patch by Filip Gruszczyński.
|
||||||
|
|
||||||
- Issue #11688: Add sqlite3.Connection.set_trace_callback(). Patch by
|
- Issue #11688: Add sqlite3.Connection.set_trace_callback(). Patch by
|
||||||
Torsten Landschoff.
|
Torsten Landschoff.
|
||||||
|
|
||||||
|
|
|
@ -330,6 +330,165 @@ static PyTypeObject partial_type = {
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/* cmp_to_key ***************************************************************/
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
PyObject_HEAD;
|
||||||
|
PyObject *cmp;
|
||||||
|
PyObject *object;
|
||||||
|
} keyobject;
|
||||||
|
|
||||||
|
static void
|
||||||
|
keyobject_dealloc(keyobject *ko)
|
||||||
|
{
|
||||||
|
Py_DECREF(ko->cmp);
|
||||||
|
Py_XDECREF(ko->object);
|
||||||
|
PyObject_FREE(ko);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int
|
||||||
|
keyobject_traverse(keyobject *ko, visitproc visit, void *arg)
|
||||||
|
{
|
||||||
|
Py_VISIT(ko->cmp);
|
||||||
|
if (ko->object)
|
||||||
|
Py_VISIT(ko->object);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyMemberDef keyobject_members[] = {
|
||||||
|
{"obj", T_OBJECT,
|
||||||
|
offsetof(keyobject, object), 0,
|
||||||
|
PyDoc_STR("Value wrapped by a key function.")},
|
||||||
|
{NULL}
|
||||||
|
};
|
||||||
|
|
||||||
|
static PyObject *
|
||||||
|
keyobject_call(keyobject *ko, PyObject *args, PyObject *kw);
|
||||||
|
|
||||||
|
static PyObject *
|
||||||
|
keyobject_richcompare(PyObject *ko, PyObject *other, int op);
|
||||||
|
|
||||||
|
static PyTypeObject keyobject_type = {
|
||||||
|
PyVarObject_HEAD_INIT(&PyType_Type, 0)
|
||||||
|
"functools.KeyWrapper", /* tp_name */
|
||||||
|
sizeof(keyobject), /* tp_basicsize */
|
||||||
|
0, /* tp_itemsize */
|
||||||
|
/* methods */
|
||||||
|
(destructor)keyobject_dealloc, /* tp_dealloc */
|
||||||
|
0, /* tp_print */
|
||||||
|
0, /* tp_getattr */
|
||||||
|
0, /* tp_setattr */
|
||||||
|
0, /* tp_reserved */
|
||||||
|
0, /* tp_repr */
|
||||||
|
0, /* tp_as_number */
|
||||||
|
0, /* tp_as_sequence */
|
||||||
|
0, /* tp_as_mapping */
|
||||||
|
0, /* tp_hash */
|
||||||
|
(ternaryfunc)keyobject_call, /* tp_call */
|
||||||
|
0, /* tp_str */
|
||||||
|
PyObject_GenericGetAttr, /* tp_getattro */
|
||||||
|
0, /* tp_setattro */
|
||||||
|
0, /* tp_as_buffer */
|
||||||
|
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
||||||
|
0, /* tp_doc */
|
||||||
|
(traverseproc)keyobject_traverse, /* tp_traverse */
|
||||||
|
0, /* tp_clear */
|
||||||
|
keyobject_richcompare, /* tp_richcompare */
|
||||||
|
0, /* tp_weaklistoffset */
|
||||||
|
0, /* tp_iter */
|
||||||
|
0, /* tp_iternext */
|
||||||
|
0, /* tp_methods */
|
||||||
|
keyobject_members, /* tp_members */
|
||||||
|
0, /* tp_getset */
|
||||||
|
};
|
||||||
|
|
||||||
|
static PyObject *
|
||||||
|
keyobject_call(keyobject *ko, PyObject *args, PyObject *kwds)
|
||||||
|
{
|
||||||
|
PyObject *object;
|
||||||
|
keyobject *result;
|
||||||
|
static char *kwargs[] = {"obj", NULL};
|
||||||
|
|
||||||
|
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:K", kwargs, &object))
|
||||||
|
return NULL;
|
||||||
|
result = PyObject_New(keyobject, &keyobject_type);
|
||||||
|
if (!result)
|
||||||
|
return NULL;
|
||||||
|
Py_INCREF(ko->cmp);
|
||||||
|
result->cmp = ko->cmp;
|
||||||
|
Py_INCREF(object);
|
||||||
|
result->object = object;
|
||||||
|
return (PyObject *)result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject *
|
||||||
|
keyobject_richcompare(PyObject *ko, PyObject *other, int op)
|
||||||
|
{
|
||||||
|
PyObject *res;
|
||||||
|
PyObject *args;
|
||||||
|
PyObject *x;
|
||||||
|
PyObject *y;
|
||||||
|
PyObject *compare;
|
||||||
|
PyObject *answer;
|
||||||
|
static PyObject *zero;
|
||||||
|
|
||||||
|
if (zero == NULL) {
|
||||||
|
zero = PyLong_FromLong(0);
|
||||||
|
if (!zero)
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Py_TYPE(other) != &keyobject_type){
|
||||||
|
PyErr_Format(PyExc_TypeError, "other argument must be K instance");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
compare = ((keyobject *) ko)->cmp;
|
||||||
|
assert(compare != NULL);
|
||||||
|
x = ((keyobject *) ko)->object;
|
||||||
|
y = ((keyobject *) other)->object;
|
||||||
|
if (!x || !y){
|
||||||
|
PyErr_Format(PyExc_AttributeError, "object");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Call the user's comparison function and translate the 3-way
|
||||||
|
* result into true or false (or error).
|
||||||
|
*/
|
||||||
|
args = PyTuple_New(2);
|
||||||
|
if (args == NULL)
|
||||||
|
return NULL;
|
||||||
|
Py_INCREF(x);
|
||||||
|
Py_INCREF(y);
|
||||||
|
PyTuple_SET_ITEM(args, 0, x);
|
||||||
|
PyTuple_SET_ITEM(args, 1, y);
|
||||||
|
res = PyObject_Call(compare, args, NULL);
|
||||||
|
Py_DECREF(args);
|
||||||
|
if (res == NULL)
|
||||||
|
return NULL;
|
||||||
|
answer = PyObject_RichCompare(res, zero, op);
|
||||||
|
Py_DECREF(res);
|
||||||
|
return answer;
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject *
|
||||||
|
functools_cmp_to_key(PyObject *self, PyObject *args, PyObject *kwds){
|
||||||
|
PyObject *cmp;
|
||||||
|
static char *kwargs[] = {"mycmp", NULL};
|
||||||
|
|
||||||
|
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:cmp_to_key", kwargs, &cmp))
|
||||||
|
return NULL;
|
||||||
|
keyobject *object = PyObject_New(keyobject, &keyobject_type);
|
||||||
|
if (!object)
|
||||||
|
return NULL;
|
||||||
|
Py_INCREF(cmp);
|
||||||
|
object->cmp = cmp;
|
||||||
|
object->object = NULL;
|
||||||
|
return (PyObject *)object;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyDoc_STRVAR(functools_cmp_to_key_doc,
|
||||||
|
"Convert a cmp= function into a key= function.");
|
||||||
|
|
||||||
/* reduce (used to be a builtin) ********************************************/
|
/* reduce (used to be a builtin) ********************************************/
|
||||||
|
|
||||||
static PyObject *
|
static PyObject *
|
||||||
|
@ -413,6 +572,8 @@ PyDoc_STRVAR(module_doc,
|
||||||
|
|
||||||
static PyMethodDef module_methods[] = {
|
static PyMethodDef module_methods[] = {
|
||||||
{"reduce", functools_reduce, METH_VARARGS, functools_reduce_doc},
|
{"reduce", functools_reduce, METH_VARARGS, functools_reduce_doc},
|
||||||
|
{"cmp_to_key", functools_cmp_to_key, METH_VARARGS | METH_KEYWORDS,
|
||||||
|
functools_cmp_to_key_doc},
|
||||||
{NULL, NULL} /* sentinel */
|
{NULL, NULL} /* sentinel */
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue