diff --git a/Directory.Build.props b/Directory.Build.props index 8c5b53685..965610f91 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -6,7 +6,7 @@ Python.NET 10.0 false - $([System.IO.File]::ReadAllText("version.txt")) + $([System.IO.File]::ReadAllText($(MSBuildThisFileDirectory)version.txt)) $(FullVersion.Split('-', 2)[0]) $(FullVersion.Split('-', 2)[1]) diff --git a/pyproject.toml b/pyproject.toml index 5ee89d3b7..ba84e7ef4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ license = {text = "MIT"} readme = "README.rst" dependencies = [ - "clr_loader>=0.1.7" + "clr_loader==0.1.7" ] classifiers = [ diff --git a/pythonnet/__init__.py b/pythonnet/__init__.py index 9876a0bec..8f3478713 100644 --- a/pythonnet/__init__.py +++ b/pythonnet/__init__.py @@ -121,7 +121,7 @@ def load( def unload() -> None: - """Explicitly unload a laoded runtime and shut down Python.NET""" + """Explicitly unload a loaded runtime and shut down Python.NET""" global _RUNTIME, _LOADER_ASSEMBLY if _LOADER_ASSEMBLY is not None: diff --git a/src/embed_tests/TestConverter.cs b/src/embed_tests/TestConverter.cs index e586eda1b..0686d528b 100644 --- a/src/embed_tests/TestConverter.cs +++ b/src/embed_tests/TestConverter.cs @@ -148,7 +148,7 @@ public void PyIntImplicit() { var i = new PyInt(1); var ni = (PyObject)i.As(); - Assert.AreEqual(i.rawPtr, ni.rawPtr); + Assert.IsTrue(PythonReferenceComparer.Instance.Equals(i, ni)); } [Test] @@ -178,8 +178,11 @@ public void RawPyObjectProxy() var clrObject = (CLRObject)ManagedType.GetManagedObject(pyObjectProxy); Assert.AreSame(pyObject, clrObject.inst); - var proxiedHandle = pyObjectProxy.GetAttr("Handle").As(); - Assert.AreEqual(pyObject.Handle, proxiedHandle); +#pragma warning disable CS0612 // Type or member is obsolete + const string handlePropertyName = nameof(PyObject.Handle); +#pragma warning restore CS0612 // Type or member is obsolete + var proxiedHandle = pyObjectProxy.GetAttr(handlePropertyName).As(); + Assert.AreEqual(pyObject.DangerousGetAddressOrNull(), proxiedHandle); } // regression for https://github.com/pythonnet/pythonnet/issues/451 diff --git a/src/embed_tests/TestDomainReload.cs b/src/embed_tests/TestDomainReload.cs index 498119d1e..a0f9b63eb 100644 --- a/src/embed_tests/TestDomainReload.cs +++ b/src/embed_tests/TestDomainReload.cs @@ -99,8 +99,7 @@ from Python.EmbeddingTest.Domain import MyClass { Debug.Assert(obj.AsManagedObject(type).GetType() == type); // We only needs its Python handle - PyRuntime.XIncref(obj); - return obj.Handle; + return new NewReference(obj).DangerousMoveToPointer(); } } } diff --git a/src/embed_tests/TestFinalizer.cs b/src/embed_tests/TestFinalizer.cs index 40ab03395..b748a2244 100644 --- a/src/embed_tests/TestFinalizer.cs +++ b/src/embed_tests/TestFinalizer.cs @@ -212,7 +212,9 @@ public void ValidateRefCount() Assert.AreEqual(ptr, e.Handle); Assert.AreEqual(2, e.ImpactedObjects.Count); // Fix for this test, don't do this on general environment +#pragma warning disable CS0618 // Type or member is obsolete Runtime.Runtime.XIncref(e.Reference); +#pragma warning restore CS0618 // Type or member is obsolete return false; }; Finalizer.Instance.IncorrectRefCntResolver += handler; @@ -234,8 +236,9 @@ private static IntPtr CreateStringGarbage() { PyString s1 = new PyString("test_string"); // s2 steal a reference from s1 - PyString s2 = new PyString(StolenReference.DangerousFromPointer(s1.Handle)); - return s1.Handle; + IntPtr address = s1.Reference.DangerousGetAddress(); + PyString s2 = new (StolenReference.DangerousFromPointer(address)); + return address; } } } diff --git a/src/embed_tests/TestNativeTypeOffset.cs b/src/embed_tests/TestNativeTypeOffset.cs index 2d31fe506..d692c24e6 100644 --- a/src/embed_tests/TestNativeTypeOffset.cs +++ b/src/embed_tests/TestNativeTypeOffset.cs @@ -33,7 +33,8 @@ public void LoadNativeTypeOffsetClass() { PyObject sys = Py.Import("sys"); // We can safely ignore the "m" abi flag - var abiflags = sys.GetAttr("abiflags", "".ToPython()).ToString().Replace("m", ""); + var abiflags = sys.HasAttr("abiflags") ? sys.GetAttr("abiflags").ToString() : ""; + abiflags = abiflags.Replace("m", ""); if (!string.IsNullOrEmpty(abiflags)) { string typeName = "Python.Runtime.NativeTypeOffset, Python.Runtime"; diff --git a/src/embed_tests/TestPythonException.cs b/src/embed_tests/TestPythonException.cs index a7cf05c83..a248b6a1f 100644 --- a/src/embed_tests/TestPythonException.cs +++ b/src/embed_tests/TestPythonException.cs @@ -161,7 +161,7 @@ def __init__(self, val): using var tbObj = tbPtr.MoveToPyObject(); // the type returned from PyErr_NormalizeException should not be the same type since a new // exception was raised by initializing the exception - Assert.AreNotEqual(type.Handle, typeObj.Handle); + Assert.IsFalse(PythonReferenceComparer.Instance.Equals(type, typeObj)); // the message should now be the string from the throw exception during normalization Assert.AreEqual("invalid literal for int() with base 10: 'dummy string'", strObj.ToString()); } diff --git a/src/runtime/InternString.cs b/src/runtime/InternString.cs index b6d9a0e4a..decb3981d 100644 --- a/src/runtime/InternString.cs +++ b/src/runtime/InternString.cs @@ -42,7 +42,7 @@ public static void Initialize() Debug.Assert(name == op.As()); SetIntern(name, op); var field = type.GetField("f" + name, PyIdentifierFieldFlags)!; - field.SetValue(null, op.rawPtr); + field.SetValue(null, op.DangerousGetAddressOrNull()); } } @@ -76,7 +76,7 @@ public static bool TryGetInterned(BorrowedReference op, out string s) private static void SetIntern(string s, PyString op) { _string2interns.Add(s, op); - _intern2strings.Add(op.rawPtr, s); + _intern2strings.Add(op.Reference.DangerousGetAddress(), s); } } } diff --git a/src/runtime/PythonTypes/PyObject.cs b/src/runtime/PythonTypes/PyObject.cs index 3d48e22ed..ce86753eb 100644 --- a/src/runtime/PythonTypes/PyObject.cs +++ b/src/runtime/PythonTypes/PyObject.cs @@ -27,7 +27,7 @@ public partial class PyObject : DynamicObject, IDisposable, ISerializable public StackTrace Traceback { get; } = new StackTrace(1); #endif - protected internal IntPtr rawPtr = IntPtr.Zero; + protected IntPtr rawPtr = IntPtr.Zero; internal readonly int run = Runtime.GetRun(); internal BorrowedReference obj => new (rawPtr); @@ -252,6 +252,8 @@ internal void Leak() rawPtr = IntPtr.Zero; } + internal IntPtr DangerousGetAddressOrNull() => rawPtr; + internal void CheckRun() { if (run != Runtime.GetRun()) diff --git a/src/runtime/StateSerialization/RuntimeData.cs b/src/runtime/StateSerialization/RuntimeData.cs index 204e15b5b..0fa484ada 100644 --- a/src/runtime/StateSerialization/RuntimeData.cs +++ b/src/runtime/StateSerialization/RuntimeData.cs @@ -1,20 +1,407 @@ using System; -using System.Collections; using System.Collections.Generic; -using System.Collections.ObjectModel; using System.Diagnostics; using System.IO; using System.Linq; +using System.Reflection; +using System.Reflection.Emit; using System.Runtime.InteropServices; using System.Runtime.Serialization; using System.Runtime.Serialization.Formatters.Binary; - using Python.Runtime.StateSerialization; using static Python.Runtime.Runtime; namespace Python.Runtime { + [System.Serializable] + public sealed class NotSerializedException: SerializationException + { + static string _message = "The underlying C# object has been deleted."; + public NotSerializedException() : base(_message){} + private NotSerializedException(SerializationInfo info, StreamingContext context) : base(info, context){} + override public void GetObjectData(SerializationInfo info, StreamingContext context) => base.GetObjectData(info, context); + } + + [Serializable] + internal static class NonSerializedTypeBuilder + { + + internal static AssemblyName nonSerializedAssemblyName = + new AssemblyName("Python.Runtime.NonSerialized.dll, Version=0.0.0.0, Culture=neutral, PublicKeyToken=null"); + internal static AssemblyBuilder assemblyForNonSerializedClasses = + AppDomain.CurrentDomain.DefineDynamicAssembly(nonSerializedAssemblyName, AssemblyBuilderAccess.Run); + internal static ModuleBuilder moduleBuilder = assemblyForNonSerializedClasses.DefineDynamicModule("NotSerializedModule"); + internal static HashSet dontReimplementMethods = new(){"Finalize", "Dispose", "GetType", "ReferenceEquals", "GetHashCode", "Equals"}; + internal const string notSerializedSuffix = "_NotSerialized"; + // dummy field name to mark classes created by the "non-serializer" so we don't loop-inherit + // on multiple cycles of de/serialization. We use a static field instead of an attribute + // becaues of a bug in mono. Put a space in the name so users will be extremely unlikely + // to create a field with the same name. + internal const string notSerializedFieldName = "__PyNet NonSerialized"; + + private static Func hasVisibility = (tp, attr) => (tp.Attributes & TypeAttributes.VisibilityMask) == attr; + private static Func isNestedType = (tp) => hasVisibility(tp, TypeAttributes.NestedPrivate) || hasVisibility(tp, TypeAttributes.NestedPublic) || hasVisibility(tp, TypeAttributes.NestedFamily) || hasVisibility(tp, TypeAttributes.NestedAssembly); + private static Func isPrivateType = (tp) => hasVisibility(tp, TypeAttributes.NotPublic) || hasVisibility(tp, TypeAttributes.NestedPrivate) || hasVisibility(tp, TypeAttributes.NestedFamily) || hasVisibility( tp, TypeAttributes.NestedAssembly); + private static Func isPublicType = (tp) => hasVisibility(tp, TypeAttributes.Public) || hasVisibility(tp,TypeAttributes.NestedPublic); + private static Func CanCreateType = (tp) => isPublicType(tp) && ((tp.Attributes & TypeAttributes.Sealed) == 0); + + public static object? CreateNewObject(Type baseType) + { + var myType = CreateType(baseType); + if (myType is null) + { + return null; + } + var myObject = Activator.CreateInstance(myType); + return myObject; + } + + static void FillTypeMethods(TypeBuilder tb) + { + var constructors = tb.BaseType.GetConstructors(); + if (constructors.Count() == 0) + { + // no constructors defined, at least declare a default + ConstructorBuilder constructor = tb.DefineDefaultConstructor(MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName); + } + else + { + foreach (var ctor in constructors) + { + var ctorParams = (from param in ctor.GetParameters() select param.ParameterType).ToArray(); + var ctorbuilder = tb.DefineConstructor(ctor.Attributes, ctor.CallingConvention, ctorParams); + ctorbuilder.GetILGenerator().Emit(OpCodes.Ret); + + } + var parameterless = tb.DefineConstructor(MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName, CallingConventions.Standard | CallingConventions.HasThis, Type.EmptyTypes); + parameterless.GetILGenerator().Emit(OpCodes.Ret); + } + + var properties = tb.BaseType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static); + foreach (var prop in properties) + { + CreateProperty(tb, prop); + } + + var methods = tb.BaseType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static); + foreach (var meth in methods) + { + CreateMethod(tb, meth); + } + + ImplementEqualityAndHash(tb); + } + + static string MakeName(Type tp) + { + string @out = tp.Name + notSerializedSuffix; + var parentType = tp.DeclaringType; + while (parentType is not null) + { + // If we have a nested class, we need the whole nester/nestee + // chain with the suffix for each. + @out = parentType.Name + notSerializedSuffix + "+" + @out; + parentType = parentType.DeclaringType; + } + return @out; + } + + public static Type? CreateType(Type tp) + { + if (!CanCreateType(tp)) + { + return null; + } + + Type existingType = assemblyForNonSerializedClasses.GetType(MakeName(tp), throwOnError:false); + if (existingType is not null) + { + return existingType; + } + var parentType = tp.DeclaringType; + if (parentType is not null) + { + // parent types for nested types must be created first. Climb up the + // declaring type chain until we find a "top-level" class. + while (parentType.DeclaringType is not null) + { + parentType = parentType.DeclaringType; + } + CreateTypeInternal(parentType); + Type nestedType = assemblyForNonSerializedClasses.GetType(MakeName(tp), throwOnError:true); + return nestedType; + } + return CreateTypeInternal(tp); + } + + private static Type? CreateTypeInternal(Type baseType) + { + if (!isPublicType(baseType)) + { + // we can't derive from non-public types. + return null; + } + Type existingType = assemblyForNonSerializedClasses.GetType(MakeName(baseType), throwOnError:false); + if (existingType is not null) + { + return existingType; + } + + TypeBuilder tb = GetTypeBuilder(baseType); + SetNonSerializedAttr(tb); + FillTypeMethods(tb); + + var nestedtypes = baseType.GetNestedTypes(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static); + List nestedBuilders = new(); + foreach (var nested in nestedtypes) + { + if (isPrivateType(nested)) + { + continue; + } + var nestedBuilder = tb.DefineNestedType(nested.Name + notSerializedSuffix, + TypeAttributes.NestedPublic, + nested + ); + nestedBuilders.Add(nestedBuilder); + } + var outTp = tb.CreateType(); + foreach(var builder in nestedBuilders) + { + FillTypeMethods(builder); + SetNonSerializedAttr(builder); + builder.CreateType(); + } + return outTp; + } + + private static void ImplementEqualityAndHash(TypeBuilder tb) + { + var hashCodeMb = tb.DefineMethod("GetHashCode", + MethodAttributes.Public | MethodAttributes.Final | MethodAttributes.ReuseSlot, + CallingConventions.Standard, + typeof(int), + Type.EmptyTypes + ); + var getHashIlGen = hashCodeMb.GetILGenerator(); + getHashIlGen.Emit(OpCodes.Ldarg_0); + getHashIlGen.EmitCall(OpCodes.Call, typeof(object).GetMethod("GetType"), Type.EmptyTypes); + getHashIlGen.EmitCall(OpCodes.Call, typeof(Type).GetProperty("Name").GetMethod, Type.EmptyTypes); + getHashIlGen.EmitCall(OpCodes.Call, typeof(string).GetMethod("GetHashCode", Type.EmptyTypes), Type.EmptyTypes); + getHashIlGen.Emit(OpCodes.Ret); + + Type[] equalsArgs = new Type[] {typeof(object), typeof(object)}; + var equalsMb = tb.DefineMethod("Equals", + MethodAttributes.Public | MethodAttributes.Final | MethodAttributes.ReuseSlot, + CallingConventions.Standard, + typeof(bool), + equalsArgs + ); + var equalsIlGen = equalsMb.GetILGenerator(); + equalsIlGen.Emit(OpCodes.Ldarg_0); // this + equalsIlGen.Emit(OpCodes.Ldarg_1); // the other object + equalsIlGen.EmitCall(OpCodes.Call, typeof(object).GetMethod("ReferenceEquals"), equalsArgs); + equalsIlGen.Emit(OpCodes.Ret); + } + + private static void SetNonSerializedAttr(TypeBuilder tb) + { + // Name of the function says we're adding an attribute, but for some + // reason on Mono the attribute is not added, and no exceptions are + // thrown. + tb.DefineField(notSerializedFieldName, typeof(int), FieldAttributes.Public | FieldAttributes.Static); + } + + public static bool IsNonSerializedType(Type tp) + { + return tp.GetField(NonSerializedTypeBuilder.notSerializedFieldName, BindingFlags.Public | BindingFlags.Static) is not null; + + } + + private static TypeBuilder GetTypeBuilder(Type baseType) + { + string typeSignature = baseType.Name + notSerializedSuffix; + TypeBuilder tb = moduleBuilder.DefineType(typeSignature, + baseType.Attributes, + baseType, + baseType.GetInterfaces()); + return tb; + } + + static ILGenerator GenerateExceptionILCode(dynamic builder) + { + ILGenerator ilgen = builder.GetILGenerator(); + var seriExc = typeof(NotSerializedException); + var exCtorInfo = seriExc.GetConstructor(new Type[]{}); + ilgen.Emit(OpCodes.Newobj, exCtorInfo); + ilgen.ThrowException(seriExc); + return ilgen; + } + + private static MethodAttributes GetMethodAttrs (MethodInfo minfo) + { + var methAttributes = minfo.Attributes; + // Always implement/shadow the method + methAttributes &=(~MethodAttributes.Abstract); + methAttributes &=(~MethodAttributes.NewSlot); + methAttributes |= MethodAttributes.ReuseSlot; + methAttributes |= MethodAttributes.HideBySig; + methAttributes |= MethodAttributes.Final; + + if (minfo.IsFinal) + { + // can't override a final method, new it instead. + methAttributes &= (~MethodAttributes.Virtual); + methAttributes |= MethodAttributes.NewSlot; + } + + return methAttributes; + } + + private static void CreateProperty(TypeBuilder tb, PropertyInfo pinfo) + { + string propertyName = pinfo.Name; + Type propertyType = pinfo.PropertyType; + FieldBuilder fieldBuilder = tb.DefineField("_" + propertyName, propertyType, FieldAttributes.Private); + PropertyBuilder propertyBuilder = tb.DefineProperty(propertyName, pinfo.Attributes, propertyType, null); + if (pinfo.GetMethod is not null) + { + var methAttributes = GetMethodAttrs(pinfo.GetMethod); + + MethodBuilder getPropMthdBldr = + tb.DefineMethod("get_" + propertyName, + methAttributes, + propertyType, + Type.EmptyTypes); + GenerateExceptionILCode(getPropMthdBldr); + propertyBuilder.SetGetMethod(getPropMthdBldr); + } + if (pinfo.SetMethod is not null) + { + var methAttributes = GetMethodAttrs(pinfo.SetMethod); + MethodBuilder setPropMthdBldr = + tb.DefineMethod("set_" + propertyName, + methAttributes, + null, + new[] { propertyType }); + + GenerateExceptionILCode(setPropMthdBldr); + propertyBuilder.SetSetMethod(setPropMthdBldr); + } + } + + private static void CreateMethod(TypeBuilder tb, MethodInfo minfo) + { + string methodName = minfo.Name; + + if (dontReimplementMethods.Contains(methodName)) + { + // Some methods must *not* be reimplemented (who wants to throw from Dispose?) + // and some methods we need to implement in a more specific way (Equals, GetHashCode) + return; + } + var methAttributes = GetMethodAttrs(minfo); + var @params = (from paraminfo in minfo.GetParameters() select paraminfo.ParameterType).ToArray(); + MethodBuilder mbuilder = tb.DefineMethod(methodName, methAttributes, minfo.CallingConvention, minfo.ReturnType, @params); + GenerateExceptionILCode(mbuilder); + } + } + + class NotSerializableSerializer : ISerializationSurrogate + { + + public NotSerializableSerializer() + { + } + + public void GetObjectData(object obj, SerializationInfo info, StreamingContext context) + { + // This type is private to System.Runtime.Serialization. We get an + // object of this type when, amongst others, the type didn't exist (yet?) + // (dll not loaded, type was removed/renamed) when we previously + // deserialized the previous domain objects. Don't serialize this + // object. + if (obj.GetType().Name == "TypeLoadExceptionHolder") + { + obj = null!; + return; + } + + MaybeType type = obj.GetType(); + + if (NonSerializedTypeBuilder.IsNonSerializedType(type.Value)) + { + // Don't serialize a _NotSerialized. Serialize the base type, and deserialize as a _NotSerialized + type = type.Value.BaseType; + obj = null!; + } + + info.AddValue("notSerialized_tp", type); + + } + + public object SetObjectData(object obj, SerializationInfo info, StreamingContext context, ISurrogateSelector selector) + { + if (info is null) + { + // `obj` is of type TypeLoadExceptionHolder. This means the type + // we're trying to load doesn't exist anymore or we haven't created + // it yet, and the runtime doesn't even gives us the chance to + // recover from this as info is null. We may even get objects + // this serializer did not serialize in a previous domain, + // like in the case of the "namespace_rename" domain reload + // test: the object successfully serialized, but it cannot be + // deserialized. + // just return null. + return null!; + } + + object nameObj = null!; + try + { + nameObj = info.GetValue("notSerialized_tp", typeof(object)); + } + catch + { + // we didn't find the expected information. We don't know + // what to do with this; return null. + return null!; + } + Debug.Assert(nameObj.GetType() == typeof(MaybeType)); + MaybeType name = (MaybeType)nameObj; + Debug.Assert(name.Valid); + if (!name.Valid) + { + // The type couldn't be loaded + return null!; + } + + obj = NonSerializedTypeBuilder.CreateNewObject(name.Value)!; + return obj; + } + } + + class NonSerializableSelector : SurrogateSelector + { + public override ISerializationSurrogate? GetSurrogate (Type type, StreamingContext context, out ISurrogateSelector selector) + { + if (type is null) + { + throw new ArgumentNullException(); + } + selector = this; + if (type.IsSerializable) + { + return null; // use whichever default + } + else + { + return new NotSerializableSerializer(); + } + } + } + public static class RuntimeData { private static Type? _formatterType; @@ -47,6 +434,73 @@ static void ClearCLRData () } } + internal static void SerializeNonSerializableTypes () + { + // Serialize the Types (Type objects) that couldn't be (de)serialized. + // This needs to be done otherwise at deserialization time we get + // TypeLoadExceptionHolder objects and we can't even recover. + + // We don't serialize the "_NotSerialized" Types, we serialize their base + // to recreate the "_NotSerialized" versions on the next domain load. + + Dictionary invalidTypes = new(); + foreach(var tp in NonSerializedTypeBuilder.assemblyForNonSerializedClasses.GetTypes()) + { + invalidTypes[tp.FullName] = new MaybeType(tp.BaseType); + } + + // delete previous data if any + BorrowedReference oldCapsule = PySys_GetObject("clr_nonSerializedTypes"); + if (!oldCapsule.IsNull) + { + IntPtr oldData = PyCapsule_GetPointer(oldCapsule, IntPtr.Zero); + PyMem_Free(oldData); + PyCapsule_SetPointer(oldCapsule, IntPtr.Zero); + } + IFormatter formatter = CreateFormatter(); + var ms = new MemoryStream(); + formatter.Serialize(ms, invalidTypes); + + Debug.Assert(ms.Length <= int.MaxValue); + byte[] data = ms.GetBuffer(); + + IntPtr mem = PyMem_Malloc(ms.Length + IntPtr.Size); + Marshal.WriteIntPtr(mem, (IntPtr)ms.Length); + Marshal.Copy(data, 0, mem + IntPtr.Size, (int)ms.Length); + + using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); + int res = PySys_SetObject("clr_nonSerializedTypes", capsule.BorrowOrThrow()); + PythonException.ThrowIfIsNotZero(res); + + } + + internal static void DeserializeNonSerializableTypes () + { + BorrowedReference capsule = PySys_GetObject("clr_nonSerializedTypes"); + if (capsule.IsNull) + { + // nothing to do. + return; + } + // get the memory stream from the capsule. + IntPtr mem = PyCapsule_GetPointer(capsule, IntPtr.Zero); + int length = (int)Marshal.ReadIntPtr(mem); + byte[] data = new byte[length]; + Marshal.Copy(mem + IntPtr.Size, data, 0, length); + var ms = new MemoryStream(data); + var formatter = CreateFormatter(); + var storage = (Dictionary)formatter.Deserialize(ms); + foreach(var item in storage) + { + if(item.Value.Valid) + { + // recreate the "_NotSerialized" Types + NonSerializedTypeBuilder.CreateType(item.Value.Value); + } + } + + } + internal static void Stash() { var runtimeStorage = new PythonNetState @@ -74,6 +528,8 @@ internal static void Stash() using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); int res = PySys_SetObject("clr_data", capsule.BorrowOrThrow()); PythonException.ThrowIfIsNotZero(res); + SerializeNonSerializableTypes(); + } internal static void RestoreRuntimeData() @@ -90,6 +546,9 @@ internal static void RestoreRuntimeData() private static void RestoreRuntimeDataImpl() { + // The "_NotSerialized" Types must exist before the rest of the data + // is deserialized. + DeserializeNonSerializableTypes(); BorrowedReference capsule = PySys_GetObject("clr_data"); if (capsule.IsNull) { @@ -123,19 +582,6 @@ public static void ClearStash() PySys_SetObject("clr_data", default); } - static bool CheckSerializable (object o) - { - Type type = o.GetType(); - do - { - if (!type.IsSerializable) - { - return false; - } - } while ((type = type.BaseType) != null); - return true; - } - private static SharedObjectsState SaveRuntimeDataObjects() { var contexts = new Dictionary>(PythonReferenceComparer.Instance); @@ -150,7 +596,6 @@ private static SharedObjectsState SaveRuntimeDataObjects() foreach (var pyObj in extensions) { var extension = (ExtensionType)ManagedType.GetManagedObject(pyObj)!; - Debug.Assert(CheckSerializable(extension)); var context = extension.Save(pyObj); if (context is not null) { @@ -170,6 +615,7 @@ private static SharedObjectsState SaveRuntimeDataObjects() .ToList(); foreach (var pyObj in reflectedObjects) { + // Console.WriteLine($"saving object: {pyObj} {pyObj.rawPtr} "); // Wrapper must be the CLRObject var clrObj = (CLRObject)ManagedType.GetManagedObject(pyObj)!; object inst = clrObj.inst; @@ -199,10 +645,6 @@ private static SharedObjectsState SaveRuntimeDataObjects() { if (!item.Stored) { - if (!CheckSerializable(item.Instance)) - { - continue; - } var clrO = wrappers[item.Instance].First(); foreach (var @ref in item.PyRefs) { @@ -254,7 +696,10 @@ internal static IFormatter CreateFormatter() { return FormatterType != null ? (IFormatter)Activator.CreateInstance(FormatterType) - : new BinaryFormatter(); + : new BinaryFormatter() + { + SurrogateSelector = new NonSerializableSelector(), + }; } } } diff --git a/src/runtime/Types/ReflectedClrType.cs b/src/runtime/Types/ReflectedClrType.cs index b787939be..d3d89bdb8 100644 --- a/src/runtime/Types/ReflectedClrType.cs +++ b/src/runtime/Types/ReflectedClrType.cs @@ -116,6 +116,6 @@ static ReflectedClrType AllocateClass(Type clrType) return new ReflectedClrType(type.Steal()); } - public override bool Equals(PyObject? other) => other != null && rawPtr == other.rawPtr; + public override bool Equals(PyObject? other) => rawPtr == other?.DangerousGetAddressOrNull(); public override int GetHashCode() => rawPtr.GetHashCode(); } diff --git a/src/runtime/Util/PythonReferenceComparer.cs b/src/runtime/Util/PythonReferenceComparer.cs index dd78f912d..63c35df57 100644 --- a/src/runtime/Util/PythonReferenceComparer.cs +++ b/src/runtime/Util/PythonReferenceComparer.cs @@ -13,10 +13,10 @@ public sealed class PythonReferenceComparer : IEqualityComparer public static PythonReferenceComparer Instance { get; } = new PythonReferenceComparer(); public bool Equals(PyObject? x, PyObject? y) { - return x?.rawPtr == y?.rawPtr; + return x?.DangerousGetAddressOrNull() == y?.DangerousGetAddressOrNull(); } - public int GetHashCode(PyObject obj) => obj.rawPtr.GetHashCode(); + public int GetHashCode(PyObject obj) => obj.DangerousGetAddressOrNull().GetHashCode(); private PythonReferenceComparer() { } } diff --git a/tests/domain_tests/TestRunner.cs b/tests/domain_tests/TestRunner.cs index 4f6a3ea28..0097ea202 100644 --- a/tests/domain_tests/TestRunner.cs +++ b/tests/domain_tests/TestRunner.cs @@ -1132,6 +1132,168 @@ import System ", }, + new TestCase + { + Name = "serialize_not_serializable", + DotNetBefore = @" + namespace TestNamespace + { + + public class NotSerializableTextWriter : System.IO.TextWriter + { + override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} } + } + + [System.Serializable] + public static class SerializableWriter + { + private static System.IO.TextWriter _writer = null; + + public static System.IO.TextWriter Writer {get { return _writer; }} + + public static void SetWriter() + { + _writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter()); + } + } + }", + DotNetAfter = @" + namespace TestNamespace + { + + public class NotSerializableTextWriter : System.IO.TextWriter + { + override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} } + } + + [System.Serializable] + public static class SerializableWriter + { + private static System.IO.TextWriter _writer = null; + + public static System.IO.TextWriter Writer {get { return _writer; }} + + public static void SetWriter(System.IO.TextWriter w) + { + _writer = System.IO.TextWriter.Synchronized(w); + } + } + }", + PythonCode = @" +import clr +import sys +clr.AddReference('DomainTests') +import TestNamespace +import System + +def before_reload(): + + TestNamespace.SerializableWriter.SetWriter(); + sys.log_writer = TestNamespace.SerializableWriter.Writer + +def after_reload(): + + assert sys.log_writer is not None + try: + encoding = sys.log_writer.Write('baba') + except System.Runtime.Serialization.SerializationException: + pass + else: + raise AssertionError('Serialized non-serializable objects should be deserialized to throwing objects') +", + }, + new TestCase + { + Name = "serialize_not_serializable_interface", + DotNetBefore = @" + namespace TestNamespace + { + public interface MyInterface + { + int InterfaceMethod(); + } + + public class NotSerializableInterfaceImplement : MyInterface + { + public int value = -1; + int MyInterface.InterfaceMethod() + { + return value; + } + } + + [System.Serializable] + public static class SerializableWriter + { + private static MyInterface _iface = null; + + public static MyInterface Writer {get { return _iface; }} + + public static void SetInterface() + { + var temp = new NotSerializableInterfaceImplement(); + temp.value = 12315; + _iface = temp; + } + } + }", + DotNetAfter = @" + namespace TestNamespace + { + public interface MyInterface + { + int InterfaceMethod(); + } + + public class NotSerializableInterfaceImplement : MyInterface + { + public int value = -1; + int MyInterface.InterfaceMethod() + { + return value; + } + } + + [System.Serializable] + public static class SerializableWriter + { + private static MyInterface _iface = null; + + public static MyInterface Writer {get { return _iface; }} + + public static void SetInterface() + { + var temp = new NotSerializableInterfaceImplement(); + temp.value = 123124; + _iface = temp; + } + } + }", + PythonCode = @" +import clr +import sys +clr.AddReference('DomainTests') +import TestNamespace +import System + +def before_reload(): + + TestNamespace.SerializableWriter.SetInterface(); + sys.log_writer = TestNamespace.SerializableWriter.Writer + assert(sys.log_writer.InterfaceMethod() == 12315) + +def after_reload(): + + assert sys.log_writer is not None + try: + retcode = sys.log_writer.InterfaceMethod() + print(f'retcode of InterfaceMethod is {retcode}') + except System.Runtime.Serialization.SerializationException: + pass + else: + raise AssertionError('Serialized non-serializable objects should be deserialized to throwing objects') +", + }, }; /// diff --git a/tests/domain_tests/test_domain_reload.py b/tests/domain_tests/test_domain_reload.py index 8999e481b..bbff73493 100644 --- a/tests/domain_tests/test_domain_reload.py +++ b/tests/domain_tests/test_domain_reload.py @@ -1,6 +1,7 @@ import subprocess import os import platform +from unittest import skip import pytest @@ -88,3 +89,10 @@ def test_nested_type(): def test_import_after_reload(): _run_test("import_after_reload") + +def test_serialize_not_serializable(): + _run_test("serialize_not_serializable") + +@skip("interface methods cannot be overriden") +def test_serialize_not_serializable_interface(): + _run_test("serialize_not_serializable_interface") \ No newline at end of file diff --git a/tests/test_import.py b/tests/test_import.py index 25877be15..877eacd84 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -15,7 +15,7 @@ def test_relative_missing_import(): def test_import_all_on_second_time(): """Test import all attributes after a normal import without '*'. - Due to import * only allowed at module level, the test body splited + Due to import * only allowed at module level, the test body splitted to a module file.""" from . import importtest del sys.modules[importtest.__name__]