Skip to content

IComparable and IEquatable implementations #2322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ This document follows the conventions laid out in [Keep a CHANGELOG][].

- Added `ToPythonAs<T>()` extension method to allow for explicit conversion using a specific type. ([#2311][i2311])

- Added `IComparable` and `IEquatable` implementations to `PyInt`, `PyFloat`, and `PyString`
to compare with primitive .NET types like `long`.

### Changed

### Fixed
Expand Down
27 changes: 27 additions & 0 deletions src/embed_tests/TestPyFloat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,32 @@ public void AsFloatBad()
StringAssert.StartsWith("could not convert string to float", ex.Message);
Assert.IsNull(a);
}

[Test]
public void CompareTo()
{
var v = new PyFloat(42);

Assert.AreEqual(0, v.CompareTo(42f));
Assert.AreEqual(0, v.CompareTo(42d));

Assert.AreEqual(1, v.CompareTo(41f));
Assert.AreEqual(1, v.CompareTo(41d));

Assert.AreEqual(-1, v.CompareTo(43f));
Assert.AreEqual(-1, v.CompareTo(43d));
}

[Test]
public void Equals()
{
var v = new PyFloat(42);

Assert.IsTrue(v.Equals(42f));
Assert.IsTrue(v.Equals(42d));

Assert.IsFalse(v.Equals(41f));
Assert.IsFalse(v.Equals(41d));
}
}
}
70 changes: 70 additions & 0 deletions src/embed_tests/TestPyInt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,76 @@ public void ToBigInteger()
CollectionAssert.AreEqual(expected, actual);
}

[Test]
public void CompareTo()
{
var v = new PyInt(42);

#region Signed
Assert.AreEqual(0, v.CompareTo(42L));
Assert.AreEqual(0, v.CompareTo(42));
Assert.AreEqual(0, v.CompareTo((short)42));
Assert.AreEqual(0, v.CompareTo((sbyte)42));

Assert.AreEqual(1, v.CompareTo(41L));
Assert.AreEqual(1, v.CompareTo(41));
Assert.AreEqual(1, v.CompareTo((short)41));
Assert.AreEqual(1, v.CompareTo((sbyte)41));

Assert.AreEqual(-1, v.CompareTo(43L));
Assert.AreEqual(-1, v.CompareTo(43));
Assert.AreEqual(-1, v.CompareTo((short)43));
Assert.AreEqual(-1, v.CompareTo((sbyte)43));
#endregion Signed

#region Unsigned
Assert.AreEqual(0, v.CompareTo(42UL));
Assert.AreEqual(0, v.CompareTo(42U));
Assert.AreEqual(0, v.CompareTo((ushort)42));
Assert.AreEqual(0, v.CompareTo((byte)42));

Assert.AreEqual(1, v.CompareTo(41UL));
Assert.AreEqual(1, v.CompareTo(41U));
Assert.AreEqual(1, v.CompareTo((ushort)41));
Assert.AreEqual(1, v.CompareTo((byte)41));

Assert.AreEqual(-1, v.CompareTo(43UL));
Assert.AreEqual(-1, v.CompareTo(43U));
Assert.AreEqual(-1, v.CompareTo((ushort)43));
Assert.AreEqual(-1, v.CompareTo((byte)43));
#endregion Unsigned
}

[Test]
public void Equals()
{
var v = new PyInt(42);

#region Signed
Assert.True(v.Equals(42L));
Assert.True(v.Equals(42));
Assert.True(v.Equals((short)42));
Assert.True(v.Equals((sbyte)42));

Assert.False(v.Equals(41L));
Assert.False(v.Equals(41));
Assert.False(v.Equals((short)41));
Assert.False(v.Equals((sbyte)41));
#endregion Signed

#region Unsigned
Assert.True(v.Equals(42UL));
Assert.True(v.Equals(42U));
Assert.True(v.Equals((ushort)42));
Assert.True(v.Equals((byte)42));

Assert.False(v.Equals(41UL));
Assert.False(v.Equals(41U));
Assert.False(v.Equals((ushort)41));
Assert.False(v.Equals((byte)41));
#endregion Unsigned
}

[Test]
public void ToBigIntegerLarge()
{
Expand Down
19 changes: 19 additions & 0 deletions src/embed_tests/TestPyString.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,24 @@ public void TestUnicodeSurrogate()
Assert.AreEqual(4, actual.Length());
Assert.AreEqual(expected, actual.ToString());
}

[Test]
public void CompareTo()
{
var a = new PyString("foo");

Assert.AreEqual(0, a.CompareTo("foo"));
Assert.AreEqual("foo".CompareTo("bar"), a.CompareTo("bar"));
Assert.AreEqual("foo".CompareTo("foz"), a.CompareTo("foz"));
}

[Test]
public void Equals()
{
var a = new PyString("foo");

Assert.True(a.Equals("foo"));
Assert.False(a.Equals("bar"));
}
}
}
34 changes: 34 additions & 0 deletions src/runtime/PythonTypes/PyFloat.IComparable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using System;

namespace Python.Runtime;

partial class PyFloat : IComparable<double>, IComparable<float>
, IEquatable<double>, IEquatable<float>
, IComparable<PyFloat?>, IEquatable<PyFloat?>
{
public override bool Equals(object o)
{
using var _ = Py.GIL();
return o switch
{
double f64 => this.Equals(f64),
float f32 => this.Equals(f32),
_ => base.Equals(o),
};
}

public int CompareTo(double other) => this.ToDouble().CompareTo(other);

public int CompareTo(float other) => this.ToDouble().CompareTo(other);

public bool Equals(double other) => this.ToDouble().Equals(other);

public bool Equals(float other) => this.ToDouble().Equals(other);

public int CompareTo(PyFloat? other)
{
return other is null ? 1 : this.CompareTo(other.BorrowNullable());
}

public bool Equals(PyFloat? other) => base.Equals(other);
}
4 changes: 3 additions & 1 deletion src/runtime/PythonTypes/PyFloat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Python.Runtime
/// PY3: https://docs.python.org/3/c-api/float.html
/// for details.
/// </summary>
public class PyFloat : PyNumber
public partial class PyFloat : PyNumber
{
internal PyFloat(in StolenReference ptr) : base(ptr)
{
Expand Down Expand Up @@ -100,6 +100,8 @@ public static PyFloat AsFloat(PyObject value)
return new PyFloat(op.Steal());
}

public double ToDouble() => Runtime.PyFloat_AsDouble(obj);

public override TypeCode GetTypeCode() => TypeCode.Double;
}
}
136 changes: 136 additions & 0 deletions src/runtime/PythonTypes/PyInt.IComparable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
using System;

namespace Python.Runtime;

partial class PyInt : IComparable<long>, IComparable<int>, IComparable<sbyte>, IComparable<short>
, IComparable<ulong>, IComparable<uint>, IComparable<ushort>, IComparable<byte>
, IEquatable<long>, IEquatable<int>, IEquatable<short>, IEquatable<sbyte>
, IEquatable<ulong>, IEquatable<uint>, IEquatable<ushort>, IEquatable<byte>
, IComparable<PyInt?>, IEquatable<PyInt?>
{
public override bool Equals(object o)
{
using var _ = Py.GIL();
return o switch
{
long i64 => this.Equals(i64),
int i32 => this.Equals(i32),
short i16 => this.Equals(i16),
sbyte i8 => this.Equals(i8),

ulong u64 => this.Equals(u64),
uint u32 => this.Equals(u32),
ushort u16 => this.Equals(u16),
byte u8 => this.Equals(u8),

_ => base.Equals(o),
};
}

#region Signed
public int CompareTo(long other)
{
using var pyOther = Runtime.PyInt_FromInt64(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public int CompareTo(int other)
{
using var pyOther = Runtime.PyInt_FromInt32(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public int CompareTo(short other)
{
using var pyOther = Runtime.PyInt_FromInt32(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public int CompareTo(sbyte other)
{
using var pyOther = Runtime.PyInt_FromInt32(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public bool Equals(long other)
{
using var pyOther = Runtime.PyInt_FromInt64(other);
return this.Equals(pyOther.BorrowOrThrow());
}

public bool Equals(int other)
{
using var pyOther = Runtime.PyInt_FromInt32(other);
return this.Equals(pyOther.BorrowOrThrow());
}

public bool Equals(short other)
{
using var pyOther = Runtime.PyInt_FromInt32(other);
return this.Equals(pyOther.BorrowOrThrow());
}

public bool Equals(sbyte other)
{
using var pyOther = Runtime.PyInt_FromInt32(other);
return this.Equals(pyOther.BorrowOrThrow());
}
#endregion Signed

#region Unsigned
public int CompareTo(ulong other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public int CompareTo(uint other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public int CompareTo(ushort other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public int CompareTo(byte other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public bool Equals(ulong other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.Equals(pyOther.BorrowOrThrow());
}

public bool Equals(uint other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.Equals(pyOther.BorrowOrThrow());
}

public bool Equals(ushort other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.Equals(pyOther.BorrowOrThrow());
}

public bool Equals(byte other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.Equals(pyOther.BorrowOrThrow());
}
#endregion Unsigned

public int CompareTo(PyInt? other)
{
return other is null ? 1 : this.CompareTo(other.BorrowNullable());
}

public bool Equals(PyInt? other) => base.Equals(other);
}
2 changes: 1 addition & 1 deletion src/runtime/PythonTypes/PyInt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Python.Runtime
/// Represents a Python integer object.
/// See the documentation at https://docs.python.org/3/c-api/long.html
/// </summary>
public class PyInt : PyNumber, IFormattable
public partial class PyInt : PyNumber, IFormattable
{
internal PyInt(in StolenReference ptr) : base(ptr)
{
Expand Down
19 changes: 18 additions & 1 deletion src/runtime/PythonTypes/PyObject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,23 @@ public long Refcount
}
}

internal int CompareTo(BorrowedReference other)
{
int greater = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_GT);
Debug.Assert(greater != -1);
if (greater > 0)
return 1;
int less = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_LT);
Debug.Assert(less != -1);
return less > 0 ? -1 : 0;
}

internal bool Equals(BorrowedReference other)
{
int equal = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_EQ);
Debug.Assert(equal != -1);
return equal > 0;
}

public override bool TryGetMember(GetMemberBinder binder, out object? result)
{
Expand Down Expand Up @@ -1325,7 +1342,7 @@ private bool TryCompare(PyObject arg, int op, out object @out)
}
return true;
}

public override bool TryBinaryOperation(BinaryOperationBinder binder, object arg, out object? result)
{
using var _ = Py.GIL();
Expand Down
Loading