bpo-16575: Add checks for unions passed by value to functions. (GH-16799)
This commit is contained in:
parent
bdac32e9fe
commit
79d4ed102a
|
@ -576,6 +576,86 @@ class StructureTestCase(unittest.TestCase):
|
|||
self.assertEqual(f2, [0x4567, 0x0123, 0xcdef, 0x89ab,
|
||||
0x3210, 0x7654, 0xba98, 0xfedc])
|
||||
|
||||
def test_union_by_value(self):
|
||||
# See bpo-16575
|
||||
|
||||
# These should mirror the structures in Modules/_ctypes/_ctypes_test.c
|
||||
|
||||
class Nested1(Structure):
|
||||
_fields_ = [
|
||||
('an_int', c_int),
|
||||
('another_int', c_int),
|
||||
]
|
||||
|
||||
class Test4(Union):
|
||||
_fields_ = [
|
||||
('a_long', c_long),
|
||||
('a_struct', Nested1),
|
||||
]
|
||||
|
||||
class Nested2(Structure):
|
||||
_fields_ = [
|
||||
('an_int', c_int),
|
||||
('a_union', Test4),
|
||||
]
|
||||
|
||||
class Test5(Structure):
|
||||
_fields_ = [
|
||||
('an_int', c_int),
|
||||
('nested', Nested2),
|
||||
('another_int', c_int),
|
||||
]
|
||||
|
||||
test4 = Test4()
|
||||
dll = CDLL(_ctypes_test.__file__)
|
||||
with self.assertRaises(TypeError) as ctx:
|
||||
func = dll._testfunc_union_by_value1
|
||||
func.restype = c_long
|
||||
func.argtypes = (Test4,)
|
||||
result = func(test4)
|
||||
self.assertEqual(ctx.exception.args[0], 'item 1 in _argtypes_ passes '
|
||||
'a union by value, which is unsupported.')
|
||||
test5 = Test5()
|
||||
with self.assertRaises(TypeError) as ctx:
|
||||
func = dll._testfunc_union_by_value2
|
||||
func.restype = c_long
|
||||
func.argtypes = (Test5,)
|
||||
result = func(test5)
|
||||
self.assertEqual(ctx.exception.args[0], 'item 1 in _argtypes_ passes '
|
||||
'a union by value, which is unsupported.')
|
||||
|
||||
# passing by reference should be OK
|
||||
test4.a_long = 12345;
|
||||
func = dll._testfunc_union_by_reference1
|
||||
func.restype = c_long
|
||||
func.argtypes = (POINTER(Test4),)
|
||||
result = func(byref(test4))
|
||||
self.assertEqual(result, 12345)
|
||||
self.assertEqual(test4.a_long, 0)
|
||||
self.assertEqual(test4.a_struct.an_int, 0)
|
||||
self.assertEqual(test4.a_struct.another_int, 0)
|
||||
test4.a_struct.an_int = 0x12340000
|
||||
test4.a_struct.another_int = 0x5678
|
||||
func = dll._testfunc_union_by_reference2
|
||||
func.restype = c_long
|
||||
func.argtypes = (POINTER(Test4),)
|
||||
result = func(byref(test4))
|
||||
self.assertEqual(result, 0x12345678)
|
||||
self.assertEqual(test4.a_long, 0)
|
||||
self.assertEqual(test4.a_struct.an_int, 0)
|
||||
self.assertEqual(test4.a_struct.another_int, 0)
|
||||
test5.an_int = 0x12000000
|
||||
test5.nested.an_int = 0x345600
|
||||
test5.another_int = 0x78
|
||||
func = dll._testfunc_union_by_reference3
|
||||
func.restype = c_long
|
||||
func.argtypes = (POINTER(Test5),)
|
||||
result = func(byref(test5))
|
||||
self.assertEqual(result, 0x12345678)
|
||||
self.assertEqual(test5.an_int, 0)
|
||||
self.assertEqual(test5.nested.an_int, 0)
|
||||
self.assertEqual(test5.another_int, 0)
|
||||
|
||||
class PointerMemberTestCase(unittest.TestCase):
|
||||
|
||||
def test(self):
|
||||
|
|
|
@ -504,6 +504,9 @@ StructUnionType_new(PyTypeObject *type, PyObject *args, PyObject *kwds, int isSt
|
|||
Py_DECREF(result);
|
||||
return NULL;
|
||||
}
|
||||
if (!isStruct) {
|
||||
dict->flags |= TYPEFLAG_HASUNION;
|
||||
}
|
||||
/* replace the class dict by our updated stgdict, which holds info
|
||||
about storage requirements of the instances */
|
||||
if (-1 == PyDict_Update((PyObject *)dict, result->tp_dict)) {
|
||||
|
@ -2383,6 +2386,27 @@ converters_from_argtypes(PyObject *ob)
|
|||
for (i = 0; i < nArgs; ++i) {
|
||||
PyObject *cnv;
|
||||
PyObject *tp = PyTuple_GET_ITEM(ob, i);
|
||||
StgDictObject *stgdict = PyType_stgdict(tp);
|
||||
|
||||
if (stgdict != NULL) {
|
||||
if (stgdict->flags & TYPEFLAG_HASUNION) {
|
||||
Py_DECREF(converters);
|
||||
Py_DECREF(ob);
|
||||
if (!PyErr_Occurred()) {
|
||||
PyErr_Format(PyExc_TypeError,
|
||||
"item %zd in _argtypes_ passes a union by "
|
||||
"value, which is unsupported.",
|
||||
i + 1);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
/*
|
||||
if (stgdict->flags & TYPEFLAG_HASBITFIELD) {
|
||||
printf("found stgdict with bitfield\n");
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
if (_PyObject_LookupAttrId(tp, &PyId_from_param, &cnv) <= 0) {
|
||||
Py_DECREF(converters);
|
||||
Py_DECREF(ob);
|
||||
|
|
|
@ -131,6 +131,69 @@ _testfunc_array_in_struct2a(Test3B in)
|
|||
return result;
|
||||
}
|
||||
|
||||
typedef union {
|
||||
long a_long;
|
||||
struct {
|
||||
int an_int;
|
||||
int another_int;
|
||||
} a_struct;
|
||||
} Test4;
|
||||
|
||||
typedef struct {
|
||||
int an_int;
|
||||
struct {
|
||||
int an_int;
|
||||
Test4 a_union;
|
||||
} nested;
|
||||
int another_int;
|
||||
} Test5;
|
||||
|
||||
EXPORT(long)
|
||||
_testfunc_union_by_value1(Test4 in) {
|
||||
long result = in.a_long + in.a_struct.an_int + in.a_struct.another_int;
|
||||
|
||||
/* As the union/struct are passed by value, changes to them shouldn't be
|
||||
* reflected in the caller.
|
||||
*/
|
||||
memset(&in, 0, sizeof(in));
|
||||
return result;
|
||||
}
|
||||
|
||||
EXPORT(long)
|
||||
_testfunc_union_by_value2(Test5 in) {
|
||||
long result = in.an_int + in.nested.an_int;
|
||||
|
||||
/* As the union/struct are passed by value, changes to them shouldn't be
|
||||
* reflected in the caller.
|
||||
*/
|
||||
memset(&in, 0, sizeof(in));
|
||||
return result;
|
||||
}
|
||||
|
||||
EXPORT(long)
|
||||
_testfunc_union_by_reference1(Test4 *in) {
|
||||
long result = in->a_long;
|
||||
|
||||
memset(in, 0, sizeof(Test4));
|
||||
return result;
|
||||
}
|
||||
|
||||
EXPORT(long)
|
||||
_testfunc_union_by_reference2(Test4 *in) {
|
||||
long result = in->a_struct.an_int + in->a_struct.another_int;
|
||||
|
||||
memset(in, 0, sizeof(Test4));
|
||||
return result;
|
||||
}
|
||||
|
||||
EXPORT(long)
|
||||
_testfunc_union_by_reference3(Test5 *in) {
|
||||
long result = in->an_int + in->nested.an_int + in->another_int;
|
||||
|
||||
memset(in, 0, sizeof(Test5));
|
||||
return result;
|
||||
}
|
||||
|
||||
EXPORT(void)testfunc_array(int values[4])
|
||||
{
|
||||
printf("testfunc_array %d %d %d %d\n",
|
||||
|
|
|
@ -288,6 +288,8 @@ PyObject *_ctypes_callproc(PPROC pProc,
|
|||
|
||||
#define TYPEFLAG_ISPOINTER 0x100
|
||||
#define TYPEFLAG_HASPOINTER 0x200
|
||||
#define TYPEFLAG_HASUNION 0x400
|
||||
#define TYPEFLAG_HASBITFIELD 0x800
|
||||
|
||||
#define DICTFLAG_FINAL 0x1000
|
||||
|
||||
|
|
|
@ -440,6 +440,13 @@ PyCStructUnionType_update_stgdict(PyObject *type, PyObject *fields, int isStruct
|
|||
PyMem_Free(stgdict->ffi_type_pointer.elements);
|
||||
|
||||
basedict = PyType_stgdict((PyObject *)((PyTypeObject *)type)->tp_base);
|
||||
if (basedict) {
|
||||
stgdict->flags |= (basedict->flags &
|
||||
(TYPEFLAG_HASUNION | TYPEFLAG_HASBITFIELD));
|
||||
}
|
||||
if (!isStruct) {
|
||||
stgdict->flags |= TYPEFLAG_HASUNION;
|
||||
}
|
||||
if (basedict && !use_broken_old_ctypes_semantics) {
|
||||
size = offset = basedict->size;
|
||||
align = basedict->align;
|
||||
|
@ -515,8 +522,10 @@ PyCStructUnionType_update_stgdict(PyObject *type, PyObject *fields, int isStruct
|
|||
stgdict->ffi_type_pointer.elements[ffi_ofs + i] = &dict->ffi_type_pointer;
|
||||
if (dict->flags & (TYPEFLAG_ISPOINTER | TYPEFLAG_HASPOINTER))
|
||||
stgdict->flags |= TYPEFLAG_HASPOINTER;
|
||||
stgdict->flags |= dict->flags & (TYPEFLAG_HASUNION | TYPEFLAG_HASBITFIELD);
|
||||
dict->flags |= DICTFLAG_FINAL; /* mark field type final */
|
||||
if (PyTuple_Size(pair) == 3) { /* bits specified */
|
||||
stgdict->flags |= TYPEFLAG_HASBITFIELD;
|
||||
switch(dict->ffi_type_pointer.type) {
|
||||
case FFI_TYPE_UINT8:
|
||||
case FFI_TYPE_UINT16:
|
||||
|
|
Loading…
Reference in New Issue