diff --git a/src/semver/version.py b/src/semver/version.py index 2f1f8cb..0ff5b9f 100644 --- a/src/semver/version.py +++ b/src/semver/version.py @@ -41,7 +41,7 @@ def _comparator(operator: Comparator) -> Comparator: @wraps(operator) def wrapper(self: "Version", other: Comparable) -> bool: comparable_types = ( - Version, + type(self), dict, tuple, list, diff --git a/tests/test_subclass.py b/tests/test_subclass.py index b33f496..0d39000 100644 --- a/tests/test_subclass.py +++ b/tests/test_subclass.py @@ -1,4 +1,5 @@ from semver import Version +import pytest def test_subclass_from_versioninfo(): @@ -51,3 +52,20 @@ def __str__(self) -> str: dev_version = version.replace(prerelease="dev.0") assert str(dev_version) == "v1.1.0-dev.0" + + +def test_compare_with_subclass(): + class SemVerSubclass(Version): + pass + + with pytest.raises(TypeError): + SemVerSubclass.parse("1.0.0").compare(Version.parse("1.0.0")) + assert Version.parse("1.0.0").compare(SemVerSubclass.parse("1.0.0")) == 0 + + assert ( + SemVerSubclass.parse("1.0.0").__eq__(Version.parse("1.0.0")) is NotImplemented + ) + assert Version.parse("1.0.0").__eq__(SemVerSubclass.parse("1.0.0")) is True + + assert SemVerSubclass.parse("1.0.0") == Version.parse("1.0.0") + assert Version.parse("1.0.0") == SemVerSubclass.parse("1.0.0")