Skip to content

ENH: Add dtype support to the array comparison ops #18128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 107 additions & 30 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,16 @@ from numpy.core._internal import _ctypes
from numpy.typing import (
# Arrays
ArrayLike,
_ArrayND,
_ArrayOrScalar,
_NestedSequence,
_RecursiveSequence,
_ArrayLikeNumber_co,
_ArrayLikeTD64_co,
_ArrayLikeDT64_co,

# DTypes

DTypeLike,
_SupportsDType,
_VoidDTypeLike,
Expand Down Expand Up @@ -127,6 +135,7 @@ from typing import (
Iterable,
List,
Mapping,
NoReturn,
Optional,
overload,
Sequence,
Expand Down Expand Up @@ -584,19 +593,19 @@ where: Any
who: Any

_NdArraySubClass = TypeVar("_NdArraySubClass", bound=ndarray)
_DTypeScalar = TypeVar("_DTypeScalar", bound=generic)
_DTypeScalar_co = TypeVar("_DTypeScalar_co", covariant=True, bound=generic)
_ByteOrder = Literal["S", "<", ">", "=", "|", "L", "B", "N", "I"]

class dtype(Generic[_DTypeScalar]):
class dtype(Generic[_DTypeScalar_co]):
names: Optional[Tuple[str, ...]]
# Overload for subclass of generic
@overload
def __new__(
cls,
dtype: Type[_DTypeScalar],
dtype: Type[_DTypeScalar_co],
align: bool = ...,
copy: bool = ...,
) -> dtype[_DTypeScalar]: ...
) -> dtype[_DTypeScalar_co]: ...
# Overloads for string aliases, Python types, and some assorted
# other special cases. Order is sometimes important because of the
# subtype relationships
Expand Down Expand Up @@ -711,10 +720,10 @@ class dtype(Generic[_DTypeScalar]):
@overload
def __new__(
cls,
dtype: dtype[_DTypeScalar],
dtype: dtype[_DTypeScalar_co],
align: bool = ...,
copy: bool = ...,
) -> dtype[_DTypeScalar]: ...
) -> dtype[_DTypeScalar_co]: ...
# TODO: handle _SupportsDType better
@overload
def __new__(
Expand Down Expand Up @@ -791,7 +800,7 @@ class dtype(Generic[_DTypeScalar]):
@property
def str(self) -> builtins.str: ...
@property
def type(self) -> Type[_DTypeScalar]: ...
def type(self) -> Type[_DTypeScalar_co]: ...

class _flagsobj:
aligned: bool
Expand Down Expand Up @@ -1319,6 +1328,7 @@ class _ArrayOrScalarCommon:
) -> _NdArraySubClass: ...

_DType = TypeVar("_DType", bound=dtype[Any])
_DType_co = TypeVar("_DType_co", covariant=True, bound=dtype[Any])

# TODO: Set the `bound` to something more suitable once we
# have proper shape support
Expand All @@ -1327,7 +1337,7 @@ _ShapeType = TypeVar("_ShapeType", bound=Any)
_BufferType = Union[ndarray, bytes, bytearray, memoryview]
_Casting = Literal["no", "equiv", "safe", "same_kind", "unsafe"]

class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType]):
class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
@property
def base(self) -> Optional[ndarray]: ...
@property
Expand All @@ -1352,7 +1362,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType]):
order: _OrderKACF = ...,
) -> _ArraySelf: ...
@overload
def __array__(self, __dtype: None = ...) -> ndarray[Any, _DType]: ...
def __array__(self, __dtype: None = ...) -> ndarray[Any, _DType_co]: ...
@overload
def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ...
@property
Expand Down Expand Up @@ -1464,10 +1474,77 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType]):
def __iter__(self) -> Any: ...
def __contains__(self, key) -> bool: ...
def __index__(self) -> int: ...
def __lt__(self, other: ArrayLike) -> Union[ndarray, bool_]: ...
def __le__(self, other: ArrayLike) -> Union[ndarray, bool_]: ...
def __gt__(self, other: ArrayLike) -> Union[ndarray, bool_]: ...
def __ge__(self, other: ArrayLike) -> Union[ndarray, bool_]: ...

# The last overload is for catching recursive objects whose
# nesting is too deep.
# The first overload is for catching `bytes` (as they are a subtype of
# `Sequence[int]`) and `str`. As `str` is a recusive sequence of
# strings, it will pass through the final overload otherwise

@overload
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To summarize the various overloads:

  1. Overload for filtering out any str/bytes-based array-likes. This is needed as bytes would otherwise be recognized as a Sequence[int] sub-type (which is bad) and str would otherwise be caught by overload 6. (as str is a sequence of strings, which is a sequence of strings, etc.).
  2. Number-based overload
  3. Timedelta-based overload
  4. Datetime-based overload
  5. Object-based overload. other is typed as Any here as object arrays can be extremely flexible (depending on the actual underlying objects).
  6. A final overload to handle all sequences whose nesting level is too deep for the previous overloads.

def __lt__(self: _ArrayND[Any], other: _NestedSequence[Union[str, bytes]]) -> NoReturn: ...
@overload
def __lt__(self: _ArrayND[Union[number[Any], bool_]], other: _ArrayLikeNumber_co) -> _ArrayOrScalar[bool_]: ...
@overload
def __lt__(self: _ArrayND[Union[bool_, integer[Any], timedelta64]], other: _ArrayLikeTD64_co) -> _ArrayOrScalar[bool_]: ...
@overload
def __lt__(self: _ArrayND[datetime64], other: _ArrayLikeDT64_co) -> _ArrayOrScalar[bool_]: ...
@overload
def __lt__(self: _ArrayND[object_], other: Any) -> _ArrayOrScalar[bool_]: ...
@overload
def __lt__(
self: _ArrayND[Union[number[Any], datetime64, timedelta64, bool_]],
other: _RecursiveSequence,
) -> _ArrayOrScalar[bool_]: ...

@overload
def __le__(self: _ArrayND[Any], other: _NestedSequence[Union[str, bytes]]) -> NoReturn: ...
@overload
def __le__(self: _ArrayND[Union[number[Any], bool_]], other: _ArrayLikeNumber_co) -> _ArrayOrScalar[bool_]: ...
@overload
def __le__(self: _ArrayND[Union[bool_, integer[Any], timedelta64]], other: _ArrayLikeTD64_co) -> _ArrayOrScalar[bool_]: ...
@overload
def __le__(self: _ArrayND[datetime64], other: _ArrayLikeDT64_co) -> _ArrayOrScalar[bool_]: ...
@overload
def __le__(self: _ArrayND[object_], other: Any) -> _ArrayOrScalar[bool_]: ...
@overload
def __le__(
self: _ArrayND[Union[number[Any], datetime64, timedelta64, bool_]],
other: _RecursiveSequence,
) -> _ArrayOrScalar[bool_]: ...

@overload
def __gt__(self: _ArrayND[Any], other: _NestedSequence[Union[str, bytes]]) -> NoReturn: ...
@overload
def __gt__(self: _ArrayND[Union[number[Any], bool_]], other: _ArrayLikeNumber_co) -> _ArrayOrScalar[bool_]: ...
@overload
def __gt__(self: _ArrayND[Union[bool_, integer[Any], timedelta64]], other: _ArrayLikeTD64_co) -> _ArrayOrScalar[bool_]: ...
@overload
def __gt__(self: _ArrayND[datetime64], other: _ArrayLikeDT64_co) -> _ArrayOrScalar[bool_]: ...
@overload
def __gt__(self: _ArrayND[object_], other: Any) -> _ArrayOrScalar[bool_]: ...
@overload
def __gt__(
self: _ArrayND[Union[number[Any], datetime64, timedelta64, bool_]],
other: _RecursiveSequence,
) -> _ArrayOrScalar[bool_]: ...

@overload
def __ge__(self: _ArrayND[Any], other: _NestedSequence[Union[str, bytes]]) -> NoReturn: ...
@overload
def __ge__(self: _ArrayND[Union[number[Any], bool_]], other: _ArrayLikeNumber_co) -> _ArrayOrScalar[bool_]: ...
@overload
def __ge__(self: _ArrayND[Union[bool_, integer[Any], timedelta64]], other: _ArrayLikeTD64_co) -> _ArrayOrScalar[bool_]: ...
@overload
def __ge__(self: _ArrayND[datetime64], other: _ArrayLikeDT64_co) -> _ArrayOrScalar[bool_]: ...
@overload
def __ge__(self: _ArrayND[object_], other: Any) -> _ArrayOrScalar[bool_]: ...
@overload
def __ge__(
self: _ArrayND[Union[number[Any], datetime64, timedelta64, bool_]],
other: _RecursiveSequence,
) -> _ArrayOrScalar[bool_]: ...

def __matmul__(self, other: ArrayLike) -> Any: ...
# NOTE: `ndarray` does not implement `__imatmul__`
def __rmatmul__(self, other: ArrayLike) -> Any: ...
Expand Down Expand Up @@ -1516,7 +1593,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType]):
def __ior__(self: _ArraySelf, other: ArrayLike) -> _ArraySelf: ...
# Keep `dtype` at the bottom to avoid name conflicts with `np.dtype`
@property
def dtype(self) -> _DType: ...
def dtype(self) -> _DType_co: ...

# NOTE: while `np.generic` is not technically an instance of `ABCMeta`,
# the `@abstractmethod` decorator is herein used to (forcefully) deny
Expand Down Expand Up @@ -1586,10 +1663,10 @@ class number(generic, Generic[_NBit1]): # type: ignore
__rpow__: _NumberOp
__truediv__: _NumberOp
__rtruediv__: _NumberOp
__lt__: _ComparisonOp[_NumberLike_co]
__le__: _ComparisonOp[_NumberLike_co]
__gt__: _ComparisonOp[_NumberLike_co]
__ge__: _ComparisonOp[_NumberLike_co]
__lt__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
__le__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
__gt__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
__ge__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]

class bool_(generic):
def __init__(self, __value: object = ...) -> None: ...
Expand Down Expand Up @@ -1628,10 +1705,10 @@ class bool_(generic):
__rmod__: _BoolMod
__divmod__: _BoolDivMod
__rdivmod__: _BoolDivMod
__lt__: _ComparisonOp[_NumberLike_co]
__le__: _ComparisonOp[_NumberLike_co]
__gt__: _ComparisonOp[_NumberLike_co]
__ge__: _ComparisonOp[_NumberLike_co]
__lt__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
__le__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
__gt__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]
__ge__: _ComparisonOp[_NumberLike_co, _ArrayLikeNumber_co]

class object_(generic):
def __init__(self, __value: object = ...) -> None: ...
Expand Down Expand Up @@ -1660,10 +1737,10 @@ class datetime64(generic):
@overload
def __sub__(self, other: _TD64Like_co) -> datetime64: ...
def __rsub__(self, other: datetime64) -> timedelta64: ...
__lt__: _ComparisonOp[datetime64]
__le__: _ComparisonOp[datetime64]
__gt__: _ComparisonOp[datetime64]
__ge__: _ComparisonOp[datetime64]
__lt__: _ComparisonOp[datetime64, _ArrayLikeDT64_co]
__le__: _ComparisonOp[datetime64, _ArrayLikeDT64_co]
__gt__: _ComparisonOp[datetime64, _ArrayLikeDT64_co]
__ge__: _ComparisonOp[datetime64, _ArrayLikeDT64_co]

# Support for `__index__` was added in python 3.8 (bpo-20092)
if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -1762,10 +1839,10 @@ class timedelta64(generic):
def __rmod__(self, other: timedelta64) -> timedelta64: ...
def __divmod__(self, other: timedelta64) -> Tuple[int64, timedelta64]: ...
def __rdivmod__(self, other: timedelta64) -> Tuple[int64, timedelta64]: ...
__lt__: _ComparisonOp[Union[timedelta64, _IntLike_co, _BoolLike_co]]
__le__: _ComparisonOp[Union[timedelta64, _IntLike_co, _BoolLike_co]]
__gt__: _ComparisonOp[Union[timedelta64, _IntLike_co, _BoolLike_co]]
__ge__: _ComparisonOp[Union[timedelta64, _IntLike_co, _BoolLike_co]]
__lt__: _ComparisonOp[_TD64Like_co, _ArrayLikeTD64_co]
__le__: _ComparisonOp[_TD64Like_co, _ArrayLikeTD64_co]
__gt__: _ComparisonOp[_TD64Like_co, _ArrayLikeTD64_co]
__ge__: _ComparisonOp[_TD64Like_co, _ArrayLikeTD64_co]

class unsignedinteger(integer[_NBit1]):
# NOTE: `uint64 + signedinteger -> float64`
Expand Down
5 changes: 5 additions & 0 deletions numpy/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,18 +302,23 @@ class _8Bit(_16Bit): ... # type: ignore[misc]
ArrayLike as ArrayLike,
_ArrayLike,
_NestedSequence,
_RecursiveSequence,
_SupportsArray,
_ArrayND,
_ArrayOrScalar,
_ArrayLikeBool_co,
_ArrayLikeUInt_co,
_ArrayLikeInt_co,
_ArrayLikeFloat_co,
_ArrayLikeComplex_co,
_ArrayLikeNumber_co,
_ArrayLikeTD64_co,
_ArrayLikeDT64_co,
_ArrayLikeObject_co,
_ArrayLikeVoid_co,
_ArrayLikeStr_co,
_ArrayLikeBytes_co,

)

if __doc__ is not None:
Expand Down
18 changes: 16 additions & 2 deletions numpy/typing/_array_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
integer,
floating,
complexfloating,
number,
timedelta64,
datetime64,
object_,
Expand All @@ -33,15 +34,17 @@
HAVE_PROTOCOL = True

_T = TypeVar("_T")
_ScalarType = TypeVar("_ScalarType", bound=generic)
_DType = TypeVar("_DType", bound="dtype[Any]")
_DType_co = TypeVar("_DType_co", covariant=True, bound="dtype[Any]")

if TYPE_CHECKING or HAVE_PROTOCOL:
# The `_SupportsArray` protocol only cares about the default dtype
# (i.e. `dtype=None`) of the to-be returned array.
# Concrete implementations of the protocol are responsible for adding
# any and all remaining overloads
class _SupportsArray(Protocol[_DType]):
def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ...
class _SupportsArray(Protocol[_DType_co]):
def __array__(self, dtype: None = ...) -> ndarray[Any, _DType_co]: ...
else:
_SupportsArray = Any

Expand Down Expand Up @@ -100,6 +103,10 @@ def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ...
"dtype[Union[bool_, integer[Any], floating[Any], complexfloating[Any, Any]]]",
Union[bool, int, float, complex],
]
_ArrayLikeNumber_co = _ArrayLike[
"dtype[Union[bool_, number[Any]]]",
Union[bool, int, float, complex],
]
_ArrayLikeTD64_co = _ArrayLike[
"dtype[Union[bool_, integer[Any], timedelta64]]",
Union[bool, int],
Expand All @@ -116,3 +123,10 @@ def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ...
"dtype[bytes_]",
bytes,
]

if TYPE_CHECKING:
_ArrayND = ndarray[Any, dtype[_ScalarType]]
_ArrayOrScalar = Union[_ScalarType, _ArrayND[_ScalarType]]
else:
_ArrayND = Any
_ArrayOrScalar = Any
16 changes: 10 additions & 6 deletions numpy/typing/_callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

"""

from __future__ import annotations

import sys
from typing import (
Union,
Expand All @@ -21,6 +23,7 @@

from numpy import (
ndarray,
dtype,
generic,
bool_,
timedelta64,
Expand All @@ -44,7 +47,7 @@
_NumberLike_co,
)
from . import NBitBase
from ._array_like import ArrayLike
from ._array_like import ArrayLike, _ArrayOrScalar

if sys.version_info >= (3, 8):
from typing import Protocol
Expand All @@ -58,8 +61,9 @@
HAVE_PROTOCOL = True

if TYPE_CHECKING or HAVE_PROTOCOL:
_T = TypeVar("_T")
_2Tuple = Tuple[_T, _T]
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_2Tuple = Tuple[_T1, _T1]

_NBit1 = TypeVar("_NBit1", bound=NBitBase)
_NBit2 = TypeVar("_NBit2", bound=NBitBase)
Expand Down Expand Up @@ -316,11 +320,11 @@ def __call__(
class _NumberOp(Protocol):
def __call__(self, __other: _NumberLike_co) -> Any: ...

class _ComparisonOp(Protocol[_T]):
class _ComparisonOp(Protocol[_T1, _T2]):
@overload
def __call__(self, __other: _T) -> bool_: ...
def __call__(self, __other: _T1) -> bool_: ...
@overload
def __call__(self, __other: ArrayLike) -> Union[ndarray, bool_]: ...
def __call__(self, __other: _T2) -> _ArrayOrScalar[bool_]: ...

else:
_BoolOp = Any
Expand Down
Loading