Skip to content

Commit 3d01f64

Browse files
authored
Merge pull request numpy#28717 from jorenham/typing/fix-28708
TYP: fix string-like ``ndarray`` rich comparison operators
2 parents 3476ce0 + 689d05a commit 3d01f64

File tree

2 files changed

+55
-24
lines changed

2 files changed

+55
-24
lines changed

numpy/__init__.pyi

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ from numpy._typing import (
2222
NDArray,
2323
_SupportsArray,
2424
_NestedSequence,
25-
_FiniteNestedSequence,
2625
_ArrayLike,
2726
_ArrayLikeBool_co,
2827
_ArrayLikeUInt_co,
@@ -33,28 +32,27 @@ from numpy._typing import (
3332
_ArrayLikeComplex128_co,
3433
_ArrayLikeComplex_co,
3534
_ArrayLikeNumber_co,
35+
_ArrayLikeObject_co,
36+
_ArrayLikeBytes_co,
37+
_ArrayLikeStr_co,
38+
_ArrayLikeString_co,
3639
_ArrayLikeTD64_co,
3740
_ArrayLikeDT64_co,
38-
_ArrayLikeObject_co,
39-
4041
# DTypes
4142
DTypeLike,
4243
_DTypeLike,
4344
_DTypeLikeVoid,
4445
_VoidDTypeLike,
45-
4646
# Shapes
4747
_Shape,
4848
_ShapeLike,
49-
5049
# Scalars
5150
_CharLike_co,
5251
_IntLike_co,
5352
_FloatLike_co,
5453
_TD64Like_co,
5554
_NumberLike_co,
5655
_ScalarLike_co,
57-
5856
# `number` precision
5957
NBitBase,
6058
# NOTE: Do not remove the extended precision bit-types even if seemingly unused;
@@ -77,7 +75,6 @@ from numpy._typing import (
7775
_NBitSingle,
7876
_NBitDouble,
7977
_NBitLongDouble,
80-
8178
# Character codes
8279
_BoolCodes,
8380
_UInt8Codes,
@@ -119,7 +116,6 @@ from numpy._typing import (
119116
_VoidCodes,
120117
_ObjectCodes,
121118
_StringCodes,
122-
123119
_UnsignedIntegerCodes,
124120
_SignedIntegerCodes,
125121
_IntegerCodes,
@@ -130,7 +126,6 @@ from numpy._typing import (
130126
_CharacterCodes,
131127
_FlexibleCodes,
132128
_GenericCodes,
133-
134129
# Ufuncs
135130
_UFunc_Nin1_Nout1,
136131
_UFunc_Nin2_Nout1,
@@ -2552,55 +2547,77 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
25522547
@overload # ?-d
25532548
def __iter__(self, /) -> Iterator[Any]: ...
25542549

2555-
# The last overload is for catching recursive objects whose
2556-
# nesting is too deep.
2557-
# The first overload is for catching `bytes` (as they are a subtype of
2558-
# `Sequence[int]`) and `str`. As `str` is a recursive sequence of
2559-
# strings, it will pass through the final overload otherwise
2560-
2550+
#
25612551
@overload
25622552
def __lt__(self: _ArrayNumber_co, other: _ArrayLikeNumber_co, /) -> NDArray[np.bool]: ...
25632553
@overload
25642554
def __lt__(self: _ArrayTD64_co, other: _ArrayLikeTD64_co, /) -> NDArray[np.bool]: ...
25652555
@overload
25662556
def __lt__(self: NDArray[datetime64], other: _ArrayLikeDT64_co, /) -> NDArray[np.bool]: ...
25672557
@overload
2568-
def __lt__(self: NDArray[object_], other: Any, /) -> NDArray[np.bool]: ...
2558+
def __lt__(self: NDArray[bytes_], other: _ArrayLikeBytes_co, /) -> NDArray[np.bool]: ...
25692559
@overload
2570-
def __lt__(self: NDArray[Any], other: _ArrayLikeObject_co, /) -> NDArray[np.bool]: ...
2560+
def __lt__(
2561+
self: ndarray[Any, dtype[str_] | dtypes.StringDType], other: _ArrayLikeStr_co | _ArrayLikeString_co, /
2562+
) -> NDArray[np.bool]: ...
2563+
@overload
2564+
def __lt__(self: NDArray[object_], other: object, /) -> NDArray[np.bool]: ...
2565+
@overload
2566+
def __lt__(self, other: _ArrayLikeObject_co, /) -> NDArray[np.bool]: ...
25712567

2568+
#
25722569
@overload
25732570
def __le__(self: _ArrayNumber_co, other: _ArrayLikeNumber_co, /) -> NDArray[np.bool]: ...
25742571
@overload
25752572
def __le__(self: _ArrayTD64_co, other: _ArrayLikeTD64_co, /) -> NDArray[np.bool]: ...
25762573
@overload
25772574
def __le__(self: NDArray[datetime64], other: _ArrayLikeDT64_co, /) -> NDArray[np.bool]: ...
25782575
@overload
2579-
def __le__(self: NDArray[object_], other: Any, /) -> NDArray[np.bool]: ...
2576+
def __le__(self: NDArray[bytes_], other: _ArrayLikeBytes_co, /) -> NDArray[np.bool]: ...
25802577
@overload
2581-
def __le__(self: NDArray[Any], other: _ArrayLikeObject_co, /) -> NDArray[np.bool]: ...
2578+
def __le__(
2579+
self: ndarray[Any, dtype[str_] | dtypes.StringDType], other: _ArrayLikeStr_co | _ArrayLikeString_co, /
2580+
) -> NDArray[np.bool]: ...
2581+
@overload
2582+
def __le__(self: NDArray[object_], other: object, /) -> NDArray[np.bool]: ...
2583+
@overload
2584+
def __le__(self, other: _ArrayLikeObject_co, /) -> NDArray[np.bool]: ...
25822585

2586+
#
25832587
@overload
25842588
def __gt__(self: _ArrayNumber_co, other: _ArrayLikeNumber_co, /) -> NDArray[np.bool]: ...
25852589
@overload
25862590
def __gt__(self: _ArrayTD64_co, other: _ArrayLikeTD64_co, /) -> NDArray[np.bool]: ...
25872591
@overload
25882592
def __gt__(self: NDArray[datetime64], other: _ArrayLikeDT64_co, /) -> NDArray[np.bool]: ...
25892593
@overload
2590-
def __gt__(self: NDArray[object_], other: Any, /) -> NDArray[np.bool]: ...
2594+
def __gt__(self: NDArray[bytes_], other: _ArrayLikeBytes_co, /) -> NDArray[np.bool]: ...
25912595
@overload
2592-
def __gt__(self: NDArray[Any], other: _ArrayLikeObject_co, /) -> NDArray[np.bool]: ...
2596+
def __gt__(
2597+
self: ndarray[Any, dtype[str_] | dtypes.StringDType], other: _ArrayLikeStr_co | _ArrayLikeString_co, /
2598+
) -> NDArray[np.bool]: ...
2599+
@overload
2600+
def __gt__(self: NDArray[object_], other: object, /) -> NDArray[np.bool]: ...
2601+
@overload
2602+
def __gt__(self, other: _ArrayLikeObject_co, /) -> NDArray[np.bool]: ...
25932603

2604+
#
25942605
@overload
25952606
def __ge__(self: _ArrayNumber_co, other: _ArrayLikeNumber_co, /) -> NDArray[np.bool]: ...
25962607
@overload
25972608
def __ge__(self: _ArrayTD64_co, other: _ArrayLikeTD64_co, /) -> NDArray[np.bool]: ...
25982609
@overload
25992610
def __ge__(self: NDArray[datetime64], other: _ArrayLikeDT64_co, /) -> NDArray[np.bool]: ...
26002611
@overload
2601-
def __ge__(self: NDArray[object_], other: Any, /) -> NDArray[np.bool]: ...
2612+
def __ge__(self: NDArray[bytes_], other: _ArrayLikeBytes_co, /) -> NDArray[np.bool]: ...
2613+
@overload
2614+
def __ge__(
2615+
self: ndarray[Any, dtype[str_] | dtypes.StringDType], other: _ArrayLikeStr_co | _ArrayLikeString_co, /
2616+
) -> NDArray[np.bool]: ...
2617+
@overload
2618+
def __ge__(self: NDArray[object_], other: object, /) -> NDArray[np.bool]: ...
26022619
@overload
2603-
def __ge__(self: NDArray[Any], other: _ArrayLikeObject_co, /) -> NDArray[np.bool]: ...
2620+
def __ge__(self, other: _ArrayLikeObject_co, /) -> NDArray[np.bool]: ...
26042621

26052622
# Unary ops
26062623

numpy/typing/tests/data/pass/comparisons.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import cast, Any
44
import numpy as np
55

66
c16 = np.complex128()
@@ -30,6 +30,9 @@
3030
AR_i: np.ndarray[Any, np.dtype[np.int_]] = np.array([1])
3131
AR_f: np.ndarray[Any, np.dtype[np.float64]] = np.array([1.0])
3232
AR_c: np.ndarray[Any, np.dtype[np.complex128]] = np.array([1.0j])
33+
AR_S: np.ndarray[Any, np.dtype[np.bytes_]] = np.array([b"a"], "S")
34+
AR_T = cast(np.ndarray[Any, np.dtypes.StringDType], np.array(["a"], "T"))
35+
AR_U: np.ndarray[Any, np.dtype[np.str_]] = np.array(["a"], "U")
3336
AR_m: np.ndarray[Any, np.dtype[np.timedelta64]] = np.array([np.timedelta64("1")])
3437
AR_M: np.ndarray[Any, np.dtype[np.datetime64]] = np.array([np.datetime64("1")])
3538
AR_O: np.ndarray[Any, np.dtype[np.object_]] = np.array([1], dtype=object)
@@ -66,6 +69,17 @@
6669
AR_c > AR_f
6770
AR_c > AR_c
6871

72+
AR_S > AR_S
73+
AR_S > b""
74+
75+
AR_T > AR_T
76+
AR_T > AR_U
77+
AR_T > ""
78+
79+
AR_U > AR_U
80+
AR_U > AR_T
81+
AR_U > ""
82+
6983
AR_m > AR_b
7084
AR_m > AR_u
7185
AR_m > AR_i

0 commit comments

Comments
 (0)