Skip to content

Commit e44aa46

Browse files
authored
Support comparison operators (#1347)
1 parent ed6763c commit e44aa46

File tree

5 files changed

+262
-10
lines changed

5 files changed

+262
-10
lines changed

src/embed_tests/TestOperator.cs

+190-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@ public class OperableObject
2525
{
2626
public int Num { get; set; }
2727

28+
public override int GetHashCode()
29+
{
30+
return unchecked(159832395 + Num.GetHashCode());
31+
}
32+
33+
public override bool Equals(object obj)
34+
{
35+
return obj is OperableObject @object &&
36+
Num == @object.Num;
37+
}
38+
2839
public OperableObject(int num)
2940
{
3041
Num = num;
@@ -149,6 +160,103 @@ public OperableObject(int num)
149160
return new OperableObject(a.Num ^ b);
150161
}
151162

163+
public static bool operator ==(int a, OperableObject b)
164+
{
165+
return (a == b.Num);
166+
}
167+
public static bool operator ==(OperableObject a, OperableObject b)
168+
{
169+
return (a.Num == b.Num);
170+
}
171+
public static bool operator ==(OperableObject a, int b)
172+
{
173+
return (a.Num == b);
174+
}
175+
176+
public static bool operator !=(int a, OperableObject b)
177+
{
178+
return (a != b.Num);
179+
}
180+
public static bool operator !=(OperableObject a, OperableObject b)
181+
{
182+
return (a.Num != b.Num);
183+
}
184+
public static bool operator !=(OperableObject a, int b)
185+
{
186+
return (a.Num != b);
187+
}
188+
189+
public static bool operator <=(int a, OperableObject b)
190+
{
191+
return (a <= b.Num);
192+
}
193+
public static bool operator <=(OperableObject a, OperableObject b)
194+
{
195+
return (a.Num <= b.Num);
196+
}
197+
public static bool operator <=(OperableObject a, int b)
198+
{
199+
return (a.Num <= b);
200+
}
201+
202+
public static bool operator >=(int a, OperableObject b)
203+
{
204+
return (a >= b.Num);
205+
}
206+
public static bool operator >=(OperableObject a, OperableObject b)
207+
{
208+
return (a.Num >= b.Num);
209+
}
210+
public static bool operator >=(OperableObject a, int b)
211+
{
212+
return (a.Num >= b);
213+
}
214+
215+
public static bool operator >=(OperableObject a, PyObject b)
216+
{
217+
using (Py.GIL())
218+
{
219+
// Assuming b is a tuple, take the first element.
220+
int bNum = b[0].As<int>();
221+
return a.Num >= bNum;
222+
}
223+
}
224+
public static bool operator <=(OperableObject a, PyObject b)
225+
{
226+
using (Py.GIL())
227+
{
228+
// Assuming b is a tuple, take the first element.
229+
int bNum = b[0].As<int>();
230+
return a.Num <= bNum;
231+
}
232+
}
233+
234+
public static bool operator <(int a, OperableObject b)
235+
{
236+
return (a < b.Num);
237+
}
238+
public static bool operator <(OperableObject a, OperableObject b)
239+
{
240+
return (a.Num < b.Num);
241+
}
242+
public static bool operator <(OperableObject a, int b)
243+
{
244+
return (a.Num < b);
245+
}
246+
247+
public static bool operator >(int a, OperableObject b)
248+
{
249+
return (a > b.Num);
250+
}
251+
public static bool operator >(OperableObject a, OperableObject b)
252+
{
253+
return (a.Num > b.Num);
254+
}
255+
public static bool operator >(OperableObject a, int b)
256+
{
257+
return (a.Num > b);
258+
}
259+
152260
public static OperableObject operator <<(OperableObject a, int offset)
153261
{
154262
return new OperableObject(a.Num << offset);
@@ -161,7 +269,7 @@ public OperableObject(int num)
161269
}
162270

163271
[Test]
164-
public void OperatorOverloads()
272+
public void SymmetricalOperatorOverloads()
165273
{
166274
string name = string.Format("{0}.{1}",
167275
typeof(OperableObject).DeclaringType.Name,
@@ -206,6 +314,24 @@ public void OperatorOverloads()
206314
207315
c = a ^ b
208316
assert c.Num == a.Num ^ b.Num
317+
318+
c = a == b
319+
assert c == (a.Num == b.Num)
320+
321+
c = a != b
322+
assert c == (a.Num != b.Num)
323+
324+
c = a <= b
325+
assert c == (a.Num <= b.Num)
326+
327+
c = a >= b
328+
assert c == (a.Num >= b.Num)
329+
330+
c = a < b
331+
assert c == (a.Num < b.Num)
332+
333+
c = a > b
334+
assert c == (a.Num > b.Num)
209335
");
210336
}
211337

@@ -263,6 +389,51 @@ public void ForwardOperatorOverloads()
263389
264390
c = a ^ b
265391
assert c.Num == a.Num ^ b
392+
393+
c = a == b
394+
assert c == (a.Num == b)
395+
396+
c = a != b
397+
assert c == (a.Num != b)
398+
399+
c = a <= b
400+
assert c == (a.Num <= b)
401+
402+
c = a >= b
403+
assert c == (a.Num >= b)
404+
405+
c = a < b
406+
assert c == (a.Num < b)
407+
408+
c = a > b
409+
assert c == (a.Num > b)
410+
");
411+
}
412+
413+
[Test]
414+
public void TupleComparisonOperatorOverloads()
415+
{
416+
string name = string.Format("{0}.{1}",
417+
typeof(OperableObject).DeclaringType.Name,
418+
typeof(OperableObject).Name);
419+
string module = MethodBase.GetCurrentMethod().DeclaringType.Namespace;
420+
PythonEngine.Exec($@"
421+
from {module} import *
422+
cls = {name}
423+
a = cls(2)
424+
b = (1, 2)
425+
426+
c = a >= b
427+
assert c == (a.Num >= b[0])
428+
429+
c = a <= b
430+
assert c == (a.Num <= b[0])
431+
432+
c = b >= a
433+
assert c == (b[0] >= a.Num)
434+
435+
c = b <= a
436+
assert c == (b[0] <= a.Num)
266437
");
267438
}
268439

@@ -304,6 +475,24 @@ public void ReverseOperatorOverloads()
304475
305476
c = a ^ b
306477
assert c.Num == a ^ b.Num
478+
479+
c = a == b
480+
assert c == (a == b.Num)
481+
482+
c = a != b
483+
assert c == (a != b.Num)
484+
485+
c = a <= b
486+
assert c == (a <= b.Num)
487+
488+
c = a >= b
489+
assert c == (a >= b.Num)
490+
491+
c = a < b
492+
assert c == (a < b.Num)
493+
494+
c = a > b
495+
assert c == (a > b.Num)
307496
");
308497

309498
}

src/runtime/classbase.cs

+34
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ internal class ClassBase : ManagedType
2121
[NonSerialized]
2222
internal List<string> dotNetMembers;
2323
internal Indexer indexer;
24+
internal Dictionary<int, MethodObject> richcompare;
2425
internal MaybeType type;
2526

2627
internal ClassBase(Type tp)
@@ -35,6 +36,15 @@ internal virtual bool CanSubclass()
3536
return !type.Value.IsEnum;
3637
}
3738

39+
public readonly static Dictionary<string, int> CilToPyOpMap = new Dictionary<string, int>
40+
{
41+
["op_Equality"] = Runtime.Py_EQ,
42+
["op_Inequality"] = Runtime.Py_NE,
43+
["op_LessThanOrEqual"] = Runtime.Py_LE,
44+
["op_GreaterThanOrEqual"] = Runtime.Py_GE,
45+
["op_LessThan"] = Runtime.Py_LT,
46+
["op_GreaterThan"] = Runtime.Py_GT,
47+
};
3848

3949
/// <summary>
4050
/// Default implementation of [] semantics for reflected types.
@@ -72,6 +82,30 @@ public static IntPtr tp_richcompare(IntPtr ob, IntPtr other, int op)
7282
{
7383
CLRObject co1;
7484
CLRObject co2;
85+
IntPtr tp = Runtime.PyObject_TYPE(ob);
86+
var cls = (ClassBase)GetManagedObject(tp);
87+
// C# operator methods take precedence over IComparable.
88+
// We first check if there's a comparison operator by looking up the richcompare table,
89+
// otherwise fallback to checking if an IComparable interface is handled.
90+
if (cls.richcompare.TryGetValue(op, out var methodObject))
91+
{
92+
// Wrap the `other` argument of a binary comparison operator in a PyTuple.
93+
IntPtr args = Runtime.PyTuple_New(1);
94+
Runtime.XIncref(other);
95+
Runtime.PyTuple_SetItem(args, 0, other);
96+
97+
IntPtr value;
98+
try
99+
{
100+
value = methodObject.Invoke(ob, args, IntPtr.Zero);
101+
}
102+
finally
103+
{
104+
Runtime.XDecref(args); // Free args pytuple
105+
}
106+
return value;
107+
}
108+
75109
switch (op)
76110
{
77111
case Runtime.Py_EQ:

src/runtime/classmanager.cs

+5-2
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ private static void InitClassBase(Type type, ClassBase impl)
259259
ClassInfo info = GetClassInfo(type);
260260

261261
impl.indexer = info.indexer;
262+
impl.richcompare = new Dictionary<int, MethodObject>();
262263

263264
// Now we allocate the Python type object to reflect the given
264265
// managed type, filling the Python type slots with thunks that
@@ -284,6 +285,9 @@ private static void InitClassBase(Type type, ClassBase impl)
284285
Runtime.PyDict_SetItemString(dict, name, item.pyHandle);
285286
// Decref the item now that it's been used.
286287
item.DecrRefCount();
288+
if (ClassBase.CilToPyOpMap.TryGetValue(name, out var pyOp)) {
289+
impl.richcompare.Add(pyOp, (MethodObject)item);
290+
}
287291
}
288292

289293
// If class has constructors, generate an __doc__ attribute.
@@ -553,8 +557,7 @@ private static ClassInfo GetClassInfo(Type type)
553557
{
554558
string pyName = OperatorMethod.GetPyMethodName(name);
555559
string pyNameReverse = OperatorMethod.ReversePyMethodName(pyName);
556-
MethodInfo[] forwardMethods, reverseMethods;
557-
OperatorMethod.FilterMethods(mlist, out forwardMethods, out reverseMethods);
560+
OperatorMethod.FilterMethods(mlist, out var forwardMethods, out var reverseMethods);
558561
// Only methods where the left operand is the declaring type.
559562
if (forwardMethods.Length > 0)
560563
ci.members[pyName] = new MethodObject(type, name, forwardMethods);

src/runtime/methodbinder.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -354,16 +354,17 @@ internal Binding Bind(IntPtr inst, IntPtr args, IntPtr kw, MethodBase info, Meth
354354
int kwargsMatched;
355355
int defaultsNeeded;
356356
bool isOperator = OperatorMethod.IsOperatorMethod(mi);
357-
int clrnargs = pi.Length;
358357
// Binary operator methods will have 2 CLR args but only one Python arg
359358
// (unary operators will have 1 less each), since Python operator methods are bound.
360-
isOperator = isOperator && pynargs == clrnargs - 1;
359+
isOperator = isOperator && pynargs == pi.Length - 1;
360+
bool isReverse = isOperator && OperatorMethod.IsReverse((MethodInfo)mi); // Only cast if isOperator.
361+
if (isReverse && OperatorMethod.IsComparisonOp((MethodInfo)mi))
362+
continue; // Comparison operators in Python have no reverse mode.
361363
if (!MatchesArgumentCount(pynargs, pi, kwargDict, out paramsArray, out defaultArgList, out kwargsMatched, out defaultsNeeded) && !isOperator)
362364
{
363365
continue;
364366
}
365367
// Preprocessing pi to remove either the first or second argument.
366-
bool isReverse = isOperator && OperatorMethod.IsReverse((MethodInfo)mi); // Only cast if isOperator.
367368
if (isOperator && !isReverse) {
368369
// The first Python arg is the right operand, while the bound instance is the left.
369370
// We need to skip the first (left operand) CLR argument.

0 commit comments

Comments
 (0)