Skip to content

Allow substituting base types for CLR types (as seen from Python) #1487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions src/embed_tests/Inheritance.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.InteropServices;

using NUnit.Framework;

using Python.Runtime;

namespace Python.EmbeddingTest
{
public class Inheritance
{
[OneTimeSetUp]
public void SetUp()
{
PythonEngine.Initialize();
var locals = new PyDict();
PythonEngine.Exec(InheritanceTestBaseClassWrapper.ClassSourceCode, locals: locals.Handle);
ExtraBaseTypeProvider.ExtraBase = new PyType(locals[InheritanceTestBaseClassWrapper.ClassName]);
var baseTypeProviders = PythonEngine.InteropConfiguration.PythonBaseTypeProviders;
baseTypeProviders.Add(new ExtraBaseTypeProvider());
baseTypeProviders.Add(new NoEffectBaseTypeProvider());
}

[OneTimeTearDown]
public void Dispose()
{
PythonEngine.Shutdown();
}

[Test]
public void ExtraBase_PassesInstanceCheck()
{
var inherited = new Inherited();
bool properlyInherited = PyIsInstance(inherited, ExtraBaseTypeProvider.ExtraBase);
Assert.IsTrue(properlyInherited);
}

static dynamic PyIsInstance => PythonEngine.Eval("isinstance");

[Test]
public void InheritingWithExtraBase_CreatesNewClass()
{
PyObject a = ExtraBaseTypeProvider.ExtraBase;
var inherited = new Inherited();
PyObject inheritedClass = inherited.ToPython().GetAttr("__class__");
Assert.IsFalse(PythonReferenceComparer.Instance.Equals(a, inheritedClass));
}

[Test]
public void InheritedFromInheritedClassIsSelf()
{
using var scope = Py.CreateScope();
scope.Exec($"from {typeof(Inherited).Namespace} import {nameof(Inherited)}");
scope.Exec($"class B({nameof(Inherited)}): pass");
PyObject b = scope.Eval("B");
PyObject bInstance = b.Invoke();
PyObject bInstanceClass = bInstance.GetAttr("__class__");
Assert.IsTrue(PythonReferenceComparer.Instance.Equals(b, bInstanceClass));
}

[Test]
public void Grandchild_PassesExtraBaseInstanceCheck()
{
using var scope = Py.CreateScope();
scope.Exec($"from {typeof(Inherited).Namespace} import {nameof(Inherited)}");
scope.Exec($"class B({nameof(Inherited)}): pass");
PyObject b = scope.Eval("B");
PyObject bInst = b.Invoke();
bool properlyInherited = PyIsInstance(bInst, ExtraBaseTypeProvider.ExtraBase);
Assert.IsTrue(properlyInherited);
}

[Test]
public void CallInheritedClrMethod_WithExtraPythonBase()
{
var instance = new Inherited().ToPython();
string result = instance.InvokeMethod(nameof(PythonWrapperBase.WrapperBaseMethod)).As<string>();
Assert.AreEqual(result, nameof(PythonWrapperBase.WrapperBaseMethod));
}

[Test]
public void CallExtraBaseMethod()
{
var instance = new Inherited();
using var scope = Py.CreateScope();
scope.Set(nameof(instance), instance);
int actual = instance.ToPython().InvokeMethod("callVirt").As<int>();
Assert.AreEqual(expected: Inherited.OverridenVirtValue, actual);
}

[Test]
public void SetAdHocAttributes_WhenExtraBasePresent()
{
var instance = new Inherited();
using var scope = Py.CreateScope();
scope.Set(nameof(instance), instance);
scope.Exec($"super({nameof(instance)}.__class__, {nameof(instance)}).set_x_to_42()");
int actual = scope.Eval<int>($"{nameof(instance)}.{nameof(Inherited.XProp)}");
Assert.AreEqual(expected: Inherited.X, actual);
}
}

class ExtraBaseTypeProvider : IPythonBaseTypeProvider
{
internal static PyType ExtraBase;
public IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases)
{
if (type == typeof(InheritanceTestBaseClassWrapper))
{
return new[] { PyType.Get(type.BaseType), ExtraBase };
}
return existingBases;
}
}

class NoEffectBaseTypeProvider : IPythonBaseTypeProvider
{
public IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases)
=> existingBases;
}

public class PythonWrapperBase
{
public string WrapperBaseMethod() => nameof(WrapperBaseMethod);
}

public class InheritanceTestBaseClassWrapper : PythonWrapperBase
{
public const string ClassName = "InheritanceTestBaseClass";
public const string ClassSourceCode = "class " + ClassName +
@":
def virt(self):
return 42
def set_x_to_42(self):
self.XProp = 42
def callVirt(self):
return self.virt()
def __getattr__(self, name):
return '__getattr__:' + name
def __setattr__(self, name, value):
value[name] = name
" + ClassName + " = " + ClassName + "\n";
}

public class Inherited : InheritanceTestBaseClassWrapper
{
public const int OverridenVirtValue = -42;
public const int X = 42;
readonly Dictionary<string, object> extras = new Dictionary<string, object>();
public int virt() => OverridenVirtValue;
public int XProp
{
get
{
using (var scope = Py.CreateScope())
{
scope.Set("this", this);
try
{
return scope.Eval<int>($"super(this.__class__, this).{nameof(XProp)}");
}
catch (PythonException ex) when (ex.Type.Handle == Exceptions.AttributeError)
{
if (this.extras.TryGetValue(nameof(this.XProp), out object value))
return (int)value;
throw;
}
}
}
set => this.extras[nameof(this.XProp)] = value;
}
}
}
34 changes: 34 additions & 0 deletions src/runtime/DefaultBaseTypeProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;

namespace Python.Runtime
{
/// <summary>Minimal Python base type provider</summary>
public sealed class DefaultBaseTypeProvider : IPythonBaseTypeProvider
{
public IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases)
{
if (type is null)
throw new ArgumentNullException(nameof(type));
if (existingBases is null)
throw new ArgumentNullException(nameof(existingBases));
if (existingBases.Count > 0)
throw new ArgumentException("To avoid confusion, this type provider requires the initial set of base types to be empty");

return new[] { new PyType(GetBaseType(type)) };
}

static BorrowedReference GetBaseType(Type type)
{
if (type == typeof(Exception))
return new BorrowedReference(Exceptions.Exception);

return type.BaseType is not null
? ClassManager.GetClass(type.BaseType).ObjectReference
: new BorrowedReference(Runtime.PyBaseObjectType);
}

DefaultBaseTypeProvider(){}
public static DefaultBaseTypeProvider Instance { get; } = new DefaultBaseTypeProvider();
}
}
14 changes: 14 additions & 0 deletions src/runtime/IPythonBaseTypeProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;

namespace Python.Runtime
{
public interface IPythonBaseTypeProvider
{
/// <summary>
/// Get Python types, that should be presented to Python as the base types
/// for the specified .NET type.
/// </summary>
IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases);
}
}
25 changes: 25 additions & 0 deletions src/runtime/InteropConfiguration.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
namespace Python.Runtime
{
using System;
using System.Collections.Generic;

public sealed class InteropConfiguration
{
internal readonly PythonBaseTypeProviderGroup pythonBaseTypeProviders
= new PythonBaseTypeProviderGroup();

/// <summary>Enables replacing base types of CLR types as seen from Python</summary>
public IList<IPythonBaseTypeProvider> PythonBaseTypeProviders => this.pythonBaseTypeProviders;

public static InteropConfiguration MakeDefault()
{
return new InteropConfiguration
{
PythonBaseTypeProviders =
{
DefaultBaseTypeProvider.Instance,
},
};
}
}
}
24 changes: 24 additions & 0 deletions src/runtime/PythonBaseTypeProviderGroup.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Linq;

namespace Python.Runtime
{
class PythonBaseTypeProviderGroup : List<IPythonBaseTypeProvider>, IPythonBaseTypeProvider
{
public IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases)
{
if (type is null)
throw new ArgumentNullException(nameof(type));
if (existingBases is null)
throw new ArgumentNullException(nameof(existingBases));

foreach (var provider in this)
{
existingBases = provider.GetBaseTypes(type, existingBases).ToList();
}

return existingBases;
}
}
}
22 changes: 22 additions & 0 deletions src/runtime/PythonReferenceComparer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#nullable enable
using System.Collections.Generic;

namespace Python.Runtime
{
/// <summary>
/// Compares Python object wrappers by Python object references.
/// <para>Similar to <see cref="object.ReferenceEquals"/> but for Python objects</para>
/// </summary>
public sealed class PythonReferenceComparer : IEqualityComparer<PyObject>
{
public static PythonReferenceComparer Instance { get; } = new PythonReferenceComparer();
public bool Equals(PyObject? x, PyObject? y)
{
return x?.Handle == y?.Handle;
}

public int GetHashCode(PyObject obj) => obj.Handle.GetHashCode();

private PythonReferenceComparer() { }
}
}
7 changes: 7 additions & 0 deletions src/runtime/StolenReference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ public override bool Equals(object obj)

[Pure]
public override int GetHashCode() => Pointer.GetHashCode();

[Pure]
public static StolenReference DangerousFromPointer(IntPtr ptr)
{
if (ptr == IntPtr.Zero) throw new ArgumentNullException(nameof(ptr));
return new StolenReference(ptr);
}
}

static class StolenReferenceExtensions
Expand Down
12 changes: 10 additions & 2 deletions src/runtime/classmanager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,14 @@ private static ClassBase CreateClass(Type type)

private static void InitClassBase(Type type, ClassBase impl)
{
// Ensure, that matching Python type exists first.
// It is required for self-referential classes
// (e.g. with members, that refer to the same class)
var pyType = TypeManager.GetOrCreateClass(type);

// Set the handle attributes on the implementing instance.
impl.tpHandle = impl.pyHandle = pyType.Handle;

// First, we introspect the managed type and build some class
// information, including generating the member descriptors
// that we'll be putting in the Python class __dict__.
Expand All @@ -261,12 +269,12 @@ private static void InitClassBase(Type type, ClassBase impl)
impl.indexer = info.indexer;
impl.richcompare = new Dictionary<int, MethodObject>();

// Now we allocate the Python type object to reflect the given
// Now we force initialize the Python type object to reflect the given
// managed type, filling the Python type slots with thunks that
// point to the managed methods providing the implementation.


var pyType = TypeManager.GetType(impl, type);
TypeManager.GetOrInitializeClass(impl, type);

// Finally, initialize the class __dict__ and return the object.
using var dict = Runtime.PyObject_GenericGetDict(pyType.Reference);
Expand Down
17 changes: 16 additions & 1 deletion src/runtime/pythonengine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public static ShutdownMode ShutdownMode
private static IntPtr _pythonHome = IntPtr.Zero;
private static IntPtr _programName = IntPtr.Zero;
private static IntPtr _pythonPath = IntPtr.Zero;
private static InteropConfiguration interopConfiguration = InteropConfiguration.MakeDefault();

public PythonEngine()
{
Expand Down Expand Up @@ -68,6 +69,18 @@ internal static DelegateManager DelegateManager
}
}

public static InteropConfiguration InteropConfiguration
{
get => interopConfiguration;
set
{
if (IsInitialized)
throw new NotSupportedException("Changing interop configuration when engine is running is not supported");

interopConfiguration = value ?? throw new ArgumentNullException(nameof(InteropConfiguration));
}
}

public static string ProgramName
{
get
Expand Down Expand Up @@ -334,6 +347,8 @@ public static void Shutdown(ShutdownMode mode)
PyObjectConversions.Reset();

initialized = false;

InteropConfiguration = InteropConfiguration.MakeDefault();
}

/// <summary>
Expand Down Expand Up @@ -563,7 +578,7 @@ public static ulong GetPythonThreadID()
/// Interrupts the execution of a thread.
/// </summary>
/// <param name="pythonThreadID">The Python thread ID.</param>
/// <returns>The number of thread states modified; this is normally one, but will be zero if the thread id isn’t found.</returns>
/// <returns>The number of thread states modified; this is normally one, but will be zero if the thread id is not found.</returns>
public static int Interrupt(ulong pythonThreadID)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
Expand Down
Loading