Skip to content

Commit 3e27819

Browse files
committed
IComparable and IEquatable implementations for PyInt, PyFloat, and PyString for primitive .NET types
1 parent eef67db commit 3e27819

File tree

10 files changed

+315
-4
lines changed

10 files changed

+315
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ This document follows the conventions laid out in [Keep a CHANGELOG][].
99

1010
### Added
1111

12+
- Added `IComparable` and `IEquatable` implementations to `PyInt`, `PyFloat`, and `PyString`
13+
to compare with primitive .NET types like `long`.
14+
1215
### Changed
1316

1417
### Fixed

src/embed_tests/TestPyFloat.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,5 +126,32 @@ public void AsFloatBad()
126126
StringAssert.StartsWith("could not convert string to float", ex.Message);
127127
Assert.IsNull(a);
128128
}
129+
130+
[Test]
131+
public void CompareTo()
132+
{
133+
var v = new PyFloat(42);
134+
135+
Assert.AreEqual(0, v.CompareTo(42f));
136+
Assert.AreEqual(0, v.CompareTo(42d));
137+
138+
Assert.AreEqual(1, v.CompareTo(41f));
139+
Assert.AreEqual(1, v.CompareTo(41d));
140+
141+
Assert.AreEqual(-1, v.CompareTo(43f));
142+
Assert.AreEqual(-1, v.CompareTo(43d));
143+
}
144+
145+
[Test]
146+
public void Equals()
147+
{
148+
var v = new PyFloat(42);
149+
150+
Assert.IsTrue(v.Equals(42f));
151+
Assert.IsTrue(v.Equals(42d));
152+
153+
Assert.IsFalse(v.Equals(41f));
154+
Assert.IsFalse(v.Equals(41d));
155+
}
129156
}
130157
}

src/embed_tests/TestPyInt.cs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,76 @@ public void ToBigInteger()
210210
CollectionAssert.AreEqual(expected, actual);
211211
}
212212

213+
[Test]
214+
public void CompareTo()
215+
{
216+
var v = new PyInt(42);
217+
218+
#region Signed
219+
Assert.AreEqual(0, v.CompareTo(42L));
220+
Assert.AreEqual(0, v.CompareTo(42));
221+
Assert.AreEqual(0, v.CompareTo((short)42));
222+
Assert.AreEqual(0, v.CompareTo((sbyte)42));
223+
224+
Assert.AreEqual(1, v.CompareTo(41L));
225+
Assert.AreEqual(1, v.CompareTo(41));
226+
Assert.AreEqual(1, v.CompareTo((short)41));
227+
Assert.AreEqual(1, v.CompareTo((sbyte)41));
228+
229+
Assert.AreEqual(-1, v.CompareTo(43L));
230+
Assert.AreEqual(-1, v.CompareTo(43));
231+
Assert.AreEqual(-1, v.CompareTo((short)43));
232+
Assert.AreEqual(-1, v.CompareTo((sbyte)43));
233+
#endregion Signed
234+
235+
#region Unsigned
236+
Assert.AreEqual(0, v.CompareTo(42UL));
237+
Assert.AreEqual(0, v.CompareTo(42U));
238+
Assert.AreEqual(0, v.CompareTo((ushort)42));
239+
Assert.AreEqual(0, v.CompareTo((byte)42));
240+
241+
Assert.AreEqual(1, v.CompareTo(41UL));
242+
Assert.AreEqual(1, v.CompareTo(41U));
243+
Assert.AreEqual(1, v.CompareTo((ushort)41));
244+
Assert.AreEqual(1, v.CompareTo((byte)41));
245+
246+
Assert.AreEqual(-1, v.CompareTo(43UL));
247+
Assert.AreEqual(-1, v.CompareTo(43U));
248+
Assert.AreEqual(-1, v.CompareTo((ushort)43));
249+
Assert.AreEqual(-1, v.CompareTo((byte)43));
250+
#endregion Unsigned
251+
}
252+
253+
[Test]
254+
public void Equals()
255+
{
256+
var v = new PyInt(42);
257+
258+
#region Signed
259+
Assert.True(v.Equals(42L));
260+
Assert.True(v.Equals(42));
261+
Assert.True(v.Equals((short)42));
262+
Assert.True(v.Equals((sbyte)42));
263+
264+
Assert.False(v.Equals(41L));
265+
Assert.False(v.Equals(41));
266+
Assert.False(v.Equals((short)41));
267+
Assert.False(v.Equals((sbyte)41));
268+
#endregion Signed
269+
270+
#region Unsigned
271+
Assert.True(v.Equals(42UL));
272+
Assert.True(v.Equals(42U));
273+
Assert.True(v.Equals((ushort)42));
274+
Assert.True(v.Equals((byte)42));
275+
276+
Assert.False(v.Equals(41UL));
277+
Assert.False(v.Equals(41U));
278+
Assert.False(v.Equals((ushort)41));
279+
Assert.False(v.Equals((byte)41));
280+
#endregion Unsigned
281+
}
282+
213283
[Test]
214284
public void ToBigIntegerLarge()
215285
{

src/embed_tests/TestPyString.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,24 @@ public void TestUnicodeSurrogate()
112112
Assert.AreEqual(4, actual.Length());
113113
Assert.AreEqual(expected, actual.ToString());
114114
}
115+
116+
[Test]
117+
public void CompareTo()
118+
{
119+
var a = new PyString("foo");
120+
121+
Assert.AreEqual(0, a.CompareTo("foo"));
122+
Assert.AreEqual("foo".CompareTo("bar"), a.CompareTo("bar"));
123+
Assert.AreEqual("foo".CompareTo("foz"), a.CompareTo("foz"));
124+
}
125+
126+
[Test]
127+
public void Equals()
128+
{
129+
var a = new PyString("foo");
130+
131+
Assert.True(a.Equals("foo"));
132+
Assert.False(a.Equals("bar"));
133+
}
115134
}
116135
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
3+
namespace Python.Runtime;
4+
5+
partial class PyFloat : IComparable<double>, IComparable<float>
6+
, IEquatable<double>, IEquatable<float>
7+
{
8+
public override bool Equals(object o)
9+
{
10+
using var _ = Py.GIL();
11+
return o switch
12+
{
13+
double f64 => this.Equals(f64),
14+
float f32 => this.Equals(f32),
15+
_ => base.Equals(o),
16+
};
17+
}
18+
19+
public int CompareTo(double other) => this.ToDouble().CompareTo(other);
20+
21+
public int CompareTo(float other) => this.ToDouble().CompareTo(other);
22+
23+
public bool Equals(double other) => this.ToDouble().Equals(other);
24+
25+
public bool Equals(float other) => this.ToDouble().Equals(other);
26+
}

src/runtime/PythonTypes/PyFloat.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Python.Runtime
88
/// PY3: https://docs.python.org/3/c-api/float.html
99
/// for details.
1010
/// </summary>
11-
public class PyFloat : PyNumber
11+
public partial class PyFloat : PyNumber
1212
{
1313
internal PyFloat(in StolenReference ptr) : base(ptr)
1414
{
@@ -100,6 +100,8 @@ public static PyFloat AsFloat(PyObject value)
100100
return new PyFloat(op.Steal());
101101
}
102102

103+
public double ToDouble() => Runtime.PyFloat_AsDouble(obj);
104+
103105
public override TypeCode GetTypeCode() => TypeCode.Double;
104106
}
105107
}
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
using System;
2+
3+
namespace Python.Runtime;
4+
5+
partial class PyInt : IComparable<long>, IComparable<int>, IComparable<sbyte>, IComparable<short>
6+
, IComparable<ulong>, IComparable<uint>, IComparable<ushort>, IComparable<byte>
7+
, IEquatable<long>, IEquatable<int>, IEquatable<short>, IEquatable<sbyte>
8+
, IEquatable<ulong>, IEquatable<uint>, IEquatable<ushort>, IEquatable<byte>
9+
{
10+
public override bool Equals(object o)
11+
{
12+
using var _ = Py.GIL();
13+
return o switch
14+
{
15+
long i64 => this.Equals(i64),
16+
int i32 => this.Equals(i32),
17+
short i16 => this.Equals(i16),
18+
sbyte i8 => this.Equals(i8),
19+
20+
ulong u64 => this.Equals(u64),
21+
uint u32 => this.Equals(u32),
22+
ushort u16 => this.Equals(u16),
23+
byte u8 => this.Equals(u8),
24+
25+
_ => base.Equals(o),
26+
};
27+
}
28+
29+
#region Signed
30+
public int CompareTo(long other)
31+
{
32+
using var pyOther = Runtime.PyInt_FromInt64(other);
33+
return this.CompareTo(pyOther.BorrowOrThrow());
34+
}
35+
36+
public int CompareTo(int other)
37+
{
38+
using var pyOther = Runtime.PyInt_FromInt32(other);
39+
return this.CompareTo(pyOther.BorrowOrThrow());
40+
}
41+
42+
public int CompareTo(short other)
43+
{
44+
using var pyOther = Runtime.PyInt_FromInt32(other);
45+
return this.CompareTo(pyOther.BorrowOrThrow());
46+
}
47+
48+
public int CompareTo(sbyte other)
49+
{
50+
using var pyOther = Runtime.PyInt_FromInt32(other);
51+
return this.CompareTo(pyOther.BorrowOrThrow());
52+
}
53+
54+
public bool Equals(long other)
55+
{
56+
using var pyOther = Runtime.PyInt_FromInt64(other);
57+
return this.Equals(pyOther.BorrowOrThrow());
58+
}
59+
60+
public bool Equals(int other)
61+
{
62+
using var pyOther = Runtime.PyInt_FromInt32(other);
63+
return this.Equals(pyOther.BorrowOrThrow());
64+
}
65+
66+
public bool Equals(short other)
67+
{
68+
using var pyOther = Runtime.PyInt_FromInt32(other);
69+
return this.Equals(pyOther.BorrowOrThrow());
70+
}
71+
72+
public bool Equals(sbyte other)
73+
{
74+
using var pyOther = Runtime.PyInt_FromInt32(other);
75+
return this.Equals(pyOther.BorrowOrThrow());
76+
}
77+
#endregion Signed
78+
79+
#region Unsigned
80+
public int CompareTo(ulong other)
81+
{
82+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
83+
return this.CompareTo(pyOther.BorrowOrThrow());
84+
}
85+
86+
public int CompareTo(uint other)
87+
{
88+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
89+
return this.CompareTo(pyOther.BorrowOrThrow());
90+
}
91+
92+
public int CompareTo(ushort other)
93+
{
94+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
95+
return this.CompareTo(pyOther.BorrowOrThrow());
96+
}
97+
98+
public int CompareTo(byte other)
99+
{
100+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
101+
return this.CompareTo(pyOther.BorrowOrThrow());
102+
}
103+
104+
public bool Equals(ulong other)
105+
{
106+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
107+
return this.Equals(pyOther.BorrowOrThrow());
108+
}
109+
110+
public bool Equals(uint other)
111+
{
112+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
113+
return this.Equals(pyOther.BorrowOrThrow());
114+
}
115+
116+
public bool Equals(ushort other)
117+
{
118+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
119+
return this.Equals(pyOther.BorrowOrThrow());
120+
}
121+
122+
public bool Equals(byte other)
123+
{
124+
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
125+
return this.Equals(pyOther.BorrowOrThrow());
126+
}
127+
#endregion Unsigned
128+
}

src/runtime/PythonTypes/PyInt.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace Python.Runtime
99
/// Represents a Python integer object.
1010
/// See the documentation at https://docs.python.org/3/c-api/long.html
1111
/// </summary>
12-
public class PyInt : PyNumber, IFormattable
12+
public partial class PyInt : PyNumber, IFormattable
1313
{
1414
internal PyInt(in StolenReference ptr) : base(ptr)
1515
{

src/runtime/PythonTypes/PyObject.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,23 @@ public long Refcount
11361136
}
11371137
}
11381138

1139+
internal int CompareTo(BorrowedReference other)
1140+
{
1141+
int greater = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_GT);
1142+
Debug.Assert(greater != -1);
1143+
if (greater > 0)
1144+
return 1;
1145+
int less = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_LT);
1146+
Debug.Assert(less != -1);
1147+
return less > 0 ? -1 : 0;
1148+
}
1149+
1150+
internal bool Equals(BorrowedReference other)
1151+
{
1152+
int equal = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_EQ);
1153+
Debug.Assert(equal != -1);
1154+
return equal > 0;
1155+
}
11391156

11401157
public override bool TryGetMember(GetMemberBinder binder, out object? result)
11411158
{
@@ -1325,7 +1342,7 @@ private bool TryCompare(PyObject arg, int op, out object @out)
13251342
}
13261343
return true;
13271344
}
1328-
1345+
13291346
public override bool TryBinaryOperation(BinaryOperationBinder binder, object arg, out object? result)
13301347
{
13311348
using var _ = Py.GIL();

src/runtime/PythonTypes/PyString.cs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Diagnostics;
23
using System.Runtime.Serialization;
34

45
namespace Python.Runtime
@@ -13,7 +14,7 @@ namespace Python.Runtime
1314
/// 2011-01-29: ...Then why does the string constructor call PyUnicode_FromUnicode()???
1415
/// </remarks>
1516
[Serializable]
16-
public class PyString : PySequence
17+
public class PyString : PySequence, IComparable<string>, IEquatable<string>
1718
{
1819
internal PyString(in StolenReference reference) : base(reference) { }
1920
internal PyString(BorrowedReference reference) : base(reference) { }
@@ -61,5 +62,23 @@ public static bool IsStringType(PyObject value)
6162
}
6263

6364
public override TypeCode GetTypeCode() => TypeCode.String;
65+
66+
internal string ToStringUnderGIL()
67+
{
68+
string? result = Runtime.GetManagedString(this.Reference);
69+
Debug.Assert(result is not null);
70+
return result!;
71+
}
72+
73+
public bool Equals(string? other)
74+
=> this.ToStringUnderGIL().Equals(other, StringComparison.CurrentCulture);
75+
public int CompareTo(string? other)
76+
=> string.Compare(this.ToStringUnderGIL(), other, StringComparison.CurrentCulture);
77+
78+
public override string ToString()
79+
{
80+
using var _ = Py.GIL();
81+
return this.ToStringUnderGIL();
82+
}
6483
}
6584
}

0 commit comments

Comments
 (0)