Skip to content

Commit 2b3868f

Browse files
committed
managed types can now be subclassed in python and override virtual methods
1 parent 69d9933 commit 2b3868f

File tree

11 files changed

+421
-85
lines changed

11 files changed

+421
-85
lines changed

pythonnet/src/runtime/Python.Runtime.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@
115115
<OutputPath>bin\x64\Debug\</OutputPath>
116116
<DefineConstants Condition=" '$(DefineConstants)'==''">TRACE;DEBUG;PYTHON27,UCS4</DefineConstants>
117117
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
118-
<Optimize>true</Optimize>
118+
<Optimize>false</Optimize>
119119
<DebugType>full</DebugType>
120120
<PlatformTarget>x64</PlatformTarget>
121121
<CodeAnalysisIgnoreBuiltInRuleSets>false</CodeAnalysisIgnoreBuiltInRuleSets>
@@ -171,6 +171,7 @@
171171
<Compile Include="arrayobject.cs" />
172172
<Compile Include="assemblyinfo.cs" />
173173
<Compile Include="assemblymanager.cs" />
174+
<Compile Include="classderived.cs" />
174175
<Compile Include="classbase.cs" />
175176
<Compile Include="classmanager.cs" />
176177
<Compile Include="classobject.cs" />

pythonnet/src/runtime/assemblymanager.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ public static bool LoadImplicit(string name, out bool fromFile) {
260260
// be valid namespaces (to better match Python import semantics).
261261
//===================================================================
262262

263-
static void ScanAssembly(Assembly assembly) {
263+
internal static void ScanAssembly(Assembly assembly) {
264264

265265
// A couple of things we want to do here: first, we want to
266266
// gather a list of all of the namespaces contributed to by

pythonnet/src/runtime/classbase.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ public static IntPtr tp_richcompare(IntPtr ob, IntPtr other, int op) {
8282

8383
CLRObject co1 = GetManagedObject(ob) as CLRObject;
8484
CLRObject co2 = GetManagedObject(other) as CLRObject;
85+
if (null == co2) {
86+
Runtime.Incref(pyfalse);
87+
return pyfalse;
88+
}
89+
8590
Object o1 = co1.inst;
8691
Object o2 = co2.inst;
8792

pythonnet/src/runtime/classderived.cs

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
// ==========================================================================
2+
// This software is subject to the provisions of the Zope Public License,
3+
// Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution.
4+
// THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
5+
// WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
6+
// WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
7+
// FOR A PARTICULAR PURPOSE.
8+
// ==========================================================================
9+
10+
using System;
11+
using System.Reflection;
12+
using System.Reflection.Emit;
13+
using System.Collections.Generic;
14+
using System.Threading;
15+
using System.Linq;
16+
17+
namespace Python.Runtime
18+
{
19+
20+
/// <summary>
21+
/// Managed class that provides the implementation for reflected types.
22+
/// Managed classes and value types are represented in Python by actual
23+
/// Python type objects. Each of those type objects is associated with
24+
/// an instance of ClassObject, which provides its implementation.
25+
/// </summary>
26+
27+
internal class ClassDerivedObject : ClassObject
28+
{
29+
static private Dictionary<string, AssemblyBuilder> assemblyBuilders;
30+
static private Dictionary<Tuple<string, string>, ModuleBuilder> moduleBuilders;
31+
32+
static ClassDerivedObject()
33+
{
34+
assemblyBuilders = new Dictionary<string, AssemblyBuilder>();
35+
moduleBuilders = new Dictionary<Tuple<string, string>, ModuleBuilder>();
36+
}
37+
38+
internal ClassDerivedObject(Type tp)
39+
: base(tp)
40+
{
41+
}
42+
43+
//====================================================================
44+
// Implements __new__ for derived classes of reflected classes.
45+
//====================================================================
46+
new public static IntPtr tp_new(IntPtr tp, IntPtr args, IntPtr kw)
47+
{
48+
// derived classes have a __pyobj__ field that points back to the python object
49+
// (see Trampoline.InvokeMethod and CreateDerivedType)
50+
IntPtr pyobj = ClassObject.tp_new(tp, args, kw);
51+
CLRObject obj = (CLRObject)ManagedType.GetManagedObject(pyobj);
52+
FieldInfo fi = obj.inst.GetType().GetField("__pyobj__");
53+
fi.SetValue(obj.inst, pyobj);
54+
return pyobj;
55+
}
56+
57+
//====================================================================
58+
// Creates a new managed type derived from a base type with any virtual
59+
// methods overriden to call out to python if the associated python
60+
// object has overriden the method.
61+
//====================================================================
62+
internal static Type CreateDerivedType(string name,
63+
Type baseType,
64+
string namespaceStr,
65+
string assemblyName,
66+
string moduleName="Python.Runtime.Dynamic.dll")
67+
{
68+
if (null != namespaceStr)
69+
name = namespaceStr + "." + name;
70+
71+
if (null == assemblyName)
72+
assemblyName = Assembly.GetExecutingAssembly().FullName;
73+
74+
ModuleBuilder moduleBuilder = GetModuleBuilder(assemblyName, moduleName);
75+
TypeBuilder typeBuilder = moduleBuilder.DefineType(name, TypeAttributes.Public | TypeAttributes.Class);
76+
typeBuilder.SetParent(baseType);
77+
78+
// add a field for storing the python object pointer
79+
FieldBuilder fb = typeBuilder.DefineField("__pyobj__", typeof(IntPtr), FieldAttributes.Public);
80+
81+
// override any virtual methods
82+
MethodInfo[] methods = baseType.GetMethods();
83+
List<string> baseMethodNames = new List<string>();
84+
foreach (MethodInfo method in methods)
85+
{
86+
if (!method.Attributes.HasFlag(MethodAttributes.Virtual) | method.Attributes.HasFlag(MethodAttributes.Final))
87+
continue;
88+
89+
ParameterInfo[] parameters = method.GetParameters();
90+
Type[] parameterTypes = (from param in parameters select param.ParameterType).ToArray();
91+
92+
// create a method for calling the original method
93+
string baseMethodName = "_" + baseType.Name + "__" + method.Name;
94+
baseMethodNames.Add(baseMethodName);
95+
MethodBuilder mb = typeBuilder.DefineMethod(baseMethodName,
96+
MethodAttributes.Public | MethodAttributes.Final | MethodAttributes.HideBySig,
97+
method.ReturnType,
98+
parameterTypes);
99+
100+
// emit the assembly for calling the original method using call instead of callvirt
101+
ILGenerator il = mb.GetILGenerator();
102+
il.Emit(OpCodes.Ldarg_0);
103+
for (int i = 0; i < parameters.Length; ++i)
104+
il.Emit(OpCodes.Ldarg, i + 1);
105+
il.Emit(OpCodes.Call, method);
106+
il.Emit(OpCodes.Ret);
107+
108+
// override the original method with a new one that dispatches to python
109+
mb = typeBuilder.DefineMethod(method.Name,
110+
MethodAttributes.Public | MethodAttributes.ReuseSlot |
111+
MethodAttributes.Virtual | MethodAttributes.HideBySig,
112+
method.CallingConvention,
113+
method.ReturnType,
114+
parameterTypes);
115+
116+
il = mb.GetILGenerator();
117+
il.DeclareLocal(typeof(Object[]));
118+
il.Emit(OpCodes.Ldarg_0);
119+
il.Emit(OpCodes.Ldstr, method.Name);
120+
il.Emit(OpCodes.Ldstr, baseMethodName);
121+
il.Emit(OpCodes.Ldc_I4, parameters.Length);
122+
il.Emit(OpCodes.Newarr, typeof(System.Object));
123+
il.Emit(OpCodes.Stloc_0);
124+
for (int i = 0; i < parameters.Length; ++i)
125+
{
126+
il.Emit(OpCodes.Ldloc_0);
127+
il.Emit(OpCodes.Ldc_I4, i);
128+
il.Emit(OpCodes.Ldarg, i + 1);
129+
if (parameterTypes[i].IsPrimitive)
130+
il.Emit(OpCodes.Box, parameterTypes[i]);
131+
il.Emit(OpCodes.Stelem, typeof(Object));
132+
}
133+
il.Emit(OpCodes.Ldloc_0);
134+
if (method.ReturnType == typeof(void))
135+
{
136+
il.Emit(OpCodes.Call, typeof(Trampoline).GetMethod("InvokeMethodVoid"));
137+
}
138+
else
139+
{
140+
il.Emit(OpCodes.Call, typeof(Trampoline).GetMethod("InvokeMethod").MakeGenericMethod(method.ReturnType));
141+
}
142+
il.Emit(OpCodes.Ret);
143+
}
144+
145+
Type type = typeBuilder.CreateType();
146+
147+
// scan the assembly so the newly added class can be imported
148+
Assembly assembly = Assembly.GetAssembly(type);
149+
AssemblyManager.ScanAssembly(assembly);
150+
151+
return type;
152+
}
153+
154+
155+
private static ModuleBuilder GetModuleBuilder(string assemblyName, string moduleName)
156+
{
157+
// find or create a dynamic assembly and module
158+
AppDomain domain = AppDomain.CurrentDomain;
159+
ModuleBuilder moduleBuilder = null;
160+
161+
if (moduleBuilders.ContainsKey(Tuple.Create(assemblyName, moduleName)))
162+
{
163+
moduleBuilder = moduleBuilders[Tuple.Create(assemblyName, moduleName)];
164+
}
165+
else
166+
{
167+
AssemblyBuilder assemblyBuilder = null;
168+
if (assemblyBuilders.ContainsKey(assemblyName))
169+
{
170+
assemblyBuilder = assemblyBuilders[assemblyName];
171+
}
172+
else
173+
{
174+
assemblyBuilder = domain.DefineDynamicAssembly(new AssemblyName(assemblyName),
175+
AssemblyBuilderAccess.Run);
176+
assemblyBuilders[assemblyName] = assemblyBuilder;
177+
}
178+
179+
moduleBuilder = assemblyBuilder.DefineDynamicModule(moduleName);
180+
moduleBuilders[Tuple.Create(assemblyName, moduleName)] = moduleBuilder;
181+
}
182+
183+
return moduleBuilder;
184+
}
185+
186+
}
187+
188+
// This has to be public as it's called from methods on dynamically built classes
189+
// potentially in other assemblies
190+
public class Trampoline
191+
{
192+
//====================================================================
193+
// This is the implementaion of the overriden methods in the derived
194+
// type. It looks for a python method with the same name as the method
195+
// on the managed base class and if it exists and isn't the managed
196+
// method binding (ie it has been overriden in the derived python
197+
// class) it calls it, otherwise it calls the base method.
198+
//====================================================================
199+
public static T InvokeMethod<T>(Object obj, string methodName, string origMethodName, Object[] args)
200+
{
201+
FieldInfo fi = obj.GetType().GetField("__pyobj__");
202+
IntPtr ptr = (IntPtr)fi.GetValue(obj);
203+
if (null != ptr)
204+
{
205+
IntPtr gs = Runtime.PyGILState_Ensure();
206+
try
207+
{
208+
PyObject pyobj = new PyObject(ptr);
209+
PyObject method = pyobj.GetAttr(methodName, new PyObject(Runtime.PyNone));
210+
if (method.Handle != Runtime.PyNone)
211+
{
212+
// if the method hasn't been overriden then it will be a managed object
213+
ManagedType managedMethod = ManagedType.GetManagedObject(method.Handle);
214+
if (null == managedMethod)
215+
{
216+
PyObject[] pyargs = new PyObject[args.Length];
217+
for (int i = 0; i < args.Length; ++i)
218+
{
219+
pyargs[i] = new PyObject(Converter.ToPython(args[i], args[i].GetType()));
220+
}
221+
222+
PyObject py_result = method.Invoke(pyargs);
223+
return (T)py_result.AsManagedObject(typeof(T));
224+
}
225+
}
226+
}
227+
finally
228+
{
229+
Runtime.PyGILState_Release(gs);
230+
}
231+
}
232+
233+
return (T)obj.GetType().InvokeMember(origMethodName,
234+
BindingFlags.InvokeMethod,
235+
null,
236+
obj,
237+
args);
238+
}
239+
240+
public static void InvokeMethodVoid(Object obj, string methodName, string origMethodName, Object[] args)
241+
{
242+
FieldInfo fi = obj.GetType().GetField("__pyobj__");
243+
IntPtr ptr = (IntPtr)fi.GetValue(obj);
244+
if (null != ptr)
245+
{
246+
IntPtr gs = Runtime.PyGILState_Ensure();
247+
try
248+
{
249+
PyObject pyobj = new PyObject(ptr);
250+
PyObject method = pyobj.GetAttr(methodName, new PyObject(Runtime.PyNone));
251+
if (method.Handle != Runtime.PyNone)
252+
{
253+
// if the method hasn't been overriden then it will be a managed object
254+
ManagedType managedMethod = ManagedType.GetManagedObject(method.Handle);
255+
if (null == managedMethod)
256+
{
257+
PyObject[] pyargs = new PyObject[args.Length];
258+
for (int i = 0; i < args.Length; ++i)
259+
{
260+
pyargs[i] = new PyObject(Converter.ToPython(args[i], args[i].GetType()));
261+
}
262+
263+
PyObject py_result = method.Invoke(pyargs);
264+
return;
265+
}
266+
}
267+
}
268+
finally
269+
{
270+
Runtime.PyGILState_Release(gs);
271+
}
272+
}
273+
274+
obj.GetType().InvokeMember(origMethodName,
275+
BindingFlags.InvokeMethod,
276+
null,
277+
obj,
278+
args);
279+
}
280+
}
281+
}

pythonnet/src/runtime/classmanager.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,12 @@ private static ClassBase CreateClass(Type type) {
105105
impl = new ExceptionClassObject(type);
106106
}
107107

108-
else {
108+
else if (null != type.GetField("__pyobj__")) {
109+
impl = new ClassDerivedObject(type);
110+
}
111+
112+
else
113+
{
109114
impl = new ClassObject(type);
110115
}
111116

pythonnet/src/runtime/genericutil.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ static GenericUtil() {
3737
//====================================================================
3838

3939
internal static void Register(Type t) {
40+
if (null == t.Namespace || null == t.Name)
41+
return;
42+
4043
Dictionary<string, List<string>> nsmap = null;
4144
mapping.TryGetValue(t.Namespace, out nsmap);
4245
if (nsmap == null) {

pythonnet/src/runtime/metatype.cs

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public static IntPtr tp_new(IntPtr tp, IntPtr args, IntPtr kw) {
4646
return Exceptions.RaiseTypeError("invalid argument list");
4747
}
4848

49-
//IntPtr name = Runtime.PyTuple_GetItem(args, 0);
49+
IntPtr name = Runtime.PyTuple_GetItem(args, 0);
5050
IntPtr bases = Runtime.PyTuple_GetItem(args, 1);
5151
IntPtr dict = Runtime.PyTuple_GetItem(args, 2);
5252

@@ -88,45 +88,7 @@ public static IntPtr tp_new(IntPtr tp, IntPtr args, IntPtr kw) {
8888
);
8989
}
9090

91-
// hack for now... fix for 1.0
92-
//return TypeManager.CreateSubType(args);
93-
94-
95-
// right way
96-
97-
IntPtr func = Marshal.ReadIntPtr(Runtime.PyTypeType,
98-
TypeOffset.tp_new);
99-
IntPtr type = NativeCall.Call_3(func, tp, args, kw);
100-
if (type == IntPtr.Zero) {
101-
return IntPtr.Zero;
102-
}
103-
104-
int flags = TypeFlags.Default;
105-
flags |= TypeFlags.Managed;
106-
flags |= TypeFlags.HeapType;
107-
flags |= TypeFlags.BaseType;
108-
flags |= TypeFlags.Subclass;
109-
flags |= TypeFlags.HaveGC;
110-
Marshal.WriteIntPtr(type, TypeOffset.tp_flags, (IntPtr)flags);
111-
112-
TypeManager.CopySlot(base_type, type, TypeOffset.tp_dealloc);
113-
114-
// Hmm - the standard subtype_traverse, clear look at ob_size to
115-
// do things, so to allow gc to work correctly we need to move
116-
// our hidden handle out of ob_size. Then, in theory we can
117-
// comment this out and still not crash.
118-
TypeManager.CopySlot(base_type, type, TypeOffset.tp_traverse);
119-
TypeManager.CopySlot(base_type, type, TypeOffset.tp_clear);
120-
121-
122-
// for now, move up hidden handle...
123-
IntPtr gc = Marshal.ReadIntPtr(base_type, TypeOffset.magic());
124-
Marshal.WriteIntPtr(type, TypeOffset.magic(), gc);
125-
126-
//DebugUtil.DumpType(base_type);
127-
//DebugUtil.DumpType(type);
128-
129-
return type;
91+
return TypeManager.CreateSubType(name, base_type, dict);
13092
}
13193

13294

0 commit comments

Comments
 (0)