Skip to content

Commit 187816e

Browse files
committed
Finish comparison operator impl and add tests
1 parent 9567a80 commit 187816e

File tree

4 files changed

+77
-1
lines changed

4 files changed

+77
-1
lines changed

src/embed_tests/TestOperator.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ public void SymmetricalOperatorOverloads()
270270
271271
c = a > b
272272
assert c == (a.Num > b.Num)
273+
274+
c = a == b
275+
assert c == (a.Num == b.Num)
276+
277+
c = a != b
278+
assert c == (a.Num != b.Num)
273279
");
274280
}
275281

@@ -339,6 +345,12 @@ public void ForwardOperatorOverloads()
339345
340346
c = a > b
341347
assert c == (a.Num > b)
348+
349+
c = a == b
350+
assert c == (a.Num == b)
351+
352+
c = a != b
353+
assert c == (a.Num != b)
342354
");
343355
}
344356

@@ -392,6 +404,12 @@ public void ReverseOperatorOverloads()
392404
393405
c = a > b
394406
assert c == (a > b.Num)
407+
408+
c = a == b
409+
assert c == (a == b.Num)
410+
411+
c = a != b
412+
assert c == (a != b.Num)
395413
");
396414

397415
}

src/runtime/classbase.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ namespace Python.Runtime
1919
internal class ClassBase : ManagedType
2020
{
2121
internal Indexer indexer;
22+
internal Hashtable richcompare;
2223
internal Type type;
2324

2425
internal ClassBase(Type tp)
@@ -32,6 +33,15 @@ internal virtual bool CanSubclass()
3233
return !type.IsEnum;
3334
}
3435

36+
public readonly static Dictionary<int, string> PyToCilOpMap = new Dictionary<int, string>
37+
{
38+
[Runtime.Py_EQ] = "op_Equality",
39+
[Runtime.Py_NE] = "op_Inequality",
40+
[Runtime.Py_GT] = "op_GreaterThan",
41+
[Runtime.Py_GE] = "op_GreaterThanOrEqual",
42+
[Runtime.Py_LT] = "op_LessThan",
43+
[Runtime.Py_LE] = "op_LessThanOrEqual",
44+
};
3545

3646
/// <summary>
3747
/// Default implementation of [] semantics for reflected types.
@@ -64,6 +74,42 @@ public static IntPtr tp_richcompare(IntPtr ob, IntPtr other, int op)
6474
{
6575
CLRObject co1;
6676
CLRObject co2;
77+
IntPtr tp = Runtime.PyObject_TYPE(ob);
78+
var cls = (ClassBase)GetManagedObject(tp);
79+
// C# operator methods take precedence over IComparable.
80+
// We first check if there's a comparison operator by looking up the richcompare table,
81+
// otherwise fallback to checking if an IComparable interface is handled.
82+
if (PyToCilOpMap.ContainsKey(op)) {
83+
string CilOp = PyToCilOpMap[op];
84+
if (cls.richcompare.Contains(CilOp)) {
85+
var methodObject = (MethodObject)cls.richcompare[CilOp];
86+
IntPtr args = other;
87+
var free = false;
88+
if (!Runtime.PyTuple_Check(other))
89+
{
90+
// Wrap the `other` argument of a binary comparison operator in a PyTuple.
91+
args = Runtime.PyTuple_New(1);
92+
Runtime.XIncref(other);
93+
Runtime.PyTuple_SetItem(args, 0, other);
94+
free = true;
95+
}
96+
97+
IntPtr value;
98+
try
99+
{
100+
value = methodObject.Invoke(ob, args, IntPtr.Zero);
101+
}
102+
finally
103+
{
104+
if (free)
105+
{
106+
Runtime.XDecref(args); // Free args pytuple
107+
}
108+
}
109+
return value;
110+
}
111+
}
112+
67113
switch (op)
68114
{
69115
case Runtime.Py_EQ:

src/runtime/classmanager.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ private static void InitClassBase(Type type, ClassBase impl)
197197
ClassInfo info = GetClassInfo(type);
198198

199199
impl.indexer = info.indexer;
200+
impl.richcompare = new Hashtable();
200201

201202
// Now we allocate the Python type object to reflect the given
202203
// managed type, filling the Python type slots with thunks that
@@ -217,6 +218,9 @@ private static void InitClassBase(Type type, ClassBase impl)
217218
Runtime.PyDict_SetItemString(dict, name, item.pyHandle);
218219
// Decref the item now that it's been used.
219220
item.DecrRefCount();
221+
if (ClassBase.PyToCilOpMap.ContainsValue(name)) {
222+
impl.richcompare.Add(name, iter.Value);
223+
}
220224
}
221225

222226
// If class has constructors, generate an __doc__ attribute.

src/runtime/operatormethod.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ static OperatorMethod()
5050
["op_UnaryNegation"] = new SlotDefinition("__neg__", TypeOffset.nb_negative),
5151
["op_UnaryPlus"] = new SlotDefinition("__pos__", TypeOffset.nb_positive),
5252
["op_OneComplement"] = new SlotDefinition("__invert__", TypeOffset.nb_invert),
53+
["op_Equality"] = new SlotDefinition("__eq__", TypeOffset.tp_richcompare),
54+
["op_Inequality"] = new SlotDefinition("__ne__", TypeOffset.tp_richcompare),
5355
["op_GreaterThan"] = new SlotDefinition("__gt__", TypeOffset.tp_richcompare),
5456
["op_GreaterThanOrEqual"] = new SlotDefinition("__ge__", TypeOffset.tp_richcompare),
5557
["op_LessThan"] = new SlotDefinition("__lt__", TypeOffset.tp_richcompare),
@@ -79,6 +81,12 @@ public static bool IsOperatorMethod(MethodBase method)
7981
}
8082
return OpMethodMap.ContainsKey(method.Name);
8183
}
84+
85+
public static bool IsComparisonOp(MethodInfo method)
86+
{
87+
return OpMethodMap[method.Name].TypeOffset == TypeOffset.tp_richcompare;
88+
}
89+
8290
/// <summary>
8391
/// For the operator methods of a CLR type, set the special slots of the
8492
/// corresponding Python type's operator methods.
@@ -91,7 +99,7 @@ public static void FixupSlots(IntPtr pyType, Type clrType)
9199
Debug.Assert(_opType != null);
92100
foreach (var method in clrType.GetMethods(flags))
93101
{
94-
if (!IsOperatorMethod(method))
102+
if (!IsOperatorMethod(method) || IsComparisonOp(method)) // We don't want to override ClassBase.tp_richcompare.
95103
{
96104
continue;
97105
}

0 commit comments

Comments
 (0)