Skip to content

Commit 010de6c

Browse files
committed
make .NET objects that have __call__ method callable from Python
Implemented by adding tp_call to ClassBase, that uses reflection to find __call__ methods in .NET, and falls back to invoking __call__ method from Python base classes. fixes pythonnet#890 this is an amalgamation of d46878c, 5bb1007, and 960457f from https://github.com/losttech/pythonnet
1 parent c9626df commit 010de6c

File tree

4 files changed

+235
-0
lines changed

4 files changed

+235
-0
lines changed

src/embed_tests/CallableObject.cs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
using NUnit.Framework;
5+
6+
using Python.Runtime;
7+
8+
namespace Python.EmbeddingTest
9+
{
10+
public class CallableObject
11+
{
12+
[OneTimeSetUp]
13+
public void SetUp()
14+
{
15+
PythonEngine.Initialize();
16+
using (Py.GIL())
17+
{
18+
using var locals = new PyDict();
19+
PythonEngine.Exec(CallViaInheritance.BaseClassSource, locals: locals.Handle);
20+
CustomBaseTypeProvider.BaseClass = new PyType(locals[CallViaInheritance.BaseClassName]);
21+
PythonEngine.InteropConfiguration.PythonBaseTypeProviders.Add(new CustomBaseTypeProvider());
22+
}
23+
}
24+
25+
[OneTimeTearDown]
26+
public void Dispose()
27+
{
28+
PythonEngine.Shutdown();
29+
}
30+
[Test]
31+
public void CallMethodMakesObjectCallable()
32+
{
33+
var doubler = new DerivedDoubler();
34+
using (Py.GIL())
35+
{
36+
dynamic applyObjectTo21 = PythonEngine.Eval("lambda o: o(21)");
37+
Assert.AreEqual(doubler.__call__(21), (int)applyObjectTo21(doubler.ToPython()));
38+
}
39+
}
40+
[Test]
41+
public void CallMethodCanBeInheritedFromPython()
42+
{
43+
var callViaInheritance = new CallViaInheritance();
44+
using (Py.GIL())
45+
{
46+
dynamic applyObjectTo14 = PythonEngine.Eval("lambda o: o(14)");
47+
Assert.AreEqual(callViaInheritance.Call(14), (int)applyObjectTo14(callViaInheritance.ToPython()));
48+
}
49+
}
50+
51+
[Test]
52+
public void CanOverwriteCall()
53+
{
54+
var callViaInheritance = new CallViaInheritance();
55+
using var _ = Py.GIL();
56+
using var scope = Py.CreateScope();
57+
scope.Set("o", callViaInheritance);
58+
scope.Exec("orig_call = o.Call");
59+
scope.Exec("o.Call = lambda a: orig_call(a*7)");
60+
int result = scope.Eval<int>("o.Call(5)");
61+
Assert.AreEqual(105, result);
62+
}
63+
64+
class Doubler
65+
{
66+
public int __call__(int arg) => 2 * arg;
67+
}
68+
69+
class DerivedDoubler : Doubler { }
70+
71+
class CallViaInheritance
72+
{
73+
public const string BaseClassName = "Forwarder";
74+
public static readonly string BaseClassSource = $@"
75+
class MyCallableBase:
76+
def __call__(self, val):
77+
return self.Call(val)
78+
79+
class {BaseClassName}(MyCallableBase): pass
80+
";
81+
public int Call(int arg) => 3 * arg;
82+
}
83+
84+
class CustomBaseTypeProvider : IPythonBaseTypeProvider
85+
{
86+
internal static PyType BaseClass;
87+
88+
public IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases)
89+
{
90+
Assert.Greater(BaseClass.Refcount, 0);
91+
return type != typeof(CallViaInheritance)
92+
? existingBases
93+
: new[] { BaseClass };
94+
}
95+
}
96+
}
97+
}

src/runtime/PythonReflection.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Python.Runtime
6+
{
7+
using System.Diagnostics;
8+
using System.Runtime.InteropServices;
9+
10+
class PythonReflection
11+
{
12+
static IEnumerable<IntPtr> GetTypesWithPythonBasesInHierarchy(IntPtr tp)
13+
{
14+
Debug.Assert(ManagedType.IsManagedType(new BorrowedReference(tp)));
15+
16+
var candidateQueue = new Queue<IntPtr>();
17+
candidateQueue.Enqueue(tp);
18+
while (candidateQueue.Count > 0)
19+
{
20+
BorrowedReference tpRef = new(candidateQueue.Dequeue());
21+
BorrowedReference bases = PyType.GetBases(tpRef);
22+
if (bases != null)
23+
{
24+
long baseCount = Runtime.PyTuple_Size(bases);
25+
bool hasPythonBase = false;
26+
for (long baseIndex = 0; baseIndex < baseCount; baseIndex++)
27+
{
28+
BorrowedReference @base = Runtime.PyTuple_GetItem(bases, baseIndex);
29+
if (ManagedType.IsManagedType(@base))
30+
{
31+
candidateQueue.Enqueue(@base.DangerousGetAddress());
32+
}
33+
else
34+
{
35+
hasPythonBase = true;
36+
}
37+
}
38+
39+
if (hasPythonBase) yield return tpRef.DangerousGetAddress();
40+
}
41+
else
42+
{
43+
tpRef = PyType.GetBase(tpRef);
44+
if (tpRef != null && ManagedType.IsManagedType(tpRef))
45+
candidateQueue.Enqueue(tpRef.DangerousGetAddress());
46+
}
47+
}
48+
}
49+
}
50+
}

src/runtime/classbase.cs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
namespace Python.Runtime
77
{
8+
using System.Linq;
9+
using System.Reflection;
10+
811
/// <summary>
912
/// Base class for Python types that reflect managed types / classes.
1013
/// Concrete subclasses include ClassObject and DelegateObject. This
@@ -557,5 +560,76 @@ public static int mp_ass_subscript(IntPtr ob, IntPtr idx, IntPtr v)
557560

558561
return 0;
559562
}
563+
564+
public static IntPtr tp_call(IntPtr ob, IntPtr args, IntPtr kw)
565+
{
566+
IntPtr tp = Runtime.PyObject_TYPE(ob);
567+
var self = (ClassBase)GetManagedObject(tp);
568+
569+
if (!self.type.Valid)
570+
{
571+
return Exceptions.RaiseTypeError(self.type.DeletedMessage);
572+
}
573+
574+
Type type = self.type.Value;
575+
576+
var calls = type.GetMethods(BindingFlags.Public | BindingFlags.Instance)
577+
.Where(m => m.Name == "__call__")
578+
.ToList();
579+
if (calls.Count > 0)
580+
{
581+
var callBinder = new MethodBinder();
582+
foreach (MethodInfo call in calls)
583+
{
584+
callBinder.AddMethod(call);
585+
}
586+
return callBinder.Invoke(ob, args, kw);
587+
}
588+
589+
return InvokeCallInheritedFromPython(new BorrowedReference(ob), args, kw);
590+
}
591+
592+
/// <summary>
593+
/// Find bases defined in Python and use their __call__ if any
594+
/// </summary>
595+
static IntPtr InvokeCallInheritedFromPython(BorrowedReference ob, IntPtr args, IntPtr kw)
596+
{
597+
BorrowedReference tp = Runtime.PyObject_TYPE(ob);
598+
using var super = new PyObject(new BorrowedReference(Runtime.PySuper_Type));
599+
using var pyInst = new PyObject(ob);
600+
using var none = new PyObject(new BorrowedReference(Runtime.PyNone));
601+
BorrowedReference mro = PyType.GetMRO(tp);
602+
nint mroLen = Runtime.PyTuple_Size(mro);
603+
for (int baseIndex = 0; baseIndex < mroLen - 1; baseIndex++)
604+
{
605+
BorrowedReference @base = Runtime.PyTuple_GetItem(mro, baseIndex);
606+
if (!IsManagedType(@base)) continue;
607+
608+
BorrowedReference nextBase = Runtime.PyTuple_GetItem(mro, baseIndex + 1);
609+
if (ManagedType.IsManagedType(nextBase)) continue;
610+
611+
// call via super
612+
using var managedBase = new PyObject(@base);
613+
using var superInstance = super.Invoke(managedBase, pyInst);
614+
using var call = Runtime.PyObject_GetAttrString(superInstance.Reference, "__call__");
615+
if (call.IsNull())
616+
{
617+
if (Exceptions.ExceptionMatches(Exceptions.AttributeError))
618+
{
619+
Runtime.PyErr_Clear();
620+
continue;
621+
}
622+
else
623+
{
624+
return IntPtr.Zero;
625+
}
626+
}
627+
628+
return Runtime.PyObject_Call(call.DangerousGetAddress(), args, kw);
629+
}
630+
631+
Exceptions.SetError(Exceptions.TypeError, "object is not callable");
632+
return IntPtr.Zero;
633+
}
560634
}
561635
}

src/runtime/pytype.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,20 @@ internal static BorrowedReference GetBase(BorrowedReference type)
121121
return new BorrowedReference(basePtr);
122122
}
123123

124+
internal static BorrowedReference GetBases(BorrowedReference type)
125+
{
126+
Debug.Assert(IsType(type));
127+
IntPtr basesPtr = Marshal.ReadIntPtr(type.DangerousGetAddress(), TypeOffset.tp_bases);
128+
return new BorrowedReference(basesPtr);
129+
}
130+
131+
internal static BorrowedReference GetMRO(BorrowedReference type)
132+
{
133+
Debug.Assert(IsType(type));
134+
IntPtr basesPtr = Marshal.ReadIntPtr(type.DangerousGetAddress(), TypeOffset.tp_mro);
135+
return new BorrowedReference(basesPtr);
136+
}
137+
124138
private static IntPtr EnsureIsType(in StolenReference reference)
125139
{
126140
IntPtr address = reference.DangerousGetAddressOrNull();

0 commit comments

Comments
 (0)