From a66f4d6f0b1219db0d21d659bb1f9b2031632730 Mon Sep 17 00:00:00 2001 From: Victor Nova Date: Sat, 18 Dec 2021 12:28:27 -0800 Subject: [PATCH] Match generic and private methods upon runtime reload --- CHANGELOG.md | 1 + .../StateSerialization/MethodSerialization.cs | 35 +++++ src/runtime/Reflection/ParameterHelper.cs | 51 ++++++- .../StateSerialization/MaybeMethodBase.cs | 134 ++++++++++-------- src/runtime/runtime_data.cs | 2 +- 5 files changed, 160 insertions(+), 63 deletions(-) create mode 100644 src/embed_tests/StateSerialization/MethodSerialization.cs diff --git a/CHANGELOG.md b/CHANGELOG.md index bce1ec557..0a01e69fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -103,6 +103,7 @@ Instead, `PyIterable` does that. - Providing an invalid type parameter to a generic type or method produces a helpful Python error - Empty parameter names (as can be generated from F#) do not cause crashes - Unicode strings with surrogates were truncated when converting from Python +- `Reload` mode now supports generic methods (previously Python would stop seeing them after reload) ### Removed diff --git a/src/embed_tests/StateSerialization/MethodSerialization.cs b/src/embed_tests/StateSerialization/MethodSerialization.cs new file mode 100644 index 000000000..0e584fc37 --- /dev/null +++ b/src/embed_tests/StateSerialization/MethodSerialization.cs @@ -0,0 +1,35 @@ +using System.IO; +using System.Reflection; + +using NUnit.Framework; + +using Python.Runtime; + +namespace Python.EmbeddingTest.StateSerialization; + +public class MethodSerialization +{ + [Test] + public void GenericRoundtrip() + { + var method = typeof(MethodTestHost).GetMethod(nameof(MethodTestHost.Generic)); + var maybeMethod = new MaybeMethodBase(method); + var restored = SerializationRoundtrip(maybeMethod); + Assert.IsTrue(restored.Valid); + Assert.AreEqual(method, restored.Value); + } + + static T SerializationRoundtrip(T item) + { + using var buf = new MemoryStream(); + var formatter = RuntimeData.CreateFormatter(); + formatter.Serialize(buf, item); + buf.Position = 0; + return (T)formatter.Deserialize(buf); + } +} + +public class MethodTestHost +{ + public void Generic(T item, T[] array, ref T @ref) { } +} diff --git a/src/runtime/Reflection/ParameterHelper.cs b/src/runtime/Reflection/ParameterHelper.cs index 24fce63b1..bff9f7430 100644 --- a/src/runtime/Reflection/ParameterHelper.cs +++ b/src/runtime/Reflection/ParameterHelper.cs @@ -1,18 +1,20 @@ using System; using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; using System.Reflection; namespace Python.Runtime.Reflection; [Serializable] -struct ParameterHelper : IEquatable +class ParameterHelper : IEquatable { public readonly string TypeName; public readonly ParameterModifier Modifier; + public readonly ParameterHelper[]? GenericArguments; - public ParameterHelper(ParameterInfo tp) + public ParameterHelper(ParameterInfo tp) : this(tp.ParameterType) { - TypeName = tp.ParameterType.AssemblyQualifiedName; Modifier = ParameterModifier.None; if (tp.IsIn && tp.ParameterType.IsByRef) @@ -29,12 +31,55 @@ public ParameterHelper(ParameterInfo tp) } } + public ParameterHelper(Type type) + { + TypeName = type.AssemblyQualifiedName; + if (TypeName is null) + { + if (type.IsByRef || type.IsArray) + { + TypeName = type.IsArray ? "[]" : "&"; + GenericArguments = new[] { new ParameterHelper(type.GetElementType()) }; + } + else + { + Debug.Assert(type.ContainsGenericParameters); + TypeName = $"{type.Assembly}::{type.Namespace}/{type.Name}"; + GenericArguments = type.GenericTypeArguments.Select(t => new ParameterHelper(t)).ToArray(); + } + } + } + + public bool IsSpecialType => TypeName == "&" || TypeName == "[]"; + public bool Equals(ParameterInfo other) { return this.Equals(new ParameterHelper(other)); } public bool Matches(ParameterInfo other) => this.Equals(other); + + public bool Equals(ParameterHelper other) + { + if (other is null) return false; + + if (!(other.TypeName == TypeName && other.Modifier == Modifier)) + return false; + + if (GenericArguments == other.GenericArguments) return true; + + if (GenericArguments is not null && other.GenericArguments is not null) + { + if (GenericArguments.Length != other.GenericArguments.Length) return false; + for (int arg = 0; arg < GenericArguments.Length; arg++) + { + if (!GenericArguments[arg].Equals(other.GenericArguments[arg])) return false; + } + return true; + } + + return false; + } } enum ParameterModifier diff --git a/src/runtime/StateSerialization/MaybeMethodBase.cs b/src/runtime/StateSerialization/MaybeMethodBase.cs index a278df2cf..9fb8ae047 100644 --- a/src/runtime/StateSerialization/MaybeMethodBase.cs +++ b/src/runtime/StateSerialization/MaybeMethodBase.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Reflection; using System.Runtime.Serialization; @@ -17,8 +18,9 @@ internal struct MaybeMethodBase : ISerializable where T: MethodBase const string SerializationType = "t"; // Fhe parameters of the MethodBase const string SerializationParameters = "p"; - const string SerializationIsCtor = "c"; const string SerializationMethodName = "n"; + const string SerializationGenericParamCount = "G"; + const string SerializationFlags = "V"; public static implicit operator MaybeMethodBase (T? ob) => new (ob); @@ -62,6 +64,7 @@ public MaybeMethodBase(T? mi) { info = mi; name = mi?.ToString(); + Debug.Assert(name != null || info == null); deserializationException = null; } @@ -82,46 +85,15 @@ internal MaybeMethodBase(SerializationInfo serializationInfo, StreamingContext c { throw new SerializationException($"The underlying type {typeName} can't be found"); } + + var flags = (MaybeMethodFlags)serializationInfo.GetInt32(SerializationFlags); + int genericCount = serializationInfo.GetInt32(SerializationGenericParamCount); + // Get the method's parameters types var field_name = serializationInfo.GetString(SerializationMethodName); var param = (ParameterHelper[])serializationInfo.GetValue(SerializationParameters, typeof(ParameterHelper[])); - Type[] types = new Type[param.Length]; - bool hasRefType = false; - for (int i = 0; i < param.Length; i++) - { - var paramTypeName = param[i].TypeName; - types[i] = Type.GetType(paramTypeName); - if (types[i] == null) - { - throw new SerializationException($"The parameter of type {paramTypeName} can't be found"); - } - else if (types[i].IsByRef) - { - hasRefType = true; - } - } - MethodBase? mb = null; - if (serializationInfo.GetBoolean(SerializationIsCtor)) - { - // We never want the static constructor. - mb = tp.GetConstructor(ClassManager.BindingFlags&(~BindingFlags.Static), binder:null, types:types, modifiers:null); - } - else - { - mb = tp.GetMethod(field_name, ClassManager.BindingFlags, binder:null, types:types, modifiers:null); - } - - if (mb != null && hasRefType) - { - mb = CheckRefTypes(mb, param); - } - - // Do like in ClassManager.GetClassInfo - if(mb != null && ClassManager.ShouldBindMethod(mb)) - { - info = mb; - } + info = ScanForMethod(tp, field_name, genericCount, flags, param); } catch (Exception e) { @@ -129,28 +101,44 @@ internal MaybeMethodBase(SerializationInfo serializationInfo, StreamingContext c } } - MethodBase? CheckRefTypes(MethodBase mb, ParameterHelper[] ph) + static MethodBase ScanForMethod(Type declaringType, string name, int genericCount, MaybeMethodFlags flags, ParameterHelper[] parameters) { - // One more step: Changing: - // void MyFn (ref int a) - // to: - // void MyFn (out int a) - // will still find the function correctly as, `in`, `out` and `ref` - // are all represented as a reference type. Query the method we got - // and validate the parameters - if (ph.Length != 0) - { - foreach (var item in Enumerable.Zip(ph, mb.GetParameters(), (orig, current) => new {orig, current})) - { - if (!item.current.Equals(item.orig)) - { - // False positive - return null; - } - } - } + var bindingFlags = ClassManager.BindingFlags; + if (flags.HasFlag(MaybeMethodFlags.Constructor)) bindingFlags &= ~BindingFlags.Static; - return mb; + var alternatives = declaringType.GetMember(name, + flags.HasFlag(MaybeMethodFlags.Constructor) + ? MemberTypes.Constructor + : MemberTypes.Method, + bindingFlags); + + if (alternatives.Length == 0) + throw new MissingMethodException($"{declaringType}.{name}"); + + var visibility = flags & MaybeMethodFlags.Visibility; + + var result = alternatives.Cast().FirstOrDefault(m + => MatchesGenericCount(m, genericCount) && MatchesSignature(m, parameters) + && (Visibility(m) == visibility || ClassManager.ShouldBindMethod(m))); + + if (result is null) + throw new MissingMethodException($"Matching overload not found for {declaringType}.{name}"); + + return result; + } + + static bool MatchesGenericCount(MethodBase method, int genericCount) + => method.ContainsGenericParameters + ? method.GetGenericArguments().Length == genericCount + : genericCount == 0; + + static bool MatchesSignature(MethodBase method, ParameterHelper[] parameters) + { + var curr = method.GetParameters(); + if (curr.Length != parameters.Length) return false; + for (int i = 0; i < curr.Length; i++) + if (!parameters[i].Matches(curr[i])) return false; + return true; } public void GetObjectData(SerializationInfo serializationInfo, StreamingContext context) @@ -159,11 +147,39 @@ public void GetObjectData(SerializationInfo serializationInfo, StreamingContext if (Valid) { serializationInfo.AddValue(SerializationMethodName, info.Name); - serializationInfo.AddValue(SerializationType, info.ReflectedType.AssemblyQualifiedName); + serializationInfo.AddValue(SerializationGenericParamCount, + info.ContainsGenericParameters ? info.GetGenericArguments().Length : 0); + serializationInfo.AddValue(SerializationFlags, (int)Flags(info)); + string? typeName = info.ReflectedType.AssemblyQualifiedName; + Debug.Assert(typeName != null); + serializationInfo.AddValue(SerializationType, typeName); ParameterHelper[] parameters = (from p in info.GetParameters() select new ParameterHelper(p)).ToArray(); serializationInfo.AddValue(SerializationParameters, parameters, typeof(ParameterHelper[])); - serializationInfo.AddValue(SerializationIsCtor, info.IsConstructor); } } + + static MaybeMethodFlags Flags(MethodBase method) + { + var flags = MaybeMethodFlags.Default; + if (method.IsConstructor) flags |= MaybeMethodFlags.Constructor; + if (method.IsStatic) flags |= MaybeMethodFlags.Static; + if (method.IsPublic) flags |= MaybeMethodFlags.Public; + return flags; + } + + static MaybeMethodFlags Visibility(MethodBase method) + => Flags(method) & MaybeMethodFlags.Visibility; + } + + [Flags] + internal enum MaybeMethodFlags + { + Default = 0, + Constructor = 1, + Static = 2, + + // TODO: other kinds of visibility + Public = 32, + Visibility = Public, } } diff --git a/src/runtime/runtime_data.cs b/src/runtime/runtime_data.cs index f30a54dbe..a4726b479 100644 --- a/src/runtime/runtime_data.cs +++ b/src/runtime/runtime_data.cs @@ -247,7 +247,7 @@ private static void RestoreRuntimeDataObjects(SharedObjectsState storage) } } - private static IFormatter CreateFormatter() + internal static IFormatter CreateFormatter() { return FormatterType != null ? (IFormatter)Activator.CreateInstance(FormatterType)