Skip to content

Commit 50d947f

Browse files
authored
Merge pull request #1240 from danabr/auto-cast-ret-val-to-interface
Wrap returned objects in interface if method return type is interface
2 parents d44f1da + c46ab75 commit 50d947f

14 files changed

+189
-22
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ This document follows the conventions laid out in [Keep a CHANGELOG][].
1515
details about the cause of the failure
1616
- `clr.AddReference` no longer adds ".dll" implicitly
1717
- `PyIter(PyObject)` constructor replaced with static `PyIter.GetIter(PyObject)` method
18+
- Return values from .NET methods that return an interface are now automatically
19+
wrapped in that interface. This is a breaking change for users that rely on being
20+
able to access members that are part of the implementation class, but not the
21+
interface. Use the new __implementation__ or __raw_implementation__ properties to
22+
if you need to "downcast" to the implementation class.
1823

1924
### Fixed
2025

src/runtime/arrayobject.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ public static IntPtr tp_new(IntPtr tp, IntPtr args, IntPtr kw)
4343
public static IntPtr mp_subscript(IntPtr ob, IntPtr idx)
4444
{
4545
var obj = (CLRObject)GetManagedObject(ob);
46+
var arrObj = (ArrayObject)GetManagedObjectType(ob);
4647
var items = obj.inst as Array;
47-
Type itemType = obj.inst.GetType().GetElementType();
48+
Type itemType = arrObj.type.GetElementType();
4849
int rank = items.Rank;
4950
int index;
5051
object value;

src/runtime/converter.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,21 @@ internal static IntPtr ToPython(object value, Type type)
173173
}
174174
}
175175

176+
if (type.IsInterface)
177+
{
178+
var ifaceObj = (InterfaceObject)ClassManager.GetClass(type);
179+
return ifaceObj.WrapObject(value);
180+
}
181+
182+
// We need to special case interface array handling to ensure we
183+
// produce the correct type. Value may be an array of some concrete
184+
// type (FooImpl[]), but we want access to go via the interface type
185+
// (IFoo[]).
186+
if (type.IsArray && type.GetElementType().IsInterface)
187+
{
188+
return CLRObject.GetInstHandle(value, type);
189+
}
190+
176191
// it the type is a python subclass of a managed type then return the
177192
// underlying python object rather than construct a new wrapper object.
178193
var pyderived = value as IPythonDerivedType;

src/runtime/interfaceobject.cs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,43 @@ public static IntPtr tp_new(IntPtr tp, IntPtr args, IntPtr kw)
7171
return IntPtr.Zero;
7272
}
7373

74-
return CLRObject.GetInstHandle(obj, self.pyHandle);
74+
return self.WrapObject(obj);
75+
}
76+
77+
/// <summary>
78+
/// Wrap the given object in an interface object, so that only methods
79+
/// of the interface are available.
80+
/// </summary>
81+
public IntPtr WrapObject(object impl)
82+
{
83+
var objPtr = CLRObject.GetInstHandle(impl, pyHandle);
84+
return objPtr;
85+
}
86+
87+
/// <summary>
88+
/// Expose the wrapped implementation through attributes in both
89+
/// converted/encoded (__implementation__) and raw (__raw_implementation__) form.
90+
/// </summary>
91+
public static IntPtr tp_getattro(IntPtr ob, IntPtr key)
92+
{
93+
var clrObj = (CLRObject)GetManagedObject(ob);
94+
95+
if (!Runtime.PyString_Check(key))
96+
{
97+
return Exceptions.RaiseTypeError("string expected");
98+
}
99+
100+
string name = Runtime.GetManagedString(key);
101+
if (name == "__implementation__")
102+
{
103+
return Converter.ToPython(clrObj.inst);
104+
}
105+
else if (name == "__raw_implementation__")
106+
{
107+
return CLRObject.GetInstHandle(clrObj.inst);
108+
}
109+
110+
return Runtime.PyObject_GenericGetAttr(ob, key);
75111
}
76112
}
77113
}

src/runtime/managedtype.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,25 @@ internal static ManagedType GetManagedObject(IntPtr ob)
4545
return null;
4646
}
4747

48+
/// <summary>
49+
/// Given a Python object, return the associated managed object type or null.
50+
/// </summary>
51+
internal static ManagedType GetManagedObjectType(IntPtr ob)
52+
{
53+
if (ob != IntPtr.Zero)
54+
{
55+
IntPtr tp = Runtime.PyObject_TYPE(ob);
56+
var flags = Util.ReadCLong(tp, TypeOffset.tp_flags);
57+
if ((flags & TypeFlags.Managed) != 0)
58+
{
59+
tp = Marshal.ReadIntPtr(tp, TypeOffset.magic());
60+
var gc = (GCHandle)tp;
61+
return (ManagedType)gc.Target;
62+
}
63+
}
64+
return null;
65+
}
66+
4867

4968
internal static ManagedType GetManagedObjectErr(IntPtr ob)
5069
{

src/runtime/methodbinder.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ internal virtual IntPtr Invoke(IntPtr inst, IntPtr args, IntPtr kw, MethodBase i
744744
Type pt = pi[i].ParameterType;
745745
if (pi[i].IsOut || pt.IsByRef)
746746
{
747-
v = Converter.ToPython(binding.args[i], pt);
747+
v = Converter.ToPython(binding.args[i], pt.GetElementType());
748748
Runtime.PyTuple_SetItem(t, n, v);
749749
n++;
750750
}

src/runtime/typemanager.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,13 @@ internal static IntPtr CreateType(ManagedType impl, Type clrType)
164164
// we want to do this after the slot stuff above in case the class itself implements a slot method
165165
InitializeSlots(type, impl.GetType());
166166

167-
if (!clrType.GetInterfaces().Any(ifc => ifc == typeof(IEnumerable) || ifc == typeof(IEnumerator)))
167+
if (!typeof(IEnumerable).IsAssignableFrom(clrType) &&
168+
!typeof(IEnumerator).IsAssignableFrom(clrType))
168169
{
169170
// The tp_iter slot should only be set for enumerable types.
170171
Marshal.WriteIntPtr(type, TypeOffset.tp_iter, IntPtr.Zero);
171172
}
172173

173-
174174
if (base_ != IntPtr.Zero)
175175
{
176176
Marshal.WriteIntPtr(type, TypeOffset.tp_base, base_);

src/testing/interfacetest.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ internal interface IInternalInterface
1111
{
1212
}
1313

14-
1514
public interface ISayHello1
1615
{
1716
string SayHello();
@@ -43,6 +42,27 @@ string ISayHello2.SayHello()
4342
return "hello 2";
4443
}
4544

45+
public ISayHello1 GetISayHello1()
46+
{
47+
return this;
48+
}
49+
50+
public void GetISayHello2(out ISayHello2 hello2)
51+
{
52+
hello2 = this;
53+
}
54+
55+
public ISayHello1 GetNoSayHello(out ISayHello2 hello2)
56+
{
57+
hello2 = null;
58+
return null;
59+
}
60+
61+
public ISayHello1 [] GetISayHello1Array()
62+
{
63+
return new[] { this };
64+
}
65+
4666
public interface IPublic
4767
{
4868
}

src/testing/subclasstest.cs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,24 @@ public static string test_bar(IInterfaceTest x, string s, int i)
8989
}
9090

9191
// test instances can be constructed in managed code
92-
public static IInterfaceTest create_instance(Type t)
92+
public static SubClassTest create_instance(Type t)
93+
{
94+
return (SubClassTest)t.GetConstructor(new Type[] { }).Invoke(new object[] { });
95+
}
96+
97+
public static IInterfaceTest create_instance_interface(Type t)
9398
{
9499
return (IInterfaceTest)t.GetConstructor(new Type[] { }).Invoke(new object[] { });
95100
}
96101

97-
// test instances pass through managed code unchanged
98-
public static IInterfaceTest pass_through(IInterfaceTest s)
102+
// test instances pass through managed code unchanged ...
103+
public static SubClassTest pass_through(SubClassTest s)
104+
{
105+
return s;
106+
}
107+
108+
// ... but the return type is an interface type, objects get wrapped
109+
public static IInterfaceTest pass_through_interface(IInterfaceTest s)
99110
{
100111
return s;
101112
}

src/tests/test_array.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,9 +1288,10 @@ def test_special_array_creation():
12881288
assert value[1].__class__ == inst.__class__
12891289
assert value.Length == 2
12901290

1291+
iface_class = ISayHello1(inst).__class__
12911292
value = Array[ISayHello1]([inst, inst])
1292-
assert value[0].__class__ == inst.__class__
1293-
assert value[1].__class__ == inst.__class__
1293+
assert value[0].__class__ == iface_class
1294+
assert value[1].__class__ == iface_class
12941295
assert value.Length == 2
12951296

12961297
inst = System.Exception("badness")

src/tests/test_generic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,6 @@ def test_generic_method_type_handling():
319319
assert_generic_method_by_type(ShortEnum, ShortEnum.Zero)
320320
assert_generic_method_by_type(System.Object, InterfaceTest())
321321
assert_generic_method_by_type(InterfaceTest, InterfaceTest(), 1)
322-
assert_generic_method_by_type(ISayHello1, InterfaceTest(), 1)
323322

324323

325324
def test_correct_overload_selection():
@@ -548,10 +547,11 @@ def test_method_overload_selection_with_generic_types():
548547
value = MethodTest.Overloaded.__overloads__[vtype](input_)
549548
assert value.value.__class__ == inst.__class__
550549

550+
iface_class = ISayHello1(inst).__class__
551551
vtype = GenericWrapper[ISayHello1]
552552
input_ = vtype(inst)
553553
value = MethodTest.Overloaded.__overloads__[vtype](input_)
554-
assert value.value.__class__ == inst.__class__
554+
assert value.value.__class__ == iface_class
555555

556556
vtype = System.Array[GenericWrapper[int]]
557557
input_ = vtype([GenericWrapper[int](0), GenericWrapper[int](1)])
@@ -726,11 +726,12 @@ def test_overload_selection_with_arrays_of_generic_types():
726726
assert value[0].value.__class__ == inst.__class__
727727
assert value.Length == 2
728728

729+
iface_class = ISayHello1(inst).__class__
729730
gtype = GenericWrapper[ISayHello1]
730731
vtype = System.Array[gtype]
731732
input_ = vtype([gtype(inst), gtype(inst)])
732733
value = MethodTest.Overloaded.__overloads__[vtype](input_)
733-
assert value[0].value.__class__ == inst.__class__
734+
assert value[0].value.__class__ == iface_class
734735
assert value.Length == 2
735736

736737

src/tests/test_interface.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,62 @@ def test_explicit_cast_to_interface():
6161
assert hasattr(i1, 'SayHello')
6262
assert i1.SayHello() == 'hello 1'
6363
assert not hasattr(i1, 'HelloProperty')
64+
assert i1.__implementation__ == ob
65+
assert i1.__raw_implementation__ == ob
6466

6567
i2 = Test.ISayHello2(ob)
6668
assert type(i2).__name__ == 'ISayHello2'
6769
assert i2.SayHello() == 'hello 2'
6870
assert hasattr(i2, 'SayHello')
6971
assert not hasattr(i2, 'HelloProperty')
72+
73+
74+
def test_interface_object_returned_through_method():
75+
"""Test interface type is used if method return type is interface"""
76+
from Python.Test import InterfaceTest
77+
78+
ob = InterfaceTest()
79+
hello1 = ob.GetISayHello1()
80+
assert type(hello1).__name__ == 'ISayHello1'
81+
assert hello1.__implementation__.__class__.__name__ == "InterfaceTest"
82+
83+
assert hello1.SayHello() == 'hello 1'
84+
85+
86+
def test_interface_object_returned_through_out_param():
87+
"""Test interface type is used for out parameters of interface types"""
88+
from Python.Test import InterfaceTest
89+
90+
ob = InterfaceTest()
91+
hello2 = ob.GetISayHello2(None)
92+
assert type(hello2).__name__ == 'ISayHello2'
93+
94+
assert hello2.SayHello() == 'hello 2'
95+
96+
97+
def test_null_interface_object_returned():
98+
"""Test None is used also for methods with interface return types"""
99+
from Python.Test import InterfaceTest
100+
101+
ob = InterfaceTest()
102+
hello1, hello2 = ob.GetNoSayHello(None)
103+
assert hello1 is None
104+
assert hello2 is None
105+
106+
def test_interface_array_returned():
107+
"""Test interface type used for methods returning interface arrays"""
108+
from Python.Test import InterfaceTest
109+
110+
ob = InterfaceTest()
111+
hellos = ob.GetISayHello1Array()
112+
assert type(hellos[0]).__name__ == 'ISayHello1'
113+
assert hellos[0].__implementation__.__class__.__name__ == "InterfaceTest"
114+
115+
def test_implementation_access():
116+
"""Test the __implementation__ and __raw_implementation__ properties"""
117+
import System
118+
clrVal = System.Int32(100)
119+
i = System.IComparable(clrVal)
120+
assert 100 == i.__implementation__
121+
assert clrVal == i.__raw_implementation__
122+
assert i.__implementation__ != i.__raw_implementation__

src/tests/test_method.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,8 +564,10 @@ def test_explicit_overload_selection():
564564
value = MethodTest.Overloaded.__overloads__[InterfaceTest](inst)
565565
assert value.__class__ == inst.__class__
566566

567+
iface_class = ISayHello1(InterfaceTest()).__class__
567568
value = MethodTest.Overloaded.__overloads__[ISayHello1](inst)
568-
assert value.__class__ == inst.__class__
569+
assert value.__class__ != inst.__class__
570+
assert value.__class__ == iface_class
569571

570572
atype = Array[System.Object]
571573
value = MethodTest.Overloaded.__overloads__[str, int, atype](
@@ -718,11 +720,12 @@ def test_overload_selection_with_array_types():
718720
assert value[0].__class__ == inst.__class__
719721
assert value[1].__class__ == inst.__class__
720722

723+
iface_class = ISayHello1(inst).__class__
721724
vtype = Array[ISayHello1]
722725
input_ = vtype([inst, inst])
723726
value = MethodTest.Overloaded.__overloads__[vtype](input_)
724-
assert value[0].__class__ == inst.__class__
725-
assert value[1].__class__ == inst.__class__
727+
assert value[0].__class__ == iface_class
728+
assert value[1].__class__ == iface_class
726729

727730

728731
def test_explicit_overload_selection_failure():

src/tests/test_subclass.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,10 @@ def test_interface():
104104
assert ob.bar("bar", 2) == "bar/bar"
105105
assert FunctionsTest.test_bar(ob, "bar", 2) == "bar/bar"
106106

107-
x = FunctionsTest.pass_through(ob)
108-
assert id(x) == id(ob)
107+
# pass_through will convert from InterfaceTestClass -> IInterfaceTest,
108+
# causing a new wrapper object to be created. Hence id will differ.
109+
x = FunctionsTest.pass_through_interface(ob)
110+
assert id(x) != id(ob)
109111

110112

111113
def test_derived_class():
@@ -173,14 +175,14 @@ def test_create_instance():
173175
assert id(x) == id(ob)
174176

175177
InterfaceTestClass = interface_test_class_fixture(test_create_instance.__name__)
176-
ob2 = FunctionsTest.create_instance(InterfaceTestClass)
178+
ob2 = FunctionsTest.create_instance_interface(InterfaceTestClass)
177179
assert ob2.foo() == "InterfaceTestClass"
178180
assert FunctionsTest.test_foo(ob2) == "InterfaceTestClass"
179181
assert ob2.bar("bar", 2) == "bar/bar"
180182
assert FunctionsTest.test_bar(ob2, "bar", 2) == "bar/bar"
181183

182-
y = FunctionsTest.pass_through(ob2)
183-
assert id(y) == id(ob2)
184+
y = FunctionsTest.pass_through_interface(ob2)
185+
assert id(y) != id(ob2)
184186

185187

186188
def test_events():

0 commit comments

Comments
 (0)