diff --git a/Lib/test/test_ctypes/test_cast.py b/Lib/test/test_ctypes/test_cast.py index 604f44f03d61b2..e5ccccbc5ea582 100644 --- a/Lib/test/test_ctypes/test_cast.py +++ b/Lib/test/test_ctypes/test_cast.py @@ -1,8 +1,8 @@ import sys import unittest -from ctypes import (Structure, Union, POINTER, cast, sizeof, addressof, - c_void_p, c_char_p, c_wchar_p, - c_byte, c_short, c_int) +from ctypes import (Structure, Union, pointer, POINTER, sizeof, addressof, + c_void_p, c_char_p, c_wchar_p, cast, + c_byte, c_short, c_int, c_int16) class Test(unittest.TestCase): @@ -95,6 +95,100 @@ class MyUnion(Union): _fields_ = [("a", c_int)] self.assertRaises(TypeError, cast, array, MyUnion) + def test_pointer_identity(self): + class Struct(Structure): + _fields_ = [('a', c_int16)] + Struct3 = 3 * Struct + c_array = (2 * Struct3)( + Struct3(Struct(a=1), Struct(a=2), Struct(a=3)), + Struct3(Struct(a=4), Struct(a=5), Struct(a=6)) + ) + self.assertEqual(c_array[0][0].a, 1) + self.assertEqual(c_array[0][1].a, 2) + self.assertEqual(c_array[0][2].a, 3) + self.assertEqual(c_array[1][0].a, 4) + self.assertEqual(c_array[1][1].a, 5) + self.assertEqual(c_array[1][2].a, 6) + p_obj = cast(pointer(c_array), POINTER(pointer(c_array)._type_)) + obj = p_obj.contents + self.assertEqual(obj[0][0].a, 1) + self.assertEqual(obj[0][1].a, 2) + self.assertEqual(obj[0][2].a, 3) + self.assertEqual(obj[1][0].a, 4) + self.assertEqual(obj[1][1].a, 5) + self.assertEqual(obj[1][2].a, 6) + p_obj = cast(pointer(c_array[0]), POINTER(pointer(c_array)._type_)) + obj = p_obj.contents + self.assertEqual(obj[0][0].a, 1) + self.assertEqual(obj[0][1].a, 2) + self.assertEqual(obj[0][2].a, 3) + self.assertEqual(obj[1][0].a, 4) + self.assertEqual(obj[1][1].a, 5) + self.assertEqual(obj[1][2].a, 6) + StructPointer = POINTER(Struct) + s1 = Struct(a=10) + s2 = Struct(a=20) + s3 = Struct(a=30) + pointer_array = (3 * StructPointer)(pointer(s1), pointer(s2), pointer(s3)) + self.assertEqual(pointer_array[0][0].a, 10) + self.assertEqual(pointer_array[1][0].a, 20) + self.assertEqual(pointer_array[2][0].a, 30) + self.assertEqual(pointer_array[0].contents.a, 10) + self.assertEqual(pointer_array[1].contents.a, 20) + self.assertEqual(pointer_array[2].contents.a, 30) + p_obj = cast(pointer(pointer_array[0]), POINTER(pointer(pointer_array)._type_)) + obj = p_obj.contents + self.assertEqual(obj[0][0].a, 10) + self.assertEqual(obj[1][0].a, 20) + self.assertEqual(obj[2][0].a, 30) + self.assertEqual(obj[0].contents.a, 10) + self.assertEqual(obj[1].contents.a, 20) + self.assertEqual(obj[2].contents.a, 30) + class StructWithPointers(Structure): + _fields_ = [("s1", POINTER(Struct)), ("s2", POINTER(Struct))] + struct = StructWithPointers(s1=pointer(s1), s2=pointer(s2)) + p_obj = pointer(struct) + obj = p_obj.contents + self.assertEqual(obj.s1[0].a, 10) + self.assertEqual(obj.s2[0].a, 20) + self.assertEqual(obj.s1.contents.a, 10) + self.assertEqual(obj.s2.contents.a, 20) + p_obj = cast(pointer(struct), POINTER(pointer(pointer_array)._type_)) + obj = p_obj.contents + self.assertEqual(obj[0][0].a, 10) + self.assertEqual(obj[1][0].a, 20) + self.assertEqual(obj[0].contents.a, 10) + self.assertEqual(obj[1].contents.a, 20) + + def test_pointer_set_contents(self): + class Struct(Structure): + _fields_ = [('a', c_int16)] + p = pointer(Struct(a=23)) + self.assertEqual(p.contents.a, 23) + self.assertIs(p._type_, Struct) + cp = cast(p, POINTER(c_int16)) + self.assertEqual(cp.contents._type_, 'h') + cp.contents = c_int16(24) + self.assertEqual(cp.contents.value, 24) + self.assertEqual(p.contents.a, 24) + + pp = pointer(p) + self.assertIs(pp._type_, POINTER(Struct)) + + from code import interact; interact(local=locals()) + + cast(pp, POINTER(POINTER(c_int16))).contents.contents = c_int16(32) + + # self.assertIs(p.contents, pp.contents.contents) + + self.assertEqual(cast(p, POINTER(c_int16)).contents.value, 32) + self.assertEqual(p[0].a, 32) # works + self.assertEqual(pp[0].contents.a, 32) # works + self.assertEqual(pp.contents[0].a, 32) # works + + self.assertEqual(p.contents.a, 32) # fails, wat, holds 23 + self.assertEqual(pp.contents.contents.a, 32) # fails, wat, holds 23 + if __name__ == "__main__": unittest.main() diff --git a/Modules/_ctypes/_ctypes.c b/Modules/_ctypes/_ctypes.c index ed9efcad9ab0c8..51ab82722e38ef 100644 --- a/Modules/_ctypes/_ctypes.c +++ b/Modules/_ctypes/_ctypes.c @@ -5139,6 +5139,8 @@ static PyObject * Pointer_get_contents(CDataObject *self, void *closure) { StgDictObject *stgdict; + PyObject *ptr2ptr; + CDataObject *p2p; if (*(void **)self->b_ptr == NULL) { PyErr_SetString(PyExc_ValueError, @@ -5148,38 +5150,40 @@ Pointer_get_contents(CDataObject *self, void *closure) stgdict = PyObject_stgdict((PyObject *)self); assert(stgdict); /* Cannot be NULL for pointer instances */ + assert(stgdict->proto); - PyObject *keep = GetKeepedObjects(self); - if (keep != NULL) { - // check if it's a pointer to a pointer: - // pointers will have '0' key in the _objects - int ptr_probe = PyDict_ContainsString(keep, "0"); - if (ptr_probe < 0) { + if (self->b_objects != NULL && PyDict_CheckExact(self->b_objects)) { + // Pointer_set_contents uses KeepRef(self, 1, value); we retrieve that + ptr2ptr = PyDict_GetItemString(self->b_objects, "1"); + if (ptr2ptr == NULL) { + PyErr_SetString(PyExc_ValueError, + "Unexpected NULL pointer in _objects"); return NULL; } - if (ptr_probe) { - PyObject *item; - if (PyDict_GetItemStringRef(keep, "1", &item) < 0) { - return NULL; - } - if (item == NULL) { - PyErr_SetString(PyExc_ValueError, - "Unexpected NULL pointer in _objects"); - return NULL; - } -#ifndef NDEBUG - CDataObject *ptr2ptr = (CDataObject *)item; - // Don't construct a new object, - // return existing one instead to preserve refcount. - // Double-check that we are returning the same thing. + // if our base pointer is cast from another type, + // its `_type_` proto will be incompatible with the + // type of the object stored in `b_objects["1"]` because + // `_objects` is shared between casts and the original. + int res = PyObject_IsInstance(ptr2ptr, stgdict->proto); + if (res == -1) { + return NULL; + } + if (res) { + // It's not a cast: don't construct a new object, + // return existing one instead to preserve refcount + p2p = (CDataObject*) ptr2ptr; + printf("self->b_ptr=%lu\n", *(void**) self->b_ptr); + printf("p2p->b_ptr=%lu\n", *(void**) p2p->b_ptr); + printf("self->b_value.c=%lu\n", *(void**) self->b_value.c); + printf("p2p->b_value.c=%lu\n", *(void**) p2p->b_value.c); assert( - *(void**) self->b_ptr == ptr2ptr->b_ptr || - *(void**) self->b_value.c == ptr2ptr->b_ptr || - *(void**) self->b_ptr == ptr2ptr->b_value.c || - *(void**) self->b_value.c == ptr2ptr->b_value.c - ); -#endif - return item; + *(void**) self->b_ptr == *(void**) p2p->b_ptr || + *(void**) self->b_value.c == *(void**) p2p->b_ptr || + *(void**) self->b_ptr == *(void**) p2p->b_value.c || + *(void**) self->b_value.c == *(void**) p2p->b_value.c + ); // double-check that we are returning the same thing + Py_INCREF(ptr2ptr); + return ptr2ptr; } }