Skip to content

Match generic and private methods upon runtime reload #1637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ Instead, `PyIterable` does that.
- Providing an invalid type parameter to a generic type or method produces a helpful Python error
- Empty parameter names (as can be generated from F#) do not cause crashes
- Unicode strings with surrogates were truncated when converting from Python
- `Reload` mode now supports generic methods (previously Python would stop seeing them after reload)

### Removed

Expand Down
35 changes: 35 additions & 0 deletions src/embed_tests/StateSerialization/MethodSerialization.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using System.IO;
using System.Reflection;

using NUnit.Framework;

using Python.Runtime;

namespace Python.EmbeddingTest.StateSerialization;

public class MethodSerialization
{
[Test]
public void GenericRoundtrip()
{
var method = typeof(MethodTestHost).GetMethod(nameof(MethodTestHost.Generic));
var maybeMethod = new MaybeMethodBase<MethodBase>(method);
var restored = SerializationRoundtrip(maybeMethod);
Assert.IsTrue(restored.Valid);
Assert.AreEqual(method, restored.Value);
}

static T SerializationRoundtrip<T>(T item)
{
using var buf = new MemoryStream();
var formatter = RuntimeData.CreateFormatter();
formatter.Serialize(buf, item);
buf.Position = 0;
return (T)formatter.Deserialize(buf);
}
}

public class MethodTestHost
{
public void Generic<T>(T item, T[] array, ref T @ref) { }
}
51 changes: 48 additions & 3 deletions src/runtime/Reflection/ParameterHelper.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;

namespace Python.Runtime.Reflection;

[Serializable]
struct ParameterHelper : IEquatable<ParameterInfo>
class ParameterHelper : IEquatable<ParameterInfo>
{
public readonly string TypeName;
public readonly ParameterModifier Modifier;
public readonly ParameterHelper[]? GenericArguments;

public ParameterHelper(ParameterInfo tp)
public ParameterHelper(ParameterInfo tp) : this(tp.ParameterType)
{
TypeName = tp.ParameterType.AssemblyQualifiedName;
Modifier = ParameterModifier.None;

if (tp.IsIn && tp.ParameterType.IsByRef)
Expand All @@ -29,12 +31,55 @@ public ParameterHelper(ParameterInfo tp)
}
}

public ParameterHelper(Type type)
{
TypeName = type.AssemblyQualifiedName;
if (TypeName is null)
{
if (type.IsByRef || type.IsArray)
{
TypeName = type.IsArray ? "[]" : "&";
GenericArguments = new[] { new ParameterHelper(type.GetElementType()) };
}
else
{
Debug.Assert(type.ContainsGenericParameters);
TypeName = $"{type.Assembly}::{type.Namespace}/{type.Name}";
GenericArguments = type.GenericTypeArguments.Select(t => new ParameterHelper(t)).ToArray();
}
}
}

public bool IsSpecialType => TypeName == "&" || TypeName == "[]";

public bool Equals(ParameterInfo other)
{
return this.Equals(new ParameterHelper(other));
}

public bool Matches(ParameterInfo other) => this.Equals(other);

public bool Equals(ParameterHelper other)
{
if (other is null) return false;

if (!(other.TypeName == TypeName && other.Modifier == Modifier))
return false;

if (GenericArguments == other.GenericArguments) return true;

if (GenericArguments is not null && other.GenericArguments is not null)
{
if (GenericArguments.Length != other.GenericArguments.Length) return false;
for (int arg = 0; arg < GenericArguments.Length; arg++)
{
if (!GenericArguments[arg].Equals(other.GenericArguments[arg])) return false;
}
return true;
}

return false;
}
}

enum ParameterModifier
Expand Down
134 changes: 75 additions & 59 deletions src/runtime/StateSerialization/MaybeMethodBase.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.Serialization;
Expand All @@ -17,8 +18,9 @@ internal struct MaybeMethodBase<T> : ISerializable where T: MethodBase
const string SerializationType = "t";
// Fhe parameters of the MethodBase
const string SerializationParameters = "p";
const string SerializationIsCtor = "c";
const string SerializationMethodName = "n";
const string SerializationGenericParamCount = "G";
const string SerializationFlags = "V";

public static implicit operator MaybeMethodBase<T> (T? ob) => new (ob);

Expand Down Expand Up @@ -62,6 +64,7 @@ public MaybeMethodBase(T? mi)
{
info = mi;
name = mi?.ToString();
Debug.Assert(name != null || info == null);
deserializationException = null;
}

Expand All @@ -82,75 +85,60 @@ internal MaybeMethodBase(SerializationInfo serializationInfo, StreamingContext c
{
throw new SerializationException($"The underlying type {typeName} can't be found");
}

var flags = (MaybeMethodFlags)serializationInfo.GetInt32(SerializationFlags);
int genericCount = serializationInfo.GetInt32(SerializationGenericParamCount);

// Get the method's parameters types
var field_name = serializationInfo.GetString(SerializationMethodName);
var param = (ParameterHelper[])serializationInfo.GetValue(SerializationParameters, typeof(ParameterHelper[]));
Type[] types = new Type[param.Length];
bool hasRefType = false;
for (int i = 0; i < param.Length; i++)
{
var paramTypeName = param[i].TypeName;
types[i] = Type.GetType(paramTypeName);
if (types[i] == null)
{
throw new SerializationException($"The parameter of type {paramTypeName} can't be found");
}
else if (types[i].IsByRef)
{
hasRefType = true;
}
}

MethodBase? mb = null;
if (serializationInfo.GetBoolean(SerializationIsCtor))
{
// We never want the static constructor.
mb = tp.GetConstructor(ClassManager.BindingFlags&(~BindingFlags.Static), binder:null, types:types, modifiers:null);
}
else
{
mb = tp.GetMethod(field_name, ClassManager.BindingFlags, binder:null, types:types, modifiers:null);
}

if (mb != null && hasRefType)
{
mb = CheckRefTypes(mb, param);
}

// Do like in ClassManager.GetClassInfo
if(mb != null && ClassManager.ShouldBindMethod(mb))
{
info = mb;
}
info = ScanForMethod(tp, field_name, genericCount, flags, param);
}
catch (Exception e)
{
deserializationException = e;
}
}

MethodBase? CheckRefTypes(MethodBase mb, ParameterHelper[] ph)
static MethodBase ScanForMethod(Type declaringType, string name, int genericCount, MaybeMethodFlags flags, ParameterHelper[] parameters)
{
// One more step: Changing:
// void MyFn (ref int a)
// to:
// void MyFn (out int a)
// will still find the function correctly as, `in`, `out` and `ref`
// are all represented as a reference type. Query the method we got
// and validate the parameters
if (ph.Length != 0)
{
foreach (var item in Enumerable.Zip(ph, mb.GetParameters(), (orig, current) => new {orig, current}))
{
if (!item.current.Equals(item.orig))
{
// False positive
return null;
}
}
}
var bindingFlags = ClassManager.BindingFlags;
if (flags.HasFlag(MaybeMethodFlags.Constructor)) bindingFlags &= ~BindingFlags.Static;

return mb;
var alternatives = declaringType.GetMember(name,
flags.HasFlag(MaybeMethodFlags.Constructor)
? MemberTypes.Constructor
: MemberTypes.Method,
bindingFlags);

if (alternatives.Length == 0)
throw new MissingMethodException($"{declaringType}.{name}");

var visibility = flags & MaybeMethodFlags.Visibility;

var result = alternatives.Cast<MethodBase>().FirstOrDefault(m
=> MatchesGenericCount(m, genericCount) && MatchesSignature(m, parameters)
&& (Visibility(m) == visibility || ClassManager.ShouldBindMethod(m)));

if (result is null)
throw new MissingMethodException($"Matching overload not found for {declaringType}.{name}");

return result;
}

static bool MatchesGenericCount(MethodBase method, int genericCount)
=> method.ContainsGenericParameters
? method.GetGenericArguments().Length == genericCount
: genericCount == 0;

static bool MatchesSignature(MethodBase method, ParameterHelper[] parameters)
{
var curr = method.GetParameters();
if (curr.Length != parameters.Length) return false;
for (int i = 0; i < curr.Length; i++)
if (!parameters[i].Matches(curr[i])) return false;
return true;
}

public void GetObjectData(SerializationInfo serializationInfo, StreamingContext context)
Expand All @@ -159,11 +147,39 @@ public void GetObjectData(SerializationInfo serializationInfo, StreamingContext
if (Valid)
{
serializationInfo.AddValue(SerializationMethodName, info.Name);
serializationInfo.AddValue(SerializationType, info.ReflectedType.AssemblyQualifiedName);
serializationInfo.AddValue(SerializationGenericParamCount,
info.ContainsGenericParameters ? info.GetGenericArguments().Length : 0);
serializationInfo.AddValue(SerializationFlags, (int)Flags(info));
string? typeName = info.ReflectedType.AssemblyQualifiedName;
Debug.Assert(typeName != null);
serializationInfo.AddValue(SerializationType, typeName);
ParameterHelper[] parameters = (from p in info.GetParameters() select new ParameterHelper(p)).ToArray();
serializationInfo.AddValue(SerializationParameters, parameters, typeof(ParameterHelper[]));
serializationInfo.AddValue(SerializationIsCtor, info.IsConstructor);
}
}

static MaybeMethodFlags Flags(MethodBase method)
{
var flags = MaybeMethodFlags.Default;
if (method.IsConstructor) flags |= MaybeMethodFlags.Constructor;
if (method.IsStatic) flags |= MaybeMethodFlags.Static;
if (method.IsPublic) flags |= MaybeMethodFlags.Public;
return flags;
}

static MaybeMethodFlags Visibility(MethodBase method)
=> Flags(method) & MaybeMethodFlags.Visibility;
}

[Flags]
internal enum MaybeMethodFlags
{
Default = 0,
Constructor = 1,
Static = 2,

// TODO: other kinds of visibility
Public = 32,
Visibility = Public,
}
}
2 changes: 1 addition & 1 deletion src/runtime/runtime_data.cs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ private static void RestoreRuntimeDataObjects(SharedObjectsState storage)
}
}

private static IFormatter CreateFormatter()
internal static IFormatter CreateFormatter()
{
return FormatterType != null ?
(IFormatter)Activator.CreateInstance(FormatterType)
Expand Down