diff --git a/src/runtime/classbase.cs b/src/runtime/classbase.cs index 7214a7ba1..3089465e3 100644 --- a/src/runtime/classbase.cs +++ b/src/runtime/classbase.cs @@ -67,68 +67,108 @@ public virtual IntPtr type_subscript(IntPtr idx) //==================================================================== // Standard comparison implementation for instances of reflected types. //==================================================================== -#if (PYTHON32 || PYTHON33 || PYTHON34 || PYTHON35) - public static IntPtr tp_richcompare(IntPtr ob, IntPtr other, int op) { - if (op != Runtime.Py_EQ && op != Runtime.Py_NE) - { - Runtime.XIncref(Runtime.PyNotImplemented); - return Runtime.PyNotImplemented; - } - - IntPtr pytrue = Runtime.PyTrue; - IntPtr pyfalse = Runtime.PyFalse; - - // swap true and false for NE - if (op != Runtime.Py_EQ) - { - pytrue = Runtime.PyFalse; - pyfalse = Runtime.PyTrue; - } - - if (ob == other) { - Runtime.XIncref(pytrue); - return pytrue; - } - - CLRObject co1 = GetManagedObject(ob) as CLRObject; - CLRObject co2 = GetManagedObject(other) as CLRObject; - if (null == co2) { - Runtime.XIncref(pyfalse); - return pyfalse; - } - - Object o1 = co1.inst; - Object o2 = co2.inst; - if (Object.Equals(o1, o2)) { - Runtime.XIncref(pytrue); - return pytrue; - } - - Runtime.XIncref(pyfalse); - return pyfalse; - } -#else - public static int tp_compare(IntPtr ob, IntPtr other) - { - if (ob == other) - { - return 0; - } - - CLRObject co1 = GetManagedObject(ob) as CLRObject; - CLRObject co2 = GetManagedObject(other) as CLRObject; - Object o1 = co1.inst; - Object o2 = co2.inst; - - if (Object.Equals(o1, o2)) + public static IntPtr tp_richcompare(IntPtr ob, IntPtr other, int op) { + CLRObject co1; + CLRObject co2; + switch (op) { - return 0; + case Runtime.Py_EQ: + case Runtime.Py_NE: + IntPtr pytrue = Runtime.PyTrue; + IntPtr pyfalse = Runtime.PyFalse; + + // swap true and false for NE + if (op != Runtime.Py_EQ) + { + pytrue = Runtime.PyFalse; + pyfalse = Runtime.PyTrue; + } + + if (ob == other) + { + Runtime.XIncref(pytrue); + return pytrue; + } + + co1 = GetManagedObject(ob) as CLRObject; + co2 = GetManagedObject(other) as CLRObject; + if (null == co2) + { + Runtime.XIncref(pyfalse); + return pyfalse; + } + + Object o1 = co1.inst; + Object o2 = co2.inst; + + if (Object.Equals(o1, o2)) + { + Runtime.XIncref(pytrue); + return pytrue; + } + + Runtime.XIncref(pyfalse); + return pyfalse; + case Runtime.Py_LT: + case Runtime.Py_LE: + case Runtime.Py_GT: + case Runtime.Py_GE: + co1 = GetManagedObject(ob) as CLRObject; + co2 = GetManagedObject(other) as CLRObject; + if(co1 == null || co2 == null) + return Exceptions.RaiseTypeError("Cannot get managed object"); + var co1Comp = co1.inst as IComparable; + if (co1Comp == null) + return Exceptions.RaiseTypeError("Cannot convert object of type " + co1.GetType() + " to IComparable"); + try + { + var cmp = co1Comp.CompareTo(co2.inst); + + IntPtr pyCmp; + if (cmp < 0) + { + if (op == Runtime.Py_LT || op == Runtime.Py_LE) + { + pyCmp = Runtime.PyTrue; + } + else + { + pyCmp = Runtime.PyFalse; + } + } + else if (cmp == 0) + { + if (op == Runtime.Py_LE || op == Runtime.Py_GE) + { + pyCmp = Runtime.PyTrue; + } + else + { + pyCmp = Runtime.PyFalse; + } + } + else + { + if (op == Runtime.Py_GE || op == Runtime.Py_GT) { + pyCmp = Runtime.PyTrue; + } + else { + pyCmp = Runtime.PyFalse; + } + } + Runtime.XIncref(pyCmp); + return pyCmp; + } + catch (ArgumentException e) + { + return Exceptions.RaiseTypeError(e.Message); + } + default: + Runtime.XIncref(Runtime.PyNotImplemented); + return Runtime.PyNotImplemented; } - return -1; } -#endif - //==================================================================== // Standard iteration support for instances of reflected types. This diff --git a/src/runtime/runtime.cs b/src/runtime/runtime.cs index 9fbab2d74..7f5123012 100644 --- a/src/runtime/runtime.cs +++ b/src/runtime/runtime.cs @@ -232,13 +232,13 @@ internal static void Initialize() } #if (PYTHON32 || PYTHON33 || PYTHON34 || PYTHON35) - IntPtr op = Runtime.PyImport_ImportModule("builtins"); - IntPtr dict = Runtime.PyObject_GetAttrString(op, "__dict__"); - PyNotImplemented = Runtime.PyObject_GetAttrString(op, "NotImplemented"); + IntPtr op = Runtime.PyImport_ImportModule("builtins"); + IntPtr dict = Runtime.PyObject_GetAttrString(op, "__dict__"); #else IntPtr dict = Runtime.PyImport_GetModuleDict(); IntPtr op = Runtime.PyDict_GetItemString(dict, "__builtin__"); #endif + PyNotImplemented = Runtime.PyObject_GetAttrString(op, "NotImplemented"); PyBaseObjectType = Runtime.PyObject_GetAttrString(op, "object"); PyModuleType = Runtime.PyObject_Type(op); @@ -263,7 +263,7 @@ internal static void Initialize() Runtime.XDecref(op); #if (PYTHON32 || PYTHON33 || PYTHON34 || PYTHON35) - Runtime.XDecref(dict); + Runtime.XDecref(dict); #endif op = Runtime.PyString_FromString("string"); @@ -275,9 +275,9 @@ internal static void Initialize() Runtime.XDecref(op); #if (PYTHON32 || PYTHON33 || PYTHON34 || PYTHON35) - op = Runtime.PyBytes_FromString("bytes"); - PyBytesType = Runtime.PyObject_Type(op); - Runtime.XDecref(op); + op = Runtime.PyBytes_FromString("bytes"); + PyBytesType = Runtime.PyObject_Type(op); + Runtime.XDecref(op); #endif op = Runtime.PyTuple_New(0); @@ -397,16 +397,18 @@ internal static int AtExit() internal static IntPtr PyTypeType; #if (PYTHON32 || PYTHON33 || PYTHON34 || PYTHON35) - internal static IntPtr PyBytesType; - internal static IntPtr PyNotImplemented; - internal const int Py_LT = 0; - internal const int Py_LE = 1; - internal const int Py_EQ = 2; - internal const int Py_NE = 3; - internal const int Py_GT = 4; - internal static IntPtr _PyObject_NextNotImplemented; + internal static IntPtr PyBytesType; + internal static IntPtr _PyObject_NextNotImplemented; #endif + internal static IntPtr PyNotImplemented; + internal const int Py_LT = 0; + internal const int Py_LE = 1; + internal const int Py_EQ = 2; + internal const int Py_NE = 3; + internal const int Py_GT = 4; + internal const int Py_GE = 5; + internal static IntPtr PyTrue; internal static IntPtr PyFalse; internal static IntPtr PyNone; diff --git a/src/tests/test_class.py b/src/tests/test_class.py index eaab0cd45..afb631622 100644 --- a/src/tests/test_class.py +++ b/src/tests/test_class.py @@ -1,9 +1,12 @@ -from System.Collections import Hashtable -from Python.Test import ClassTest -import sys, os, string, unittest, types +import clr +import types +import unittest + import Python.Test as Test import System import six +from Python.Test import ClassTest +from System.Collections import Hashtable if six.PY3: DictProxyType = type(object.__dict__) @@ -209,6 +212,44 @@ def testAddAndRemoveClassAttribute(self): del TimeSpan.new_method self.assertFalse(hasattr(ts, "new_method")) + def testComparisons(self): + from System import DateTimeOffset + + d1 = DateTimeOffset.Parse("2016-11-14") + d2 = DateTimeOffset.Parse("2016-11-15") + + self.assertEqual(d1 == d2, False) + self.assertEqual(d1 != d2, True) + + self.assertEqual(d1 < d2, True) + self.assertEqual(d1 <= d2, True) + self.assertEqual(d1 >= d2, False) + self.assertEqual(d1 > d2, False) + + self.assertEqual(d1 == d1, True) + self.assertEqual(d1 != d1, False) + + self.assertEqual(d1 < d1, False) + self.assertEqual(d1 <= d1, True) + self.assertEqual(d1 >= d1, True) + self.assertEqual(d1 > d1, False) + + self.assertEqual(d2 == d1, False) + self.assertEqual(d2 != d1, True) + + self.assertEqual(d2 < d1, False) + self.assertEqual(d2 <= d1, False) + self.assertEqual(d2 >= d1, True) + self.assertEqual(d2 > d1, True) + + self.assertRaises(TypeError, lambda: d1 < None) + self.assertRaises(TypeError, lambda: d1 < System.Guid()) + + # ClassTest does not implement IComparable + c1 = ClassTest() + c2 = ClassTest() + self.assertRaises(TypeError, lambda: c1 < c2) + class ClassicClass: def kind(self):