diff --git a/src/embed_tests/TestOperator.cs b/src/embed_tests/TestOperator.cs index ecdb0c1dc..8e9feb241 100644 --- a/src/embed_tests/TestOperator.cs +++ b/src/embed_tests/TestOperator.cs @@ -25,6 +25,17 @@ public class OperableObject { public int Num { get; set; } + public override int GetHashCode() + { + return unchecked(159832395 + Num.GetHashCode()); + } + + public override bool Equals(object obj) + { + return obj is OperableObject @object && + Num == @object.Num; + } + public OperableObject(int num) { Num = num; @@ -149,6 +160,103 @@ public OperableObject(int num) return new OperableObject(a.Num ^ b); } + public static bool operator ==(int a, OperableObject b) + { + return (a == b.Num); + } + public static bool operator ==(OperableObject a, OperableObject b) + { + return (a.Num == b.Num); + } + public static bool operator ==(OperableObject a, int b) + { + return (a.Num == b); + } + + public static bool operator !=(int a, OperableObject b) + { + return (a != b.Num); + } + public static bool operator !=(OperableObject a, OperableObject b) + { + return (a.Num != b.Num); + } + public static bool operator !=(OperableObject a, int b) + { + return (a.Num != b); + } + + public static bool operator <=(int a, OperableObject b) + { + return (a <= b.Num); + } + public static bool operator <=(OperableObject a, OperableObject b) + { + return (a.Num <= b.Num); + } + public static bool operator <=(OperableObject a, int b) + { + return (a.Num <= b); + } + + public static bool operator >=(int a, OperableObject b) + { + return (a >= b.Num); + } + public static bool operator >=(OperableObject a, OperableObject b) + { + return (a.Num >= b.Num); + } + public static bool operator >=(OperableObject a, int b) + { + return (a.Num >= b); + } + + public static bool operator >=(OperableObject a, PyObject b) + { + using (Py.GIL()) + { + // Assuming b is a tuple, take the first element. + int bNum = b[0].As(); + return a.Num >= bNum; + } + } + public static bool operator <=(OperableObject a, PyObject b) + { + using (Py.GIL()) + { + // Assuming b is a tuple, take the first element. + int bNum = b[0].As(); + return a.Num <= bNum; + } + } + + public static bool operator <(int a, OperableObject b) + { + return (a < b.Num); + } + public static bool operator <(OperableObject a, OperableObject b) + { + return (a.Num < b.Num); + } + public static bool operator <(OperableObject a, int b) + { + return (a.Num < b); + } + + public static bool operator >(int a, OperableObject b) + { + return (a > b.Num); + } + public static bool operator >(OperableObject a, OperableObject b) + { + return (a.Num > b.Num); + } + public static bool operator >(OperableObject a, int b) + { + return (a.Num > b); + } + public static OperableObject operator <<(OperableObject a, int offset) { return new OperableObject(a.Num << offset); @@ -161,7 +269,7 @@ public OperableObject(int num) } [Test] - public void OperatorOverloads() + public void SymmetricalOperatorOverloads() { string name = string.Format("{0}.{1}", typeof(OperableObject).DeclaringType.Name, @@ -206,6 +314,24 @@ public void OperatorOverloads() c = a ^ b assert c.Num == a.Num ^ b.Num + +c = a == b +assert c == (a.Num == b.Num) + +c = a != b +assert c == (a.Num != b.Num) + +c = a <= b +assert c == (a.Num <= b.Num) + +c = a >= b +assert c == (a.Num >= b.Num) + +c = a < b +assert c == (a.Num < b.Num) + +c = a > b +assert c == (a.Num > b.Num) "); } @@ -263,6 +389,51 @@ public void ForwardOperatorOverloads() c = a ^ b assert c.Num == a.Num ^ b + +c = a == b +assert c == (a.Num == b) + +c = a != b +assert c == (a.Num != b) + +c = a <= b +assert c == (a.Num <= b) + +c = a >= b +assert c == (a.Num >= b) + +c = a < b +assert c == (a.Num < b) + +c = a > b +assert c == (a.Num > b) +"); + } + + [Test] + public void TupleComparisonOperatorOverloads() + { + string name = string.Format("{0}.{1}", + typeof(OperableObject).DeclaringType.Name, + typeof(OperableObject).Name); + string module = MethodBase.GetCurrentMethod().DeclaringType.Namespace; + PythonEngine.Exec($@" +from {module} import * +cls = {name} +a = cls(2) +b = (1, 2) + +c = a >= b +assert c == (a.Num >= b[0]) + +c = a <= b +assert c == (a.Num <= b[0]) + +c = b >= a +assert c == (b[0] >= a.Num) + +c = b <= a +assert c == (b[0] <= a.Num) "); } @@ -304,6 +475,24 @@ public void ReverseOperatorOverloads() c = a ^ b assert c.Num == a ^ b.Num + +c = a == b +assert c == (a == b.Num) + +c = a != b +assert c == (a != b.Num) + +c = a <= b +assert c == (a <= b.Num) + +c = a >= b +assert c == (a >= b.Num) + +c = a < b +assert c == (a < b.Num) + +c = a > b +assert c == (a > b.Num) "); } diff --git a/src/runtime/classbase.cs b/src/runtime/classbase.cs index 0ff4ba154..872501267 100644 --- a/src/runtime/classbase.cs +++ b/src/runtime/classbase.cs @@ -21,6 +21,7 @@ internal class ClassBase : ManagedType [NonSerialized] internal List dotNetMembers; internal Indexer indexer; + internal Dictionary richcompare; internal MaybeType type; internal ClassBase(Type tp) @@ -35,6 +36,15 @@ internal virtual bool CanSubclass() return !type.Value.IsEnum; } + public readonly static Dictionary CilToPyOpMap = new Dictionary + { + ["op_Equality"] = Runtime.Py_EQ, + ["op_Inequality"] = Runtime.Py_NE, + ["op_LessThanOrEqual"] = Runtime.Py_LE, + ["op_GreaterThanOrEqual"] = Runtime.Py_GE, + ["op_LessThan"] = Runtime.Py_LT, + ["op_GreaterThan"] = Runtime.Py_GT, + }; /// /// Default implementation of [] semantics for reflected types. @@ -72,6 +82,30 @@ public static IntPtr tp_richcompare(IntPtr ob, IntPtr other, int op) { CLRObject co1; CLRObject co2; + IntPtr tp = Runtime.PyObject_TYPE(ob); + var cls = (ClassBase)GetManagedObject(tp); + // C# operator methods take precedence over IComparable. + // We first check if there's a comparison operator by looking up the richcompare table, + // otherwise fallback to checking if an IComparable interface is handled. + if (cls.richcompare.TryGetValue(op, out var methodObject)) + { + // Wrap the `other` argument of a binary comparison operator in a PyTuple. + IntPtr args = Runtime.PyTuple_New(1); + Runtime.XIncref(other); + Runtime.PyTuple_SetItem(args, 0, other); + + IntPtr value; + try + { + value = methodObject.Invoke(ob, args, IntPtr.Zero); + } + finally + { + Runtime.XDecref(args); // Free args pytuple + } + return value; + } + switch (op) { case Runtime.Py_EQ: diff --git a/src/runtime/classmanager.cs b/src/runtime/classmanager.cs index 64c985ce7..0cbff371f 100644 --- a/src/runtime/classmanager.cs +++ b/src/runtime/classmanager.cs @@ -259,6 +259,7 @@ private static void InitClassBase(Type type, ClassBase impl) ClassInfo info = GetClassInfo(type); impl.indexer = info.indexer; + impl.richcompare = new Dictionary(); // Now we allocate the Python type object to reflect the given // managed type, filling the Python type slots with thunks that @@ -284,6 +285,9 @@ private static void InitClassBase(Type type, ClassBase impl) Runtime.PyDict_SetItemString(dict, name, item.pyHandle); // Decref the item now that it's been used. item.DecrRefCount(); + if (ClassBase.CilToPyOpMap.TryGetValue(name, out var pyOp)) { + impl.richcompare.Add(pyOp, (MethodObject)item); + } } // If class has constructors, generate an __doc__ attribute. @@ -553,8 +557,7 @@ private static ClassInfo GetClassInfo(Type type) { string pyName = OperatorMethod.GetPyMethodName(name); string pyNameReverse = OperatorMethod.ReversePyMethodName(pyName); - MethodInfo[] forwardMethods, reverseMethods; - OperatorMethod.FilterMethods(mlist, out forwardMethods, out reverseMethods); + OperatorMethod.FilterMethods(mlist, out var forwardMethods, out var reverseMethods); // Only methods where the left operand is the declaring type. if (forwardMethods.Length > 0) ci.members[pyName] = new MethodObject(type, name, forwardMethods); diff --git a/src/runtime/methodbinder.cs b/src/runtime/methodbinder.cs index ba37c19c1..5de0ecc00 100644 --- a/src/runtime/methodbinder.cs +++ b/src/runtime/methodbinder.cs @@ -354,16 +354,17 @@ internal Binding Bind(IntPtr inst, IntPtr args, IntPtr kw, MethodBase info, Meth int kwargsMatched; int defaultsNeeded; bool isOperator = OperatorMethod.IsOperatorMethod(mi); - int clrnargs = pi.Length; // Binary operator methods will have 2 CLR args but only one Python arg // (unary operators will have 1 less each), since Python operator methods are bound. - isOperator = isOperator && pynargs == clrnargs - 1; + isOperator = isOperator && pynargs == pi.Length - 1; + bool isReverse = isOperator && OperatorMethod.IsReverse((MethodInfo)mi); // Only cast if isOperator. + if (isReverse && OperatorMethod.IsComparisonOp((MethodInfo)mi)) + continue; // Comparison operators in Python have no reverse mode. if (!MatchesArgumentCount(pynargs, pi, kwargDict, out paramsArray, out defaultArgList, out kwargsMatched, out defaultsNeeded) && !isOperator) { continue; } // Preprocessing pi to remove either the first or second argument. - bool isReverse = isOperator && OperatorMethod.IsReverse((MethodInfo)mi); // Only cast if isOperator. if (isOperator && !isReverse) { // The first Python arg is the right operand, while the bound instance is the left. // We need to skip the first (left operand) CLR argument. diff --git a/src/runtime/operatormethod.cs b/src/runtime/operatormethod.cs index 1e0244510..59bf944bc 100644 --- a/src/runtime/operatormethod.cs +++ b/src/runtime/operatormethod.cs @@ -15,6 +15,7 @@ internal static class OperatorMethod /// that identifies that operator's slot (e.g. nb_add) in heap space. /// public static Dictionary OpMethodMap { get; private set; } + public static Dictionary ComparisonOpMap { get; private set; } public readonly struct SlotDefinition { public SlotDefinition(string methodName, int typeOffset) @@ -24,6 +25,7 @@ public SlotDefinition(string methodName, int typeOffset) } public string MethodName { get; } public int TypeOffset { get; } + } private static PyObject _opType; @@ -49,6 +51,16 @@ static OperatorMethod() ["op_OnesComplement"] = new SlotDefinition("__invert__", TypeOffset.nb_invert), ["op_UnaryNegation"] = new SlotDefinition("__neg__", TypeOffset.nb_negative), ["op_UnaryPlus"] = new SlotDefinition("__pos__", TypeOffset.nb_positive), + ["op_OneComplement"] = new SlotDefinition("__invert__", TypeOffset.nb_invert), + }; + ComparisonOpMap = new Dictionary + { + ["op_Equality"] = "__eq__", + ["op_Inequality"] = "__ne__", + ["op_LessThanOrEqual"] = "__le__", + ["op_GreaterThanOrEqual"] = "__ge__", + ["op_LessThan"] = "__lt__", + ["op_GreaterThan"] = "__gt__", }; } @@ -72,8 +84,14 @@ public static bool IsOperatorMethod(MethodBase method) { return false; } - return OpMethodMap.ContainsKey(method.Name); + return OpMethodMap.ContainsKey(method.Name) || ComparisonOpMap.ContainsKey(method.Name); + } + + public static bool IsComparisonOp(MethodInfo method) + { + return ComparisonOpMap.ContainsKey(method.Name); } + /// /// For the operator methods of a CLR type, set the special slots of the /// corresponding Python type's operator methods. @@ -86,7 +104,9 @@ public static void FixupSlots(IntPtr pyType, Type clrType) Debug.Assert(_opType != null); foreach (var method in clrType.GetMethods(flags)) { - if (!IsOperatorMethod(method)) + // We only want to override slots for operators excluding + // comparison operators, which are handled by ClassBase.tp_richcompare. + if (!OpMethodMap.ContainsKey(method.Name)) { continue; } @@ -99,13 +119,18 @@ public static void FixupSlots(IntPtr pyType, Type clrType) // when used with a Python operator. // https://tenthousandmeters.com/blog/python-behind-the-scenes-6-how-python-object-system-works/ Marshal.WriteIntPtr(pyType, offset, func); - } } public static string GetPyMethodName(string clrName) { - return OpMethodMap[clrName].MethodName; + if (OpMethodMap.ContainsKey(clrName)) + { + return OpMethodMap[clrName].MethodName; + } else + { + return ComparisonOpMap[clrName]; + } } private static string GenerateDummyCode()