Skip to content

Commit bc3265d

Browse files
authored
make .NET objects that have __call__ method callable from Python (pythonnet#1589)
Implemented by adding tp_call to ClassBase, that uses reflection to find __call__ methods in .NET fixes pythonnet#890 this is an amalgamation of d46878c, 5bb1007, and 960457f from https://github.com/losttech/pythonnet
1 parent f591024 commit bc3265d

File tree

7 files changed

+169
-2
lines changed

7 files changed

+169
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ See [Mixins/collections.py](src/runtime/Mixins/collections.py).
2323
- .NET arrays implement Python buffer protocol
2424
- Python.NET will correctly resolve .NET methods, that accept `PyList`, `PyInt`,
2525
and other `PyObject` derived types when called from Python.
26+
- .NET classes, that have `__call__` method are callable from Python
2627
- `PyIterable` type, that wraps any iterable object in Python
2728

2829

src/embed_tests/CallableObject.cs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
using NUnit.Framework;
5+
6+
using Python.Runtime;
7+
8+
namespace Python.EmbeddingTest
9+
{
10+
public class CallableObject
11+
{
12+
[OneTimeSetUp]
13+
public void SetUp()
14+
{
15+
PythonEngine.Initialize();
16+
using var locals = new PyDict();
17+
PythonEngine.Exec(CallViaInheritance.BaseClassSource, locals: locals.Handle);
18+
CustomBaseTypeProvider.BaseClass = new PyType(locals[CallViaInheritance.BaseClassName]);
19+
PythonEngine.InteropConfiguration.PythonBaseTypeProviders.Add(new CustomBaseTypeProvider());
20+
}
21+
22+
[OneTimeTearDown]
23+
public void Dispose()
24+
{
25+
PythonEngine.Shutdown();
26+
}
27+
[Test]
28+
public void CallMethodMakesObjectCallable()
29+
{
30+
var doubler = new DerivedDoubler();
31+
dynamic applyObjectTo21 = PythonEngine.Eval("lambda o: o(21)");
32+
Assert.AreEqual(doubler.__call__(21), (int)applyObjectTo21(doubler.ToPython()));
33+
}
34+
[Test]
35+
public void CallMethodCanBeInheritedFromPython()
36+
{
37+
var callViaInheritance = new CallViaInheritance();
38+
dynamic applyObjectTo14 = PythonEngine.Eval("lambda o: o(14)");
39+
Assert.AreEqual(callViaInheritance.Call(14), (int)applyObjectTo14(callViaInheritance.ToPython()));
40+
}
41+
42+
[Test]
43+
public void CanOverwriteCall()
44+
{
45+
var callViaInheritance = new CallViaInheritance();
46+
using var scope = Py.CreateScope();
47+
scope.Set("o", callViaInheritance);
48+
scope.Exec("orig_call = o.Call");
49+
scope.Exec("o.Call = lambda a: orig_call(a*7)");
50+
int result = scope.Eval<int>("o.Call(5)");
51+
Assert.AreEqual(105, result);
52+
}
53+
54+
class Doubler
55+
{
56+
public int __call__(int arg) => 2 * arg;
57+
}
58+
59+
class DerivedDoubler : Doubler { }
60+
61+
class CallViaInheritance
62+
{
63+
public const string BaseClassName = "Forwarder";
64+
public static readonly string BaseClassSource = $@"
65+
class MyCallableBase:
66+
def __call__(self, val):
67+
return self.Call(val)
68+
69+
class {BaseClassName}(MyCallableBase): pass
70+
";
71+
public int Call(int arg) => 3 * arg;
72+
}
73+
74+
class CustomBaseTypeProvider : IPythonBaseTypeProvider
75+
{
76+
internal static PyType BaseClass;
77+
78+
public IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases)
79+
{
80+
Assert.Greater(BaseClass.Refcount, 0);
81+
return type != typeof(CallViaInheritance)
82+
? existingBases
83+
: new[] { BaseClass };
84+
}
85+
}
86+
}
87+
}

src/runtime/classbase.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
using System;
22
using System.Collections;
33
using System.Collections.Generic;
4+
using System.Diagnostics;
5+
using System.Linq;
6+
using System.Reflection;
47
using System.Runtime.InteropServices;
58

69
namespace Python.Runtime
@@ -557,5 +560,44 @@ public static int mp_ass_subscript(IntPtr ob, IntPtr idx, IntPtr v)
557560

558561
return 0;
559562
}
563+
564+
static IntPtr tp_call_impl(IntPtr ob, IntPtr args, IntPtr kw)
565+
{
566+
IntPtr tp = Runtime.PyObject_TYPE(ob);
567+
var self = (ClassBase)GetManagedObject(tp);
568+
569+
if (!self.type.Valid)
570+
{
571+
return Exceptions.RaiseTypeError(self.type.DeletedMessage);
572+
}
573+
574+
Type type = self.type.Value;
575+
576+
var calls = GetCallImplementations(type).ToList();
577+
Debug.Assert(calls.Count > 0);
578+
var callBinder = new MethodBinder();
579+
foreach (MethodInfo call in calls)
580+
{
581+
callBinder.AddMethod(call);
582+
}
583+
return callBinder.Invoke(ob, args, kw);
584+
}
585+
586+
static IEnumerable<MethodInfo> GetCallImplementations(Type type)
587+
=> type.GetMethods(BindingFlags.Public | BindingFlags.Instance)
588+
.Where(m => m.Name == "__call__");
589+
590+
static readonly Interop.TernaryFunc tp_call_delegate = tp_call_impl;
591+
592+
public virtual void InitializeSlots(SlotsHolder slotsHolder)
593+
{
594+
if (!this.type.Valid) return;
595+
596+
if (GetCallImplementations(this.type.Value).Any()
597+
&& !slotsHolder.IsHolding(TypeOffset.tp_call))
598+
{
599+
TypeManager.InitializeSlot(ObjectReference, TypeOffset.tp_call, tp_call_delegate, slotsHolder);
600+
}
601+
}
560602
}
561603
}

src/runtime/classmanager.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ internal static Dictionary<ManagedType, InterDomainContext> RestoreRuntimeData(R
162162
Runtime.PyType_Modified(pair.Value.TypeReference);
163163
var context = contexts[pair.Value.pyHandle];
164164
pair.Value.Load(context);
165+
var slotsHolder = TypeManager.GetSlotsHolder(pyType);
166+
pair.Value.InitializeSlots(slotsHolder);
167+
Runtime.PyType_Modified(pair.Value.TypeReference);
165168
loadedObjs.Add(pair.Value, context);
166169
}
167170

src/runtime/interop.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,13 @@ internal static ThunkInfo GetThunk(MethodInfo method, string funcType = null)
242242
return ThunkInfo.Empty;
243243
}
244244
Delegate d = Delegate.CreateDelegate(dt, method);
245-
var info = new ThunkInfo(d);
246-
allocatedThunks[info.Address] = d;
245+
return GetThunk(d);
246+
}
247+
248+
internal static ThunkInfo GetThunk(Delegate @delegate)
249+
{
250+
var info = new ThunkInfo(@delegate);
251+
allocatedThunks[info.Address] = @delegate;
247252
return info;
248253
}
249254

src/runtime/pytype.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,20 @@ internal static BorrowedReference GetBase(BorrowedReference type)
121121
return new BorrowedReference(basePtr);
122122
}
123123

124+
internal static BorrowedReference GetBases(BorrowedReference type)
125+
{
126+
Debug.Assert(IsType(type));
127+
IntPtr basesPtr = Marshal.ReadIntPtr(type.DangerousGetAddress(), TypeOffset.tp_bases);
128+
return new BorrowedReference(basesPtr);
129+
}
130+
131+
internal static BorrowedReference GetMRO(BorrowedReference type)
132+
{
133+
Debug.Assert(IsType(type));
134+
IntPtr basesPtr = Marshal.ReadIntPtr(type.DangerousGetAddress(), TypeOffset.tp_mro);
135+
return new BorrowedReference(basesPtr);
136+
}
137+
124138
private static IntPtr EnsureIsType(in StolenReference reference)
125139
{
126140
IntPtr address = reference.DangerousGetAddressOrNull();

src/runtime/typemanager.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,10 @@ static void InitializeClass(PyType pyType, ClassBase impl, Type clrType)
404404
impl.tpHandle = type;
405405
impl.pyHandle = type;
406406

407+
impl.InitializeSlots(slotsHolder);
408+
409+
Runtime.PyType_Modified(pyType.Reference);
410+
407411
//DebugUtil.DumpType(type);
408412
}
409413

@@ -787,6 +791,12 @@ static void InitializeSlot(IntPtr type, int slotOffset, MethodInfo method, Slots
787791
InitializeSlot(type, slotOffset, thunk, slotsHolder);
788792
}
789793

794+
internal static void InitializeSlot(BorrowedReference type, int slotOffset, Delegate impl, SlotsHolder slotsHolder)
795+
{
796+
var thunk = Interop.GetThunk(impl);
797+
InitializeSlot(type.DangerousGetAddress(), slotOffset, thunk, slotsHolder);
798+
}
799+
790800
static void InitializeSlot(IntPtr type, int slotOffset, ThunkInfo thunk, SlotsHolder slotsHolder)
791801
{
792802
Marshal.WriteIntPtr(type, slotOffset, thunk.Address);
@@ -848,6 +858,9 @@ private static SlotsHolder CreateSolotsHolder(IntPtr type)
848858
_slotsHolders.Add(type, holder);
849859
return holder;
850860
}
861+
862+
internal static SlotsHolder GetSlotsHolder(PyType type)
863+
=> _slotsHolders[type.Handle];
851864
}
852865

853866

@@ -873,6 +886,8 @@ public SlotsHolder(IntPtr type)
873886
_type = type;
874887
}
875888

889+
public bool IsHolding(int offset) => _slots.ContainsKey(offset);
890+
876891
public void Set(int offset, ThunkInfo thunk)
877892
{
878893
_slots[offset] = thunk;

0 commit comments

Comments
 (0)