Skip to content

Provide __int__ instance method on enum types to support int(Enum.Member) #1661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/runtime/classmanager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -346,15 +346,25 @@ private static ClassInfo GetClassInfo(Type type, ClassBase impl)
}
}

// only [Flags] enums support bitwise operations
if (type.IsEnum && type.IsFlagsEnum())
if (type.IsEnum)
{
var opsImpl = typeof(EnumOps<>).MakeGenericType(type);
foreach (var op in opsImpl.GetMethods(OpsHelper.BindingFlags))
{
local.Add(op.Name);
}
info = info.Concat(opsImpl.GetMethods(OpsHelper.BindingFlags)).ToArray();

// only [Flags] enums support bitwise operations
if (type.IsFlagsEnum())
{
opsImpl = typeof(FlagEnumOps<>).MakeGenericType(type);
foreach (var op in opsImpl.GetMethods(OpsHelper.BindingFlags))
{
local.Add(op.Name);
}
info = info.Concat(opsImpl.GetMethods(OpsHelper.BindingFlags)).ToArray();
}
}

// Now again to filter w/o losing overloaded member info
Expand Down
1 change: 1 addition & 0 deletions src/runtime/native/ITypeOffsets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ interface ITypeOffsets
int nb_multiply { get; }
int nb_true_divide { get; }
int nb_and { get; }
int nb_int { get; }
int nb_or { get; }
int nb_xor { get; }
int nb_lshift { get; }
Expand Down
1 change: 1 addition & 0 deletions src/runtime/native/TypeOffset.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ static partial class TypeOffset
internal static int nb_and { get; private set; }
internal static int nb_or { get; private set; }
internal static int nb_xor { get; private set; }
internal static int nb_int { get; private set; }
internal static int nb_lshift { get; private set; }
internal static int nb_rshift { get; private set; }
internal static int nb_remainder { get; private set; }
Expand Down
23 changes: 17 additions & 6 deletions src/runtime/operatormethod.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Text;

namespace Python.Runtime
Expand Down Expand Up @@ -51,6 +51,8 @@ static OperatorMethod()
["op_OnesComplement"] = new SlotDefinition("__invert__", TypeOffset.nb_invert),
["op_UnaryNegation"] = new SlotDefinition("__neg__", TypeOffset.nb_negative),
["op_UnaryPlus"] = new SlotDefinition("__pos__", TypeOffset.nb_positive),

["__int__"] = new SlotDefinition("__int__", TypeOffset.nb_int),
};
ComparisonOpMap = new Dictionary<string, string>
{
Expand Down Expand Up @@ -97,14 +99,11 @@ public static bool IsComparisonOp(MethodBase method)
/// </summary>
public static void FixupSlots(BorrowedReference pyType, Type clrType)
{
const BindingFlags flags = BindingFlags.Public | BindingFlags.Static;
Debug.Assert(_opType != null);

var staticMethods =
clrType.IsEnum ? typeof(EnumOps<>).MakeGenericType(clrType).GetMethods(flags)
: clrType.GetMethods(flags);
var operatorCandidates = GetOperatorCandidates(clrType);

foreach (var method in staticMethods)
foreach (var method in operatorCandidates)
{
// We only want to override slots for operators excluding
// comparison operators, which are handled by ClassBase.tp_richcompare.
Expand All @@ -124,6 +123,18 @@ public static void FixupSlots(BorrowedReference pyType, Type clrType)
}
}

static IEnumerable<MethodInfo> GetOperatorCandidates(Type clrType)
{
const BindingFlags flags = BindingFlags.Public | BindingFlags.Static;
if (clrType.IsEnum)
{
return typeof(EnumOps<>).MakeGenericType(clrType).GetMethods(flags)
.Concat(typeof(FlagEnumOps<>).MakeGenericType(clrType).GetMethods(flags));
}

return clrType.GetMethods(flags);
}

public static string GetPyMethodName(string clrName)
{
if (OpMethodMap.ContainsKey(clrName))
Expand Down
14 changes: 13 additions & 1 deletion src/runtime/opshelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public static Expression EnumUnderlyingValue(Expression enumValue)
internal class OpsAttribute: Attribute { }

[Ops]
internal static class EnumOps<T> where T : Enum
internal static class FlagEnumOps<T> where T : Enum
{
static readonly Func<T, T, T> and = BinaryOp(Expression.And);
static readonly Func<T, T, T> or = BinaryOp(Expression.Or);
Expand Down Expand Up @@ -74,4 +74,16 @@ static Func<T, T> UnaryOp(Func<Expression, UnaryExpression> op)
});
}
}

[Ops]
internal static class EnumOps<T> where T : Enum
{
[ForbidPythonThreads]
#pragma warning disable IDE1006 // Naming Styles - must match Python
public static PyInt __int__(T value)
#pragma warning restore IDE1006 // Naming Styles
=> typeof(T).GetEnumUnderlyingType() == typeof(UInt64)
? new PyInt(Convert.ToUInt64(value))
: new PyInt(Convert.ToInt64(value));
}
}
7 changes: 5 additions & 2 deletions src/testing/enumtest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ public enum LongEnum : long
Two,
Three,
Four,
Five
Five,
Max = long.MaxValue,
Min = long.MinValue,
}

public enum ULongEnum : ulong
Expand All @@ -82,7 +84,8 @@ public enum ULongEnum : ulong
Two,
Three,
Four,
Five
Five,
Max = ulong.MaxValue,
}

[Flags]
Expand Down
9 changes: 9 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ def test_ulong_enum():
assert Test.ULongEnum.Two == Test.ULongEnum(2)


def test_long_enum_to_int():
assert int(Test.LongEnum.Max) == 9223372036854775807
assert int(Test.LongEnum.Min) == -9223372036854775808


def test_ulong_enum_to_int():
assert int(Test.ULongEnum.Max) == 18446744073709551615


def test_instantiate_enum_fails():
"""Test that instantiation of an enum class fails."""
from System import DayOfWeek
Expand Down