From d7f31bfa430ed725bca554f21d83f5594e6650a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Bourbonnais?= <6788684+BadSingleton@users.noreply.github.com> Date: Tue, 6 Feb 2024 15:47:57 -0500 Subject: [PATCH 1/6] Expose an API for users to specify their own formatter Adds post-serialization and pre-deserialization hooks for additional customization. --- src/runtime/StateSerialization/RuntimeData.cs | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/runtime/StateSerialization/RuntimeData.cs b/src/runtime/StateSerialization/RuntimeData.cs index 204e15b5b..61ead10f8 100644 --- a/src/runtime/StateSerialization/RuntimeData.cs +++ b/src/runtime/StateSerialization/RuntimeData.cs @@ -17,20 +17,29 @@ namespace Python.Runtime { public static class RuntimeData { - private static Type? _formatterType; - public static Type? FormatterType + + public delegate IFormatter FormatterFactoryDelegate(); + private readonly static FormatterFactoryDelegate DefaultFormatter = () => new BinaryFormatter(); + private static FormatterFactoryDelegate? _formatter { get; set; } = null; + + public static FormatterFactoryDelegate Formatter { - get => _formatterType; - set + get { - if (!typeof(IFormatter).IsAssignableFrom(value)) + if (_formatter is null) { - throw new ArgumentException("Not a type implemented IFormatter"); + return DefaultFormatter; } - _formatterType = value; + return _formatter; + } + set + { + _formatter = value; } } - + public delegate void SerializationHookDelegate(); + public static SerializationHookDelegate? PostStashHook {get; set;} = null; + public static SerializationHookDelegate? PreRestoreHook {get; set;} = null; public static ICLRObjectStorer? WrappersStorer { get; set; } /// @@ -74,6 +83,7 @@ internal static void Stash() using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); int res = PySys_SetObject("clr_data", capsule.BorrowOrThrow()); PythonException.ThrowIfIsNotZero(res); + PostStashHook?.Invoke(); } internal static void RestoreRuntimeData() @@ -90,6 +100,7 @@ internal static void RestoreRuntimeData() private static void RestoreRuntimeDataImpl() { + PreRestoreHook?.Invoke(); BorrowedReference capsule = PySys_GetObject("clr_data"); if (capsule.IsNull) { @@ -252,9 +263,7 @@ private static void RestoreRuntimeDataObjects(SharedObjectsState storage) internal static IFormatter CreateFormatter() { - return FormatterType != null ? - (IFormatter)Activator.CreateInstance(FormatterType) - : new BinaryFormatter(); + return Formatter(); } } } From aa37f89725bab6b12ba61da2ef1546930cfdc4a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Bourbonnais?= <6788684+BadSingleton@users.noreply.github.com> Date: Wed, 21 Feb 2024 15:07:32 -0500 Subject: [PATCH 2/6] Add API for capsuling data when serializing * And revert the breaking change of the FormatterType removal * Add CHANGELOG entry --- CHANGELOG.md | 3 + src/runtime/StateSerialization/RuntimeData.cs | 114 ++++++++++++++++-- 2 files changed, 106 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fdab9bf64..fd78f138f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ This document follows the conventions laid out in [Keep a CHANGELOG][]. ### Added ### Changed +- Added a `FormatterFactory` member in RuntimeData to create formatters with parameters. For compatibility, the `FormatterType` member is still present and has precedence when defining both `FormatterFactory` and `FormatterType` +- Added a post-serialization and a pre-deserialization step callbacks to extend (de)serialization process +- Added an API to stash serialized data on Python capsules ### Fixed diff --git a/src/runtime/StateSerialization/RuntimeData.cs b/src/runtime/StateSerialization/RuntimeData.cs index 61ead10f8..ba52ec9f5 100644 --- a/src/runtime/StateSerialization/RuntimeData.cs +++ b/src/runtime/StateSerialization/RuntimeData.cs @@ -18,28 +18,47 @@ namespace Python.Runtime public static class RuntimeData { - public delegate IFormatter FormatterFactoryDelegate(); - private readonly static FormatterFactoryDelegate DefaultFormatter = () => new BinaryFormatter(); - private static FormatterFactoryDelegate? _formatter { get; set; } = null; + private readonly static Func DefaultFormatter = () => new BinaryFormatter(); + private static Func? _formatterFactory { get; set; } = null; - public static FormatterFactoryDelegate Formatter + public static Func FormatterFactory { get { - if (_formatter is null) + if (_formatterFactory is null) { return DefaultFormatter; } - return _formatter; + return _formatterFactory; } set { - _formatter = value; + _formatterFactory = value; } } - public delegate void SerializationHookDelegate(); - public static SerializationHookDelegate? PostStashHook {get; set;} = null; - public static SerializationHookDelegate? PreRestoreHook {get; set;} = null; + + private static Type? _formatterType = null; + public static Type? FormatterType + { + get => _formatterType; + set + { + if (!typeof(IFormatter).IsAssignableFrom(value)) + { + throw new ArgumentException("Not a type implemented IFormatter"); + } + _formatterType = value; + } + } + + /// + /// Callback called as a last step in the serialization process + /// + public static Action? PostStashHook {get; set;} = null; + /// + /// Callback called as the first step in the deserialization process + /// + public static Action? PreRestoreHook {get; set;} = null; public static ICLRObjectStorer? WrappersStorer { get; set; } /// @@ -261,9 +280,82 @@ private static void RestoreRuntimeDataObjects(SharedObjectsState storage) } } + /// + /// Frees the pointer stored in a Python capsule and the capsule stored + /// on the sys module object with the given key if it exists. + /// + /// + /// The memory on the capsule must have been allocated via StashDataInCapsule + /// + /// The name given to the capsule on the `sys` module object + public static void FreeCapsuleData(string key) + { + BorrowedReference oldCapsule = PySys_GetObject(key); + if (!oldCapsule.IsNull) + { + IntPtr oldData = PyCapsule_GetPointer(oldCapsule, IntPtr.Zero); + Marshal.FreeHGlobal(oldData); + PyCapsule_SetPointer(oldCapsule, IntPtr.Zero); + PySys_SetObject(key, null); + } + } + + /// + /// Stores the data parameter in a Python capsule and stores + /// the capsule on the `sys` module object with the name . + /// This method allocates global memory to hold the data of the + /// parameter. + /// + /// + /// No checks on pre-existing names on the `sys` module object are made. + /// + /// The name given to the capsule on the `sys` module object + /// The data to be contained in the capsule + public static void StashDataInCapsule(string key, byte[] data) + { + IntPtr mem = Marshal.AllocHGlobal(IntPtr.Size + data.Length); + // store the length of the buffer first + Marshal.WriteIntPtr(mem, (IntPtr)data.Length); + Marshal.Copy(data, 0, mem + IntPtr.Size, data.Length); + + using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); + int res = PySys_SetObject(key, capsule.BorrowOrThrow()); + PythonException.ThrowIfIsNotZero(res); + } + + /// + /// Retreives the pointer to previously stored data on a Python capsule. + /// Throws if the object corresponding to the parameter + /// on the `sys` module object is not a capsule. + /// + /// The name given to the capsule on the `sys` module object + /// The pointer to the data, or IntPtr.Zero if name matches the key + public static IntPtr GetDataFromCapsule(string key) + { + BorrowedReference capsule = PySys_GetObject(key); + if (capsule.IsNull) + { + // nothing to do. + return IntPtr.Zero; + } + var ptr = PyCapsule_GetPointer(capsule, IntPtr.Zero); + if (ptr == IntPtr.Zero) + { + // The PyCapsule API returns NULL on error; NULL cannot be stored + // as a capsule's value + PythonException.ThrowIfIsNull(null); + } + return ptr; + } + internal static IFormatter CreateFormatter() { - return Formatter(); + + if (FormatterType != null) + { + return (IFormatter)Activator.CreateInstance(FormatterType); + } + return FormatterFactory(); } } } From 647c75c782be3fdd350e4b559a2536e34a204edf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Bourbonnais?= <6788684+BadSingleton@users.noreply.github.com> Date: Sun, 17 Mar 2024 11:05:23 -0400 Subject: [PATCH 3/6] Part one of review fixes --- src/runtime/StateSerialization/RuntimeData.cs | 45 ++++++++++++------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/src/runtime/StateSerialization/RuntimeData.cs b/src/runtime/StateSerialization/RuntimeData.cs index ba52ec9f5..c20040c76 100644 --- a/src/runtime/StateSerialization/RuntimeData.cs +++ b/src/runtime/StateSerialization/RuntimeData.cs @@ -54,11 +54,11 @@ public static Type? FormatterType /// /// Callback called as a last step in the serialization process /// - public static Action? PostStashHook {get; set;} = null; + public static Action? PostStashHook { get; set; } = null; /// /// Callback called as the first step in the deserialization process /// - public static Action? PreRestoreHook {get; set;} = null; + public static Action? PreRestoreHook { get; set; } = null; public static ICLRObjectStorer? WrappersStorer { get; set; } /// @@ -280,16 +280,17 @@ private static void RestoreRuntimeDataObjects(SharedObjectsState storage) } } + static readonly string serialization_key_namepsace = "pythonnet_serialization_"; /// - /// Frees the pointer stored in a Python capsule and the capsule stored - /// on the sys module object with the given key if it exists. + /// Removes the serialization capsule from the `sys` module object. /// /// - /// The memory on the capsule must have been allocated via StashDataInCapsule + /// The serialization data must have been set with StashSerializationData /// /// The name given to the capsule on the `sys` module object - public static void FreeCapsuleData(string key) + public static void FreeSerializationData(string key) { + key = serialization_key_namepsace + key; BorrowedReference oldCapsule = PySys_GetObject(key); if (!oldCapsule.IsNull) { @@ -301,42 +302,51 @@ public static void FreeCapsuleData(string key) } /// - /// Stores the data parameter in a Python capsule and stores + /// Stores the data in the argument in a Python capsule and stores /// the capsule on the `sys` module object with the name . - /// This method allocates global memory to hold the data of the - /// parameter. /// /// /// No checks on pre-existing names on the `sys` module object are made. /// /// The name given to the capsule on the `sys` module object - /// The data to be contained in the capsule - public static void StashDataInCapsule(string key, byte[] data) + /// A MemoryStream that contains the data to be placed in the capsule + public static void StashSerializationData(string key, MemoryStream stream) { + var data = stream.GetBuffer(); IntPtr mem = Marshal.AllocHGlobal(IntPtr.Size + data.Length); // store the length of the buffer first Marshal.WriteIntPtr(mem, (IntPtr)data.Length); Marshal.Copy(data, 0, mem + IntPtr.Size, data.Length); + try + { using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); int res = PySys_SetObject(key, capsule.BorrowOrThrow()); PythonException.ThrowIfIsNotZero(res); + } + catch + { + Marshal.FreeHGlobal(mem); + } + } + static byte[] emptyBuffer = new byte[0]; /// - /// Retreives the pointer to previously stored data on a Python capsule. + /// Retreives the previously stored data on a Python capsule. /// Throws if the object corresponding to the parameter /// on the `sys` module object is not a capsule. /// /// The name given to the capsule on the `sys` module object - /// The pointer to the data, or IntPtr.Zero if name matches the key - public static IntPtr GetDataFromCapsule(string key) + /// A MemoryStream containing the previously saved serialization data. + /// The stream is empty if no name matches the key. + public static MemoryStream GetSerializationData(string key) { BorrowedReference capsule = PySys_GetObject(key); if (capsule.IsNull) { // nothing to do. - return IntPtr.Zero; + return new MemoryStream(emptyBuffer, writable:false); } var ptr = PyCapsule_GetPointer(capsule, IntPtr.Zero); if (ptr == IntPtr.Zero) @@ -345,7 +355,10 @@ public static IntPtr GetDataFromCapsule(string key) // as a capsule's value PythonException.ThrowIfIsNull(null); } - return ptr; + var len = (int)Marshal.ReadIntPtr(ptr); + byte[] buffer = new byte[len]; + Marshal.Copy(ptr+IntPtr.Size, buffer, 0, len); + return new MemoryStream(buffer, writable:false); } internal static IFormatter CreateFormatter() From 02158f3bc5d1eed9c8951e2d795b325096b0a2b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Bourbonnais?= <6788684+BadSingleton@users.noreply.github.com> Date: Fri, 29 Mar 2024 13:39:46 -0400 Subject: [PATCH 4/6] Adds tests for custom serialization --- tests/domain_tests/TestRunner.cs | 117 +++++++++++++++++++++++ tests/domain_tests/test_domain_reload.py | 3 + 2 files changed, 120 insertions(+) diff --git a/tests/domain_tests/TestRunner.cs b/tests/domain_tests/TestRunner.cs index 4f6a3ea28..bbee81b3d 100644 --- a/tests/domain_tests/TestRunner.cs +++ b/tests/domain_tests/TestRunner.cs @@ -1132,6 +1132,66 @@ import System ", }, + new TestCase + { + Name = "test_serialize_unserializable_object", + DotNetBefore = @" + namespace TestNamespace + { + public class NotSerializableTextWriter : System.IO.TextWriter + { + override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} } + } + [System.Serializable] + public static class SerializableWriter + { + private static System.IO.TextWriter _writer = null; + public static System.IO.TextWriter Writer {get { return _writer; }} + public static void CreateInternalWriter() + { + _writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter()); + } + } + } +", + DotNetAfter = @" + namespace TestNamespace + { + public class NotSerializableTextWriter : System.IO.TextWriter + { + override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} } + } + [System.Serializable] + public static class SerializableWriter + { + private static System.IO.TextWriter _writer = null; + public static System.IO.TextWriter Writer {get { return _writer; }} + public static void CreateInternalWriter() + { + _writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter()); + } + } + } + ", + PythonCode = @" +import sys + +def before_reload(): + import clr + import System + clr.AddReference('DomainTests') + import TestNamespace + TestNamespace.SerializableWriter.CreateInternalWriter(); + sys.__obj = TestNamespace.SerializableWriter.Writer + sys.__obj.WriteLine('test') + +def after_reload(): + import clr + import System + sys.__obj.WriteLine('test') + + ", + } }; /// @@ -1142,7 +1202,59 @@ import System const string CaseRunnerTemplate = @" using System; using System.IO; +using System.Runtime.Serialization; +using System.Runtime.Serialization.Formatters.Binary; using Python.Runtime; + +namespace Serialization +{{ + // Classes in this namespace is mostly useful for test_serialize_unserializable_object + class NotSerializableSerializer : ISerializationSurrogate + {{ + public NotSerializableSerializer() + {{ + }} + public void GetObjectData(object obj, SerializationInfo info, StreamingContext context) + {{ + info.AddValue(""notSerialized_tp"", obj.GetType()); + }} + public object SetObjectData(object obj, SerializationInfo info, StreamingContext context, ISurrogateSelector selector) + {{ + if (info == null) + {{ + return null; + }} + Type typeObj = info.GetValue(""notSerialized_tp"", typeof(Type)) as Type; + if (typeObj == null) + {{ + return null; + }} + + obj = Activator.CreateInstance(typeObj); + return obj; + }} + }} + class NonSerializableSelector : SurrogateSelector + {{ + public override ISerializationSurrogate GetSurrogate(Type type, StreamingContext context, out ISurrogateSelector selector) + {{ + if (type == null) + {{ + throw new ArgumentNullException(); + }} + selector = (ISurrogateSelector)this; + if (type.IsSerializable) + {{ + return null; // use whichever default + }} + else + {{ + return (ISerializationSurrogate)(new NotSerializableSerializer()); + }} + }} + }} +}} + namespace CaseRunner {{ class CaseRunner @@ -1151,6 +1263,11 @@ public static int Main() {{ try {{ + RuntimeData.FormatterFactory = () => + {{ + return new BinaryFormatter(){{SurrogateSelector = new Serialization.NonSerializableSelector()}}; + }}; + PythonEngine.Initialize(); using (Py.GIL()) {{ diff --git a/tests/domain_tests/test_domain_reload.py b/tests/domain_tests/test_domain_reload.py index 8999e481b..1e5e8e81b 100644 --- a/tests/domain_tests/test_domain_reload.py +++ b/tests/domain_tests/test_domain_reload.py @@ -88,3 +88,6 @@ def test_nested_type(): def test_import_after_reload(): _run_test("import_after_reload") + +def test_import_after_reload(): + _run_test("test_serialize_unserializable_object") \ No newline at end of file From a05890013b13613374a979b10b985364f6e79d39 Mon Sep 17 00:00:00 2001 From: Benedikt Reinartz Date: Fri, 26 Apr 2024 11:31:04 +0200 Subject: [PATCH 5/6] Add NoopFormatter and fall back to it if BinaryFormatter is not available --- .../StateSerialization/NoopFormatter.cs | 14 ++++++ src/runtime/StateSerialization/RuntimeData.cs | 47 ++++++++++--------- 2 files changed, 40 insertions(+), 21 deletions(-) create mode 100644 src/runtime/StateSerialization/NoopFormatter.cs diff --git a/src/runtime/StateSerialization/NoopFormatter.cs b/src/runtime/StateSerialization/NoopFormatter.cs new file mode 100644 index 000000000..f05b7ebb2 --- /dev/null +++ b/src/runtime/StateSerialization/NoopFormatter.cs @@ -0,0 +1,14 @@ +using System; +using System.IO; +using System.Runtime.Serialization; + +namespace Python.Runtime; + +public class NoopFormatter : IFormatter { + public object Deserialize(Stream s) => throw new NotImplementedException(); + public void Serialize(Stream s, object o) {} + + public SerializationBinder? Binder { get; set; } + public StreamingContext Context { get; set; } + public ISurrogateSelector? SurrogateSelector { get; set; } +} diff --git a/src/runtime/StateSerialization/RuntimeData.cs b/src/runtime/StateSerialization/RuntimeData.cs index c20040c76..c4dead138 100644 --- a/src/runtime/StateSerialization/RuntimeData.cs +++ b/src/runtime/StateSerialization/RuntimeData.cs @@ -1,7 +1,5 @@ using System; -using System.Collections; using System.Collections.Generic; -using System.Collections.ObjectModel; using System.Diagnostics; using System.IO; using System.Linq; @@ -18,21 +16,28 @@ namespace Python.Runtime public static class RuntimeData { - private readonly static Func DefaultFormatter = () => new BinaryFormatter(); - private static Func? _formatterFactory { get; set; } = null; - - public static Func FormatterFactory + public readonly static Func DefaultFormatterFactory = () => { - get + try { - if (_formatterFactory is null) - { - return DefaultFormatter; - } - return _formatterFactory; + return new BinaryFormatter(); } + catch + { + return new NoopFormatter(); + } + }; + + private static Func _formatterFactory { get; set; } = DefaultFormatterFactory; + + public static Func FormatterFactory + { + get => _formatterFactory; set { + if (value == null) + throw new ArgumentNullException(nameof(value)); + _formatterFactory = value; } } @@ -302,8 +307,8 @@ public static void FreeSerializationData(string key) } /// - /// Stores the data in the argument in a Python capsule and stores - /// the capsule on the `sys` module object with the name . + /// Stores the data in the argument in a Python capsule and stores + /// the capsule on the `sys` module object with the name . /// /// /// No checks on pre-existing names on the `sys` module object are made. @@ -320,10 +325,10 @@ public static void StashSerializationData(string key, MemoryStream stream) try { - using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); - int res = PySys_SetObject(key, capsule.BorrowOrThrow()); - PythonException.ThrowIfIsNotZero(res); - } + using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); + int res = PySys_SetObject(key, capsule.BorrowOrThrow()); + PythonException.ThrowIfIsNotZero(res); + } catch { Marshal.FreeHGlobal(mem); @@ -333,12 +338,12 @@ public static void StashSerializationData(string key, MemoryStream stream) static byte[] emptyBuffer = new byte[0]; /// - /// Retreives the previously stored data on a Python capsule. + /// Retreives the previously stored data on a Python capsule. /// Throws if the object corresponding to the parameter /// on the `sys` module object is not a capsule. /// /// The name given to the capsule on the `sys` module object - /// A MemoryStream containing the previously saved serialization data. + /// A MemoryStream containing the previously saved serialization data. /// The stream is empty if no name matches the key. public static MemoryStream GetSerializationData(string key) { @@ -351,7 +356,7 @@ public static MemoryStream GetSerializationData(string key) var ptr = PyCapsule_GetPointer(capsule, IntPtr.Zero); if (ptr == IntPtr.Zero) { - // The PyCapsule API returns NULL on error; NULL cannot be stored + // The PyCapsule API returns NULL on error; NULL cannot be stored // as a capsule's value PythonException.ThrowIfIsNull(null); } From 6e37568dd167c5d6b972dc2870bf64591bc0d8ad Mon Sep 17 00:00:00 2001 From: Benedikt Reinartz Date: Mon, 6 May 2024 14:22:27 +0200 Subject: [PATCH 6/6] Use TryGetBuffer instead of GetBuffer --- src/runtime/StateSerialization/RuntimeData.cs | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/runtime/StateSerialization/RuntimeData.cs b/src/runtime/StateSerialization/RuntimeData.cs index c4dead138..8eda9ce0b 100644 --- a/src/runtime/StateSerialization/RuntimeData.cs +++ b/src/runtime/StateSerialization/RuntimeData.cs @@ -317,21 +317,28 @@ public static void FreeSerializationData(string key) /// A MemoryStream that contains the data to be placed in the capsule public static void StashSerializationData(string key, MemoryStream stream) { - var data = stream.GetBuffer(); - IntPtr mem = Marshal.AllocHGlobal(IntPtr.Size + data.Length); - // store the length of the buffer first - Marshal.WriteIntPtr(mem, (IntPtr)data.Length); - Marshal.Copy(data, 0, mem + IntPtr.Size, data.Length); - - try + if (stream.TryGetBuffer(out var data)) { - using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); - int res = PySys_SetObject(key, capsule.BorrowOrThrow()); - PythonException.ThrowIfIsNotZero(res); + IntPtr mem = Marshal.AllocHGlobal(IntPtr.Size + data.Count); + + // store the length of the buffer first + Marshal.WriteIntPtr(mem, (IntPtr)data.Count); + Marshal.Copy(data.Array, data.Offset, mem + IntPtr.Size, data.Count); + + try + { + using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); + int res = PySys_SetObject(key, capsule.BorrowOrThrow()); + PythonException.ThrowIfIsNotZero(res); + } + catch + { + Marshal.FreeHGlobal(mem); + } } - catch + else { - Marshal.FreeHGlobal(mem); + throw new NotImplementedException($"{nameof(stream)} must be exposable"); } }