From 5c4149995ff903cd7c387ca620b475cfcedad489 Mon Sep 17 00:00:00 2001 From: Victor Milovanov Date: Sat, 20 Feb 2021 14:36:04 -0800 Subject: [PATCH] reworked Enum marshaling - enums are no longer converted to and from PyLong automatically https://github.com/pythonnet/pythonnet/issues/1220 - one can construct an instance of MyEnum from Python using MyEnum(numeric_val), e.g. MyEnum(10) - in the above, if MyEnum does not have [Flags] and does not have value 10 defined, to create MyEnum with value 10 one must call MyEnum(10, True). Here True is an unnamed parameter, that allows unchecked conversion - legacy behavior has been moved to a codec (EnumPyLongCodec); enums can now be encoded by codecs - flags enums support bitwise ops via EnumOps class --- CHANGELOG.md | 3 + src/embed_tests/TestOperator.cs | 8 ++ src/runtime/Codecs/EnumPyLongCodec.cs | 68 +++++++++++++++++ src/runtime/classmanager.cs | 11 +++ src/runtime/classobject.cs | 47 ++++++++++-- src/runtime/converter.cs | 74 ++++++++---------- src/runtime/exceptions.cs | 2 +- src/runtime/operatormethod.cs | 16 ++-- src/runtime/opshelper.cs | 77 +++++++++++++++++++ src/runtime/polyfill/ReflectionPolyfills.cs | 3 + src/runtime/pylong.cs | 5 +- src/runtime/pyobject.cs | 16 ++-- src/runtime/runtime.cs | 6 +- tests/test_array.py | 2 +- tests/test_conversion.py | 35 --------- tests/test_enum.py | 84 +++++++++++---------- tests/test_indexer.py | 6 +- 17 files changed, 316 insertions(+), 147 deletions(-) create mode 100644 src/runtime/Codecs/EnumPyLongCodec.cs create mode 100644 src/runtime/opshelper.cs diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d9a79d21..e5f262620 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,9 @@ when .NET expects an integer [#1342][i1342] - BREAKING: to call Python from .NET `Runtime.PythonDLL` property must be set to Python DLL name or the DLL must be loaded in advance. This must be done before calling any other Python.NET functions. - BREAKING: `PyObject.Length()` now raises a `PythonException` when object does not support a concept of length. +- BREAKING: disabled implicit conversion from C# enums to Python `int` and back. +One must now either use enum members (e.g. `MyEnum.Option`), or use enum constructor +(e.g. `MyEnum(42)` or `MyEnum(42, True)` when `MyEnum` does not have a member with value 42). - Sign Runtime DLL with a strong name - Implement loading through `clr_loader` instead of the included `ClrModule`, enables support for .NET Core diff --git a/src/embed_tests/TestOperator.cs b/src/embed_tests/TestOperator.cs index 8e9feb241..68a6e8e35 100644 --- a/src/embed_tests/TestOperator.cs +++ b/src/embed_tests/TestOperator.cs @@ -335,6 +335,14 @@ public void SymmetricalOperatorOverloads() "); } + [Test] + public void EnumOperator() + { + PythonEngine.Exec($@" +from System.IO import FileAccess +c = FileAccess.Read | FileAccess.Write"); + } + [Test] public void OperatorOverloadMissingArgument() { diff --git a/src/runtime/Codecs/EnumPyLongCodec.cs b/src/runtime/Codecs/EnumPyLongCodec.cs new file mode 100644 index 000000000..7dab98028 --- /dev/null +++ b/src/runtime/Codecs/EnumPyLongCodec.cs @@ -0,0 +1,68 @@ +using System; + +namespace Python.Runtime.Codecs +{ + [Obsolete] + public sealed class EnumPyLongCodec : IPyObjectEncoder, IPyObjectDecoder + { + public static EnumPyLongCodec Instance { get; } = new EnumPyLongCodec(); + + public bool CanDecode(PyObject objectType, Type targetType) + { + return targetType.IsEnum + && objectType.IsSubclass(new BorrowedReference(Runtime.PyLongType)); + } + + public bool CanEncode(Type type) + { + return type == typeof(object) || type == typeof(ValueType) || type.IsEnum; + } + + public bool TryDecode(PyObject pyObj, out T value) + { + value = default; + if (!typeof(T).IsEnum) return false; + + Type etype = Enum.GetUnderlyingType(typeof(T)); + + if (!PyLong.IsLongType(pyObj)) return false; + + object result; + try + { + result = pyObj.AsManagedObject(etype); + } + catch (InvalidCastException) + { + return false; + } + + if (Enum.IsDefined(typeof(T), result) || typeof(T).IsFlagsEnum()) + { + value = (T)Enum.ToObject(typeof(T), result); + return true; + } + + return false; + } + + public PyObject TryEncode(object value) + { + if (value is null) return null; + + var enumType = value.GetType(); + if (!enumType.IsEnum) return null; + + try + { + return new PyLong((long)value); + } + catch (InvalidCastException) + { + return new PyLong((ulong)value); + } + } + + private EnumPyLongCodec() { } + } +} diff --git a/src/runtime/classmanager.cs b/src/runtime/classmanager.cs index 1ee06e682..306962f56 100644 --- a/src/runtime/classmanager.cs +++ b/src/runtime/classmanager.cs @@ -403,6 +403,17 @@ private static ClassInfo GetClassInfo(Type type) } } + // only [Flags] enums support bitwise operations + if (type.IsEnum && type.IsFlagsEnum()) + { + var opsImpl = typeof(EnumOps<>).MakeGenericType(type); + foreach (var op in opsImpl.GetMethods(OpsHelper.BindingFlags)) + { + local[op.Name] = 1; + } + info = info.Concat(opsImpl.GetMethods(OpsHelper.BindingFlags)).ToArray(); + } + // Now again to filter w/o losing overloaded member info for (i = 0; i < info.Length; i++) { diff --git a/src/runtime/classobject.cs b/src/runtime/classobject.cs index 4aa97f648..1a2532044 100644 --- a/src/runtime/classobject.cs +++ b/src/runtime/classobject.cs @@ -50,8 +50,9 @@ internal NewReference GetDocString() /// /// Implements __new__ for reflected classes and value types. /// - public static IntPtr tp_new(IntPtr tp, IntPtr args, IntPtr kw) + public static IntPtr tp_new(IntPtr tpRaw, IntPtr args, IntPtr kw) { + var tp = new BorrowedReference(tpRaw); var self = GetManagedObject(tp) as ClassObject; // Sanity check: this ensures a graceful error if someone does @@ -87,7 +88,7 @@ public static IntPtr tp_new(IntPtr tp, IntPtr args, IntPtr kw) return IntPtr.Zero; } - return CLRObject.GetInstHandle(result, tp); + return CLRObject.GetInstHandle(result, tp).DangerousMoveToPointerOrNull(); } if (type.IsAbstract) @@ -98,8 +99,7 @@ public static IntPtr tp_new(IntPtr tp, IntPtr args, IntPtr kw) if (type.IsEnum) { - Exceptions.SetError(Exceptions.TypeError, "cannot instantiate enumeration"); - return IntPtr.Zero; + return NewEnum(type, new BorrowedReference(args), tp).DangerousMoveToPointerOrNull(); } object obj = self.binder.InvokeRaw(IntPtr.Zero, args, kw); @@ -108,7 +108,44 @@ public static IntPtr tp_new(IntPtr tp, IntPtr args, IntPtr kw) return IntPtr.Zero; } - return CLRObject.GetInstHandle(obj, tp); + return CLRObject.GetInstHandle(obj, tp).DangerousMoveToPointerOrNull(); + } + + private static NewReference NewEnum(Type type, BorrowedReference args, BorrowedReference tp) + { + nint argCount = Runtime.PyTuple_Size(args); + bool allowUnchecked = false; + if (argCount == 2) + { + var allow = Runtime.PyTuple_GetItem(args, 1); + if (!Converter.ToManaged(allow, typeof(bool), out var allowObj, true) || allowObj is null) + { + Exceptions.RaiseTypeError("second argument to enum constructor must be a boolean"); + return default; + } + allowUnchecked |= (bool)allowObj; + } + + if (argCount < 1 || argCount > 2) + { + Exceptions.SetError(Exceptions.TypeError, "no constructors match given arguments"); + return default; + } + + var op = Runtime.PyTuple_GetItem(args, 0); + if (!Converter.ToManaged(op, type.GetEnumUnderlyingType(), out object result, true)) + { + return default; + } + + if (!allowUnchecked && !Enum.IsDefined(type, result) && !type.IsFlagsEnum()) + { + Exceptions.SetError(Exceptions.ValueError, "Invalid enumeration value. Pass True as the second argument if unchecked conversion is desired"); + return default; + } + + object enumValue = Enum.ToObject(type, result); + return CLRObject.GetInstHandle(enumValue, tp); } diff --git a/src/runtime/converter.cs b/src/runtime/converter.cs index f3b378113..4de334b5f 100644 --- a/src/runtime/converter.cs +++ b/src/runtime/converter.cs @@ -27,7 +27,6 @@ private Converter() private static Type int16Type; private static Type int32Type; private static Type int64Type; - private static Type flagsType; private static Type boolType; private static Type typeType; @@ -42,7 +41,6 @@ static Converter() singleType = typeof(Single); doubleType = typeof(Double); decimalType = typeof(Decimal); - flagsType = typeof(FlagsAttribute); boolType = typeof(Boolean); typeType = typeof(Type); } @@ -148,7 +146,8 @@ internal static IntPtr ToPython(object value, Type type) return result; } - if (Type.GetTypeCode(type) == TypeCode.Object && value.GetType() != typeof(object)) { + if (Type.GetTypeCode(type) == TypeCode.Object && value.GetType() != typeof(object) + || type.IsEnum) { var encoded = PyObjectConversions.TryEncode(value, type); if (encoded != null) { result = encoded.Handle; @@ -203,6 +202,11 @@ internal static IntPtr ToPython(object value, Type type) type = value.GetType(); + if (type.IsEnum) + { + return CLRObject.GetInstHandle(value, type); + } + TypeCode tc = Type.GetTypeCode(type); switch (tc) @@ -317,6 +321,18 @@ internal static bool ToManaged(IntPtr value, Type type, } return Converter.ToManagedValue(value, type, out result, setError); } + /// + /// Return a managed object for the given Python object, taking funny + /// byref types into account. + /// + /// A Python object + /// The desired managed type + /// Receives the managed object + /// If true, call Exceptions.SetError with the reason for failure. + /// True on success + internal static bool ToManaged(BorrowedReference value, Type type, + out object result, bool setError) + => ToManaged(value.DangerousGetAddress(), type, out result, setError); internal static bool ToManagedValue(BorrowedReference value, Type obType, out object result, bool setError) @@ -398,11 +414,6 @@ internal static bool ToManagedValue(IntPtr value, Type obType, return ToArray(value, obType, out result, setError); } - if (obType.IsEnum) - { - return ToEnum(value, obType, out result, setError); - } - // Conversion to 'Object' is done based on some reasonable default // conversions (Python string -> managed string, Python int -> Int32 etc.). if (obType == objectType) @@ -497,7 +508,7 @@ internal static bool ToManagedValue(IntPtr value, Type obType, } TypeCode typeCode = Type.GetTypeCode(obType); - if (typeCode == TypeCode.Object) + if (typeCode == TypeCode.Object || obType.IsEnum) { IntPtr pyType = Runtime.PyObject_TYPE(value); if (PyObjectConversions.TryDecode(value, pyType, obType, out result)) @@ -516,8 +527,17 @@ internal static bool ToManagedValue(IntPtr value, Type obType, /// private static bool ToPrimitive(IntPtr value, Type obType, out object result, bool setError) { - TypeCode tc = Type.GetTypeCode(obType); result = null; + if (obType.IsEnum) + { + if (setError) + { + Exceptions.SetError(Exceptions.TypeError, "since Python.NET 3.0 int can not be converted to Enum implicitly. Use Enum(int_value)"); + } + return false; + } + + TypeCode tc = Type.GetTypeCode(obType); IntPtr op = IntPtr.Zero; switch (tc) @@ -876,40 +896,6 @@ private static bool ToArray(IntPtr value, Type obType, out object result, bool s result = items; return true; } - - - /// - /// Convert a Python value to a correctly typed managed enum instance. - /// - private static bool ToEnum(IntPtr value, Type obType, out object result, bool setError) - { - Type etype = Enum.GetUnderlyingType(obType); - result = null; - - if (!ToPrimitive(value, etype, out result, setError)) - { - return false; - } - - if (Enum.IsDefined(obType, result)) - { - result = Enum.ToObject(obType, result); - return true; - } - - if (obType.GetCustomAttributes(flagsType, true).Length > 0) - { - result = Enum.ToObject(obType, result); - return true; - } - - if (setError) - { - Exceptions.SetError(Exceptions.ValueError, "invalid enumeration value"); - } - - return false; - } } public static class ConverterExtension diff --git a/src/runtime/exceptions.cs b/src/runtime/exceptions.cs index da8653853..06d2d55b5 100644 --- a/src/runtime/exceptions.cs +++ b/src/runtime/exceptions.cs @@ -340,7 +340,7 @@ public static void Clear() public static void warn(string message, IntPtr exception, int stacklevel) { if (exception == IntPtr.Zero || - (Runtime.PyObject_IsSubclass(exception, Exceptions.Warning) != 1)) + (Runtime.PyObject_IsSubclass(new BorrowedReference(exception), new BorrowedReference(Exceptions.Warning)) != 1)) { Exceptions.RaiseTypeError("Invalid exception"); } diff --git a/src/runtime/operatormethod.cs b/src/runtime/operatormethod.cs index 59bf944bc..e44dc3be1 100644 --- a/src/runtime/operatormethod.cs +++ b/src/runtime/operatormethod.cs @@ -51,7 +51,6 @@ static OperatorMethod() ["op_OnesComplement"] = new SlotDefinition("__invert__", TypeOffset.nb_invert), ["op_UnaryNegation"] = new SlotDefinition("__neg__", TypeOffset.nb_negative), ["op_UnaryPlus"] = new SlotDefinition("__pos__", TypeOffset.nb_positive), - ["op_OneComplement"] = new SlotDefinition("__invert__", TypeOffset.nb_invert), }; ComparisonOpMap = new Dictionary { @@ -80,7 +79,7 @@ public static void Shutdown() public static bool IsOperatorMethod(MethodBase method) { - if (!method.IsSpecialName) + if (!method.IsSpecialName && !method.IsOpsHelper()) { return false; } @@ -102,7 +101,12 @@ public static void FixupSlots(IntPtr pyType, Type clrType) { const BindingFlags flags = BindingFlags.Public | BindingFlags.Static; Debug.Assert(_opType != null); - foreach (var method in clrType.GetMethods(flags)) + + var staticMethods = + clrType.IsEnum ? typeof(EnumOps<>).MakeGenericType(clrType).GetMethods(flags) + : clrType.GetMethods(flags); + + foreach (var method in staticMethods) { // We only want to override slots for operators excluding // comparison operators, which are handled by ClassBase.tp_richcompare. @@ -170,9 +174,11 @@ public static string ReversePyMethodName(string pyName) /// public static bool IsReverse(MethodInfo method) { - Type declaringType = method.DeclaringType; + Type primaryType = method.IsOpsHelper() + ? method.DeclaringType.GetGenericArguments()[0] + : method.DeclaringType; Type leftOperandType = method.GetParameters()[0].ParameterType; - return leftOperandType != declaringType; + return leftOperandType != primaryType; } public static void FilterMethods(MethodInfo[] methods, out MethodInfo[] forwardMethods, out MethodInfo[] reverseMethods) diff --git a/src/runtime/opshelper.cs b/src/runtime/opshelper.cs new file mode 100644 index 000000000..59f7704b7 --- /dev/null +++ b/src/runtime/opshelper.cs @@ -0,0 +1,77 @@ +using System; +using System.Linq.Expressions; +using System.Reflection; + +using static Python.Runtime.OpsHelper; + +namespace Python.Runtime +{ + static class OpsHelper + { + public static BindingFlags BindingFlags => BindingFlags.Public | BindingFlags.Static; + + public static Func Binary(Func func) + { + var a = Expression.Parameter(typeof(T), "a"); + var b = Expression.Parameter(typeof(T), "b"); + var body = func(a, b); + var lambda = Expression.Lambda>(body, a, b); + return lambda.Compile(); + } + + public static Func Unary(Func func) + { + var value = Expression.Parameter(typeof(T), "value"); + var body = func(value); + var lambda = Expression.Lambda>(body, value); + return lambda.Compile(); + } + + public static bool IsOpsHelper(this MethodBase method) + => method.DeclaringType.GetCustomAttribute() is not null; + + public static Expression EnumUnderlyingValue(Expression enumValue) + => Expression.Convert(enumValue, enumValue.Type.GetEnumUnderlyingType()); + } + + [AttributeUsage(AttributeTargets.Class, AllowMultiple = false)] + internal class OpsAttribute: Attribute { } + + [Ops] + internal static class EnumOps where T : Enum + { + static readonly Func and = BinaryOp(Expression.And); + static readonly Func or = BinaryOp(Expression.Or); + static readonly Func xor = BinaryOp(Expression.ExclusiveOr); + + static readonly Func invert = UnaryOp(Expression.OnesComplement); + + public static T op_BitwiseAnd(T a, T b) => and(a, b); + public static T op_BitwiseOr(T a, T b) => or(a, b); + public static T op_ExclusiveOr(T a, T b) => xor(a, b); + public static T op_OnesComplement(T value) => invert(value); + + static Expression FromNumber(Expression number) + => Expression.Convert(number, typeof(T)); + + static Func BinaryOp(Func op) + { + return Binary((a, b) => + { + var numericA = EnumUnderlyingValue(a); + var numericB = EnumUnderlyingValue(b); + var numericResult = op(numericA, numericB); + return FromNumber(numericResult); + }); + } + static Func UnaryOp(Func op) + { + return Unary(value => + { + var numeric = EnumUnderlyingValue(value); + var numericResult = op(numeric); + return FromNumber(numericResult); + }); + } + } +} diff --git a/src/runtime/polyfill/ReflectionPolyfills.cs b/src/runtime/polyfill/ReflectionPolyfills.cs index 65f9b83de..36bd39cef 100644 --- a/src/runtime/polyfill/ReflectionPolyfills.cs +++ b/src/runtime/polyfill/ReflectionPolyfills.cs @@ -30,5 +30,8 @@ public static T GetCustomAttribute(this Assembly assembly) where T: Attribute .Cast() .SingleOrDefault(); } + + public static bool IsFlagsEnum(this Type type) + => type.GetCustomAttribute() is not null; } } diff --git a/src/runtime/pylong.cs b/src/runtime/pylong.cs index fdfd26aba..8cb814cf6 100644 --- a/src/runtime/pylong.cs +++ b/src/runtime/pylong.cs @@ -188,11 +188,8 @@ public PyLong(string value) : base(FromString(value)) /// - /// IsLongType Method - /// - /// /// Returns true if the given object is a Python long. - /// + /// public static bool IsLongType(PyObject value) { return Runtime.PyLong_Check(value.obj); diff --git a/src/runtime/pyobject.cs b/src/runtime/pyobject.cs index 382ed8ccd..81578a7a8 100644 --- a/src/runtime/pyobject.cs +++ b/src/runtime/pyobject.cs @@ -930,17 +930,21 @@ public bool IsInstance(PyObject typeOrClass) /// - /// IsSubclass Method - /// - /// - /// Return true if the object is identical to or derived from the + /// Return true if the object is identical to or derived from the /// given Python type or class. This method always succeeds. - /// + /// public bool IsSubclass(PyObject typeOrClass) { if (typeOrClass == null) throw new ArgumentNullException(nameof(typeOrClass)); - int r = Runtime.PyObject_IsSubclass(obj, typeOrClass.obj); + return IsSubclass(typeOrClass.Reference); + } + + internal bool IsSubclass(BorrowedReference typeOrClass) + { + if (typeOrClass.IsNull) throw new ArgumentNullException(nameof(typeOrClass)); + + int r = Runtime.PyObject_IsSubclass(Reference, typeOrClass); if (r < 0) { Runtime.PyErr_Clear(); diff --git a/src/runtime/runtime.cs b/src/runtime/runtime.cs index caa160bcf..263b4473e 100644 --- a/src/runtime/runtime.cs +++ b/src/runtime/runtime.cs @@ -1109,7 +1109,7 @@ internal static int PyObject_Compare(IntPtr value1, IntPtr value2) internal static int PyObject_IsInstance(IntPtr ob, IntPtr type) => Delegates.PyObject_IsInstance(ob, type); - internal static int PyObject_IsSubclass(IntPtr ob, IntPtr type) => Delegates.PyObject_IsSubclass(ob, type); + internal static int PyObject_IsSubclass(BorrowedReference ob, BorrowedReference type) => Delegates.PyObject_IsSubclass(ob, type); internal static int PyCallable_Check(IntPtr pointer) => Delegates.PyCallable_Check(pointer); @@ -2314,7 +2314,7 @@ static Delegates() PyObject_CallObject = (delegate* unmanaged[Cdecl])GetFunctionByName(nameof(PyObject_CallObject), GetUnmanagedDll(_PythonDll)); PyObject_RichCompareBool = (delegate* unmanaged[Cdecl])GetFunctionByName(nameof(PyObject_RichCompareBool), GetUnmanagedDll(_PythonDll)); PyObject_IsInstance = (delegate* unmanaged[Cdecl])GetFunctionByName(nameof(PyObject_IsInstance), GetUnmanagedDll(_PythonDll)); - PyObject_IsSubclass = (delegate* unmanaged[Cdecl])GetFunctionByName(nameof(PyObject_IsSubclass), GetUnmanagedDll(_PythonDll)); + PyObject_IsSubclass = (delegate* unmanaged[Cdecl])GetFunctionByName(nameof(PyObject_IsSubclass), GetUnmanagedDll(_PythonDll)); PyCallable_Check = (delegate* unmanaged[Cdecl])GetFunctionByName(nameof(PyCallable_Check), GetUnmanagedDll(_PythonDll)); PyObject_IsTrue = (delegate* unmanaged[Cdecl])GetFunctionByName(nameof(PyObject_IsTrue), GetUnmanagedDll(_PythonDll)); PyObject_Not = (delegate* unmanaged[Cdecl])GetFunctionByName(nameof(PyObject_Not), GetUnmanagedDll(_PythonDll)); @@ -2599,7 +2599,7 @@ static Delegates() internal static delegate* unmanaged[Cdecl] PyObject_CallObject { get; } internal static delegate* unmanaged[Cdecl] PyObject_RichCompareBool { get; } internal static delegate* unmanaged[Cdecl] PyObject_IsInstance { get; } - internal static delegate* unmanaged[Cdecl] PyObject_IsSubclass { get; } + internal static delegate* unmanaged[Cdecl] PyObject_IsSubclass { get; } internal static delegate* unmanaged[Cdecl] PyCallable_Check { get; } internal static delegate* unmanaged[Cdecl] PyObject_IsTrue { get; } internal static delegate* unmanaged[Cdecl] PyObject_Not { get; } diff --git a/tests/test_array.py b/tests/test_array.py index 2b1a289ad..d6f08a961 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -680,7 +680,7 @@ def test_enum_array(): items[-1] = ShortEnum.Zero assert items[-1] == ShortEnum.Zero - with pytest.raises(ValueError): + with pytest.raises(TypeError): ob = Test.EnumArrayTest() ob.items[0] = 99 diff --git a/tests/test_conversion.py b/tests/test_conversion.py index aea95e164..eec2bcde6 100644 --- a/tests/test_conversion.py +++ b/tests/test_conversion.py @@ -601,41 +601,6 @@ class Foo(object): assert ob.ObjectField == Foo -def test_enum_conversion(): - """Test enum conversion.""" - from Python.Test import ShortEnum - - ob = ConversionTest() - assert ob.EnumField == ShortEnum.Zero - - ob.EnumField = ShortEnum.One - assert ob.EnumField == ShortEnum.One - - ob.EnumField = 0 - assert ob.EnumField == ShortEnum.Zero - assert ob.EnumField == 0 - - ob.EnumField = 1 - assert ob.EnumField == ShortEnum.One - assert ob.EnumField == 1 - - with pytest.raises(ValueError): - ob = ConversionTest() - ob.EnumField = 10 - - with pytest.raises(ValueError): - ob = ConversionTest() - ob.EnumField = 255 - - with pytest.raises(OverflowError): - ob = ConversionTest() - ob.EnumField = 1000000 - - with pytest.raises(TypeError): - ob = ConversionTest() - ob.EnumField = "spam" - - def test_null_conversion(): """Test null conversion.""" import System diff --git a/tests/test_enum.py b/tests/test_enum.py index 27fe7e9ef..1f0711a94 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -22,69 +22,69 @@ def test_enum_get_member(): """Test access to enum members.""" from System import DayOfWeek - assert DayOfWeek.Sunday == 0 - assert DayOfWeek.Monday == 1 - assert DayOfWeek.Tuesday == 2 - assert DayOfWeek.Wednesday == 3 - assert DayOfWeek.Thursday == 4 - assert DayOfWeek.Friday == 5 - assert DayOfWeek.Saturday == 6 + assert DayOfWeek.Sunday == DayOfWeek(0) + assert DayOfWeek.Monday == DayOfWeek(1) + assert DayOfWeek.Tuesday == DayOfWeek(2) + assert DayOfWeek.Wednesday == DayOfWeek(3) + assert DayOfWeek.Thursday == DayOfWeek(4) + assert DayOfWeek.Friday == DayOfWeek(5) + assert DayOfWeek.Saturday == DayOfWeek(6) def test_byte_enum(): """Test byte enum.""" - assert Test.ByteEnum.Zero == 0 - assert Test.ByteEnum.One == 1 - assert Test.ByteEnum.Two == 2 + assert Test.ByteEnum.Zero == Test.ByteEnum(0) + assert Test.ByteEnum.One == Test.ByteEnum(1) + assert Test.ByteEnum.Two == Test.ByteEnum(2) def test_sbyte_enum(): """Test sbyte enum.""" - assert Test.SByteEnum.Zero == 0 - assert Test.SByteEnum.One == 1 - assert Test.SByteEnum.Two == 2 + assert Test.SByteEnum.Zero == Test.SByteEnum(0) + assert Test.SByteEnum.One == Test.SByteEnum(1) + assert Test.SByteEnum.Two == Test.SByteEnum(2) def test_short_enum(): """Test short enum.""" - assert Test.ShortEnum.Zero == 0 - assert Test.ShortEnum.One == 1 - assert Test.ShortEnum.Two == 2 + assert Test.ShortEnum.Zero == Test.ShortEnum(0) + assert Test.ShortEnum.One == Test.ShortEnum(1) + assert Test.ShortEnum.Two == Test.ShortEnum(2) def test_ushort_enum(): """Test ushort enum.""" - assert Test.UShortEnum.Zero == 0 - assert Test.UShortEnum.One == 1 - assert Test.UShortEnum.Two == 2 + assert Test.UShortEnum.Zero == Test.UShortEnum(0) + assert Test.UShortEnum.One == Test.UShortEnum(1) + assert Test.UShortEnum.Two == Test.UShortEnum(2) def test_int_enum(): """Test int enum.""" - assert Test.IntEnum.Zero == 0 - assert Test.IntEnum.One == 1 - assert Test.IntEnum.Two == 2 + assert Test.IntEnum.Zero == Test.IntEnum(0) + assert Test.IntEnum.One == Test.IntEnum(1) + assert Test.IntEnum.Two == Test.IntEnum(2) def test_uint_enum(): """Test uint enum.""" - assert Test.UIntEnum.Zero == 0 - assert Test.UIntEnum.One == 1 - assert Test.UIntEnum.Two == 2 + assert Test.UIntEnum.Zero == Test.UIntEnum(0) + assert Test.UIntEnum.One == Test.UIntEnum(1) + assert Test.UIntEnum.Two == Test.UIntEnum(2) def test_long_enum(): """Test long enum.""" - assert Test.LongEnum.Zero == 0 - assert Test.LongEnum.One == 1 - assert Test.LongEnum.Two == 2 + assert Test.LongEnum.Zero == Test.LongEnum(0) + assert Test.LongEnum.One == Test.LongEnum(1) + assert Test.LongEnum.Two == Test.LongEnum(2) def test_ulong_enum(): """Test ulong enum.""" - assert Test.ULongEnum.Zero == 0 - assert Test.ULongEnum.One == 1 - assert Test.ULongEnum.Two == 2 + assert Test.ULongEnum.Zero == Test.ULongEnum(0) + assert Test.ULongEnum.One == Test.ULongEnum(1) + assert Test.ULongEnum.Two == Test.ULongEnum(2) def test_instantiate_enum_fails(): @@ -117,29 +117,31 @@ def test_enum_set_member_fails(): del DayOfWeek.Sunday -def test_enum_with_flags_attr_conversion(): +def test_enum_undefined_value(): """Test enumeration conversion with FlagsAttribute set.""" # This works because the FlagsField enum has FlagsAttribute. - Test.FieldTest().FlagsField = 99 + Test.FieldTest().FlagsField = Test.FlagsEnum(99) # This should fail because our test enum doesn't have it. with pytest.raises(ValueError): - Test.FieldTest().EnumField = 99 - + Test.FieldTest().EnumField = Test.ShortEnum(20) + + # explicitly permit undefined values + Test.FieldTest().EnumField = Test.ShortEnum(20, True) def test_enum_conversion(): """Test enumeration conversion.""" ob = Test.FieldTest() - assert ob.EnumField == 0 + assert ob.EnumField == Test.ShortEnum(0) ob.EnumField = Test.ShortEnum.One - assert ob.EnumField == 1 - - with pytest.raises(ValueError): - Test.FieldTest().EnumField = 20 + assert ob.EnumField == Test.ShortEnum(1) with pytest.raises(OverflowError): - Test.FieldTest().EnumField = 100000 + Test.FieldTest().EnumField = Test.ShortEnum(100000) with pytest.raises(TypeError): Test.FieldTest().EnumField = "str" + + with pytest.raises(TypeError): + Test.FieldTest().EnumField = 1 diff --git a/tests/test_indexer.py b/tests/test_indexer.py index 7992f76b0..0af6e6c45 100644 --- a/tests/test_indexer.py +++ b/tests/test_indexer.py @@ -400,8 +400,10 @@ def test_enum_indexer(): ob[key] = "eggs" assert ob[key] == "eggs" - ob[1] = "spam" - assert ob[1] == "spam" + with pytest.raises(TypeError): + ob[1] = "spam" + with pytest.raises(TypeError): + ob[1] with pytest.raises(TypeError): ob = Test.EnumIndexerTest()