Skip to content

Commit b59e2bb

Browse files
committed
Add nested class support
plus guard rails for private classes and review fixes
1 parent c75ee46 commit b59e2bb

File tree

1 file changed

+124
-28
lines changed

1 file changed

+124
-28
lines changed

src/runtime/StateSerialization/RuntimeData.cs

+124-28
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
using System;
2-
using System.Collections;
32
using System.Collections.Generic;
4-
using System.Collections.ObjectModel;
53
using System.Diagnostics;
6-
using System.Dynamic;
74
using System.IO;
85
using System.Linq;
96
using System.Reflection;
@@ -45,38 +42,140 @@ internal static class NonSerializedTypeBuilder
4542
internal static HashSet<string> dontReimplementMethods = new(){"Finalize", "Dispose", "GetType", "ReferenceEquals", "GetHashCode", "Equals"};
4643
const string notSerializedSuffix = "_NotSerialized";
4744

48-
public static object CreateNewObject(Type baseType)
45+
private static Func<Type, TypeAttributes, bool> hasVisibility = (tp, attr) => (tp.Attributes & TypeAttributes.VisibilityMask) == attr;
46+
private static Func<Type, bool> isNestedType = (tp) => hasVisibility(tp, TypeAttributes.NestedPrivate) || hasVisibility(tp, TypeAttributes.NestedPublic) || hasVisibility(tp, TypeAttributes.NestedFamily) || hasVisibility(tp, TypeAttributes.NestedAssembly);
47+
private static Func<Type, bool> isPrivateType = (tp) => hasVisibility(tp, TypeAttributes.NotPublic) || hasVisibility(tp, TypeAttributes.NestedPrivate) || hasVisibility(tp, TypeAttributes.NestedFamily) || hasVisibility( tp, TypeAttributes.NestedAssembly);
48+
private static Func<Type, bool> isPublicType = (tp) => hasVisibility(tp, TypeAttributes.Public) || hasVisibility(tp,TypeAttributes.NestedPublic);
49+
50+
public static object? CreateNewObject(Type baseType)
4951
{
5052
var myType = CreateType(baseType);
53+
if (myType is null)
54+
{
55+
return null;
56+
}
5157
var myObject = Activator.CreateInstance(myType);
5258
return myObject;
5359
}
5460

55-
public static Type CreateType(Type tp)
61+
static void FillTypeMethods(TypeBuilder tb)
5662
{
57-
Type existingType = assemblyForNonSerializedClasses.GetType(tp.Name + "_NotSerialized", throwOnError:false);
58-
if (existingType is not null)
63+
var constructors = tb.BaseType.GetConstructors();
64+
if (constructors.Count() == 0)
5965
{
60-
return existingType;
66+
// no constructors defined, at least declare a default
67+
ConstructorBuilder constructor = tb.DefineDefaultConstructor(MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName);
6168
}
69+
else
70+
{
71+
foreach (var ctor in constructors)
72+
{
73+
var ctorParams = (from param in ctor.GetParameters() select param.ParameterType).ToArray();
74+
var ctorbuilder = tb.DefineConstructor(ctor.Attributes, ctor.CallingConvention, ctorParams);
75+
ctorbuilder.GetILGenerator().Emit(OpCodes.Ret);
6276

63-
TypeBuilder tb = GetTypeBuilder(tp);
64-
ConstructorBuilder constructor = tb.DefineDefaultConstructor(MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName);
65-
var properties = tp.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static);
77+
}
78+
var parameterless = tb.DefineConstructor(MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName, CallingConventions.Standard | CallingConventions.HasThis, Type.EmptyTypes);
79+
parameterless.GetILGenerator().Emit(OpCodes.Ret);
80+
}
81+
82+
var properties = tb.BaseType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static);
6683
foreach (var prop in properties)
6784
{
6885
CreateProperty(tb, prop);
6986
}
7087

71-
var methods = tp.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static);
88+
var methods = tb.BaseType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static);
7289
foreach (var meth in methods)
7390
{
7491
CreateMethod(tb, meth);
7592
}
7693

7794
ImplementEqualityAndHash(tb);
95+
}
7896

79-
return tb.CreateType();
97+
static string MakeName(Type tp)
98+
{
99+
const string suffix = "_NotSerialized";
100+
string @out = tp.Name + suffix;
101+
var parentType = tp.DeclaringType;
102+
while (parentType is not null)
103+
{
104+
// If we have a nested class, we need the whole nester/nestee
105+
// chain with the suffix for each.
106+
@out = parentType.Name + suffix + "+" + @out;
107+
parentType = parentType.DeclaringType;
108+
}
109+
return @out;
110+
}
111+
112+
public static Type? CreateType(Type tp)
113+
{
114+
if (!isPublicType(tp))
115+
{
116+
return null;
117+
}
118+
119+
Type existingType = assemblyForNonSerializedClasses.GetType(MakeName(tp), throwOnError:false);
120+
if (existingType is not null)
121+
{
122+
return existingType;
123+
}
124+
var parentType = tp.DeclaringType;
125+
if (parentType is not null)
126+
{
127+
// parent types for nested types must be created first. Climb up the
128+
// declaring type chain until we find a "top-level" class.
129+
while (parentType.DeclaringType is not null)
130+
{
131+
parentType = parentType.DeclaringType;
132+
}
133+
CreateTypeInternal(parentType);
134+
Type nestedType = assemblyForNonSerializedClasses.GetType(MakeName(tp), throwOnError:true);
135+
return nestedType;
136+
}
137+
return CreateTypeInternal(tp);
138+
}
139+
140+
private static Type? CreateTypeInternal(Type baseType)
141+
{
142+
if (!isPublicType(baseType))
143+
{
144+
// we can't derive from non-public types.
145+
return null;
146+
}
147+
Type existingType = assemblyForNonSerializedClasses.GetType(MakeName(baseType), throwOnError:false);
148+
if (existingType is not null)
149+
{
150+
return existingType;
151+
}
152+
153+
TypeBuilder tb = GetTypeBuilder(baseType);
154+
SetNonSerialiedAttr(tb);
155+
FillTypeMethods(tb);
156+
157+
var nestedtypes = baseType.GetNestedTypes(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static);
158+
List<TypeBuilder> nestedBuilders = new();
159+
foreach (var nested in nestedtypes)
160+
{
161+
if (isPrivateType(nested))
162+
{
163+
continue;
164+
}
165+
var nestedBuilder = tb.DefineNestedType(nested.Name + notSerializedSuffix,
166+
TypeAttributes.NestedPublic,
167+
nested
168+
);
169+
nestedBuilders.Add(nestedBuilder);
170+
}
171+
var outTp = tb.CreateType();
172+
foreach(var builder in nestedBuilders)
173+
{
174+
FillTypeMethods(builder);
175+
SetNonSerialiedAttr(builder);
176+
builder.CreateType();
177+
}
178+
return outTp;
80179
}
81180

82181
private static void ImplementEqualityAndHash(TypeBuilder tb)
@@ -91,7 +190,7 @@ private static void ImplementEqualityAndHash(TypeBuilder tb)
91190
getHashIlGen.Emit(OpCodes.Ldarg_0);
92191
getHashIlGen.EmitCall(OpCodes.Call, typeof(object).GetMethod("GetType"), Type.EmptyTypes);
93192
getHashIlGen.EmitCall(OpCodes.Call, typeof(Type).GetProperty("Name").GetMethod, Type.EmptyTypes);
94-
getHashIlGen.EmitCall(OpCodes.Call, typeof(string).GetMethod("GetHashCode"), Type.EmptyTypes);
193+
getHashIlGen.EmitCall(OpCodes.Call, typeof(string).GetMethod("GetHashCode", Type.EmptyTypes), Type.EmptyTypes);
95194
getHashIlGen.Emit(OpCodes.Ret);
96195

97196
Type[] equalsArgs = new Type[] {typeof(object), typeof(object)};
@@ -108,19 +207,20 @@ private static void ImplementEqualityAndHash(TypeBuilder tb)
108207
equalsIlGen.Emit(OpCodes.Ret);
109208
}
110209

210+
private static void SetNonSerialiedAttr(TypeBuilder tb)
211+
{
212+
ConstructorInfo attrCtorInfo = typeof(PyNet_NotSerializedAttribute).GetConstructor(new Type[]{});
213+
CustomAttributeBuilder attrBuilder = new CustomAttributeBuilder(attrCtorInfo,new object[]{});
214+
tb.SetCustomAttribute(attrBuilder);
215+
}
216+
111217
private static TypeBuilder GetTypeBuilder(Type baseType)
112218
{
113219
string typeSignature = baseType.Name + notSerializedSuffix;
114-
115220
TypeBuilder tb = moduleBuilder.DefineType(typeSignature,
116221
baseType.Attributes,
117222
baseType,
118223
baseType.GetInterfaces());
119-
120-
ConstructorInfo attrCtorInfo = typeof(PyNet_NotSerializedAttribute).GetConstructor(new Type[]{});
121-
CustomAttributeBuilder attrBuilder = new CustomAttributeBuilder(attrCtorInfo,new object[]{});
122-
tb.SetCustomAttribute(attrBuilder);
123-
124224
return tb;
125225
}
126226

@@ -188,7 +288,6 @@ private static void CreateProperty(TypeBuilder tb, PropertyInfo pinfo)
188288

189289
private static void CreateMethod(TypeBuilder tb, MethodInfo minfo)
190290
{
191-
Console.WriteLine($"overimplementing method for: {minfo} {minfo.IsVirtual} {minfo.IsFinal} ");
192291
string methodName = minfo.Name;
193292

194293
if (dontReimplementMethods.Contains(methodName))
@@ -226,8 +325,7 @@ public void GetObjectData(object obj, SerializationInfo info, StreamingContext c
226325

227326
MaybeType type = obj.GetType();
228327

229-
var hasAttr = (from attr in obj.GetType().CustomAttributes select attr.AttributeType == typeof(PyNet_NotSerializedAttribute)).Count() != 0;
230-
if (hasAttr)
328+
if (type.Value.CustomAttributes.Any((attr) => attr.AttributeType == typeof(NonSerializedAttribute)))
231329
{
232330
// Don't serialize a _NotSerialized. Serialize the base type, and deserialize as a _NotSerialized
233331
type = type.Value.BaseType;
@@ -257,7 +355,7 @@ public object SetObjectData(object obj, SerializationInfo info, StreamingContext
257355
object nameObj = null!;
258356
try
259357
{
260-
nameObj = info.GetValue($"notSerialized_tp", typeof(object));
358+
nameObj = info.GetValue("notSerialized_tp", typeof(object));
261359
}
262360
catch
263361
{
@@ -274,7 +372,7 @@ public object SetObjectData(object obj, SerializationInfo info, StreamingContext
274372
return null!;
275373
}
276374

277-
obj = NonSerializedTypeBuilder.CreateNewObject(name.Value);
375+
obj = NonSerializedTypeBuilder.CreateNewObject(name.Value)!;
278376
return obj;
279377
}
280378
}
@@ -287,14 +385,13 @@ class NonSerializableSelector : SurrogateSelector
287385
{
288386
throw new ArgumentNullException();
289387
}
388+
selector = this;
290389
if (type.IsSerializable)
291390
{
292-
selector = this;
293391
return null; // use whichever default
294392
}
295393
else
296394
{
297-
selector = this;
298395
return new NotSerializableSerializer();
299396
}
300397
}
@@ -597,7 +694,6 @@ internal static IFormatter CreateFormatter()
597694
: new BinaryFormatter()
598695
{
599696
SurrogateSelector = new NonSerializableSelector(),
600-
// Binder = new CustomizedBinder()
601697
};
602698
}
603699
}

0 commit comments

Comments
 (0)