Skip to content

Commit d5fd966

Browse files
authored
TYP: Type MaskedArray.__{ge,gt,le,lt}__ (numpy#28689)
1 parent 7771624 commit d5fd966

File tree

3 files changed

+77
-6
lines changed

3 files changed

+77
-6
lines changed

numpy/ma/core.pyi

+6-6
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ from numpy import (
2525
from numpy._globals import _NoValueType
2626
from numpy._typing import (
2727
ArrayLike,
28-
_IntLike_co,
2928
NDArray,
3029
_ArrayLike,
30+
_ArrayLikeInt,
3131
_ArrayLikeInt_co,
3232
_DTypeLikeBool,
33-
_ArrayLikeInt,
33+
_IntLike_co,
3434
_ScalarLike_co,
3535
_Shape,
3636
_ShapeLike,
@@ -422,10 +422,10 @@ class MaskedArray(ndarray[_ShapeT_co, _DTypeT_co]):
422422
def compress(self, condition, axis=..., out=...): ...
423423
def __eq__(self, other): ...
424424
def __ne__(self, other): ...
425-
def __ge__(self, other): ...
426-
def __gt__(self, other): ...
427-
def __le__(self, other): ...
428-
def __lt__(self, other): ...
425+
def __ge__(self, other: ArrayLike, /) -> _MaskedArray[bool_]: ...
426+
def __gt__(self, other: ArrayLike, /) -> _MaskedArray[bool_]: ...
427+
def __le__(self, other: ArrayLike, /) -> _MaskedArray[bool_]: ...
428+
def __lt__(self, other: ArrayLike, /) -> _MaskedArray[bool_]: ...
429429
def __add__(self, other): ...
430430
def __radd__(self, other): ...
431431
def __sub__(self, other): ...

numpy/typing/tests/data/fail/ma.pyi

+9
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,12 @@ m.argpartition(axis=(0,1)) # E: No overload variant
9292
m.argpartition(kind='cabbage') # E: No overload variant
9393
m.argpartition(order=lambda: 'cabbage') # E: No overload variant
9494
m.argpartition(AR_b) # E: No overload variant
95+
96+
m >= (lambda x: 'mango') # E: No overload variant
97+
98+
m > (lambda x: 'mango') # E: No overload variant
99+
100+
m <= (lambda x: 'mango') # E: No overload variant
101+
102+
m < (lambda x: 'mango') # E: No overload variant
103+

numpy/typing/tests/data/reveal/ma.pyi

+62
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from typing_extensions import assert_type
33
from typing import Any, TypeAlias, TypeVar
44
from numpy._typing import _Shape, NDArray
55
from numpy import dtype, generic
6+
from datetime import datetime, timedelta
67

78

89
_ScalarT_co = TypeVar("_ScalarT_co", bound=generic, covariant=True)
@@ -11,12 +12,22 @@ MaskedNDArray: TypeAlias = np.ma.MaskedArray[_Shape, dtype[_ScalarT_co]]
1112
class MaskedNDArraySubclass(MaskedNDArray[np.complex128]): ...
1213

1314
AR_f4: NDArray[np.float32]
15+
AR_dt64: NDArray[np.datetime64]
16+
AR_td64: NDArray[np.timedelta64]
17+
AR_o: NDArray[np.timedelta64]
1418

1519
MAR_b: MaskedNDArray[np.bool]
1620
MAR_f4: MaskedNDArray[np.float32]
1721
MAR_f8: MaskedNDArray[np.float64]
1822
MAR_i8: MaskedNDArray[np.int64]
23+
MAR_dt64: MaskedNDArray[np.datetime64]
24+
MAR_td64: MaskedNDArray[np.timedelta64]
25+
MAR_o: MaskedNDArray[np.object_]
26+
MAR_s: MaskedNDArray[np.str_]
27+
MAR_byte: MaskedNDArray[np.bytes_]
28+
1929
MAR_subclass: MaskedNDArraySubclass
30+
2031
MAR_1d: np.ma.MaskedArray[tuple[int], np.dtype[Any]]
2132

2233
f4: np.float32
@@ -156,3 +167,54 @@ assert_type(MAR_f4.partition(1, axis=0, kind='introselect', order='K'), None)
156167

157168
assert_type(MAR_f4.argpartition(1), MaskedNDArray[np.intp])
158169
assert_type(MAR_1d.argpartition(1, axis=0, kind='introselect', order='K'), MaskedNDArray[np.intp])
170+
171+
assert_type(MAR_f4 >= 3, MaskedNDArray[np.bool])
172+
assert_type(MAR_i8 >= AR_td64, MaskedNDArray[np.bool])
173+
assert_type(MAR_b >= AR_td64, MaskedNDArray[np.bool])
174+
assert_type(MAR_td64 >= AR_td64, MaskedNDArray[np.bool])
175+
assert_type(MAR_dt64 >= AR_dt64, MaskedNDArray[np.bool])
176+
assert_type(MAR_o >= AR_o, MaskedNDArray[np.bool])
177+
assert_type(MAR_1d >= 0, MaskedNDArray[np.bool])
178+
assert_type(MAR_s >= MAR_s, MaskedNDArray[np.bool])
179+
assert_type(MAR_byte >= MAR_byte, MaskedNDArray[np.bool])
180+
181+
assert_type(MAR_f4 > 3, MaskedNDArray[np.bool])
182+
assert_type(MAR_i8 > AR_td64, MaskedNDArray[np.bool])
183+
assert_type(MAR_b > AR_td64, MaskedNDArray[np.bool])
184+
assert_type(MAR_td64 > AR_td64, MaskedNDArray[np.bool])
185+
assert_type(MAR_dt64 > AR_dt64, MaskedNDArray[np.bool])
186+
assert_type(MAR_o > AR_o, MaskedNDArray[np.bool])
187+
assert_type(MAR_1d > 0, MaskedNDArray[np.bool])
188+
assert_type(MAR_s > MAR_s, MaskedNDArray[np.bool])
189+
assert_type(MAR_byte > MAR_byte, MaskedNDArray[np.bool])
190+
191+
assert_type(MAR_f4 <= 3, MaskedNDArray[np.bool])
192+
assert_type(MAR_i8 <= AR_td64, MaskedNDArray[np.bool])
193+
assert_type(MAR_b <= AR_td64, MaskedNDArray[np.bool])
194+
assert_type(MAR_td64 <= AR_td64, MaskedNDArray[np.bool])
195+
assert_type(MAR_dt64 <= AR_dt64, MaskedNDArray[np.bool])
196+
assert_type(MAR_o <= AR_o, MaskedNDArray[np.bool])
197+
assert_type(MAR_1d <= 0, MaskedNDArray[np.bool])
198+
assert_type(MAR_s <= MAR_s, MaskedNDArray[np.bool])
199+
assert_type(MAR_byte <= MAR_byte, MaskedNDArray[np.bool])
200+
201+
assert_type(MAR_f4 < 3, MaskedNDArray[np.bool])
202+
assert_type(MAR_i8 < AR_td64, MaskedNDArray[np.bool])
203+
assert_type(MAR_b < AR_td64, MaskedNDArray[np.bool])
204+
assert_type(MAR_td64 < AR_td64, MaskedNDArray[np.bool])
205+
assert_type(MAR_dt64 < AR_dt64, MaskedNDArray[np.bool])
206+
assert_type(MAR_o < AR_o, MaskedNDArray[np.bool])
207+
assert_type(MAR_1d < 0, MaskedNDArray[np.bool])
208+
assert_type(MAR_s < MAR_s, MaskedNDArray[np.bool])
209+
assert_type(MAR_byte < MAR_byte, MaskedNDArray[np.bool])
210+
211+
assert_type(MAR_f4 <= 3, MaskedNDArray[np.bool])
212+
assert_type(MAR_i8 <= AR_td64, MaskedNDArray[np.bool])
213+
assert_type(MAR_b <= AR_td64, MaskedNDArray[np.bool])
214+
assert_type(MAR_td64 <= AR_td64, MaskedNDArray[np.bool])
215+
assert_type(MAR_dt64 <= AR_dt64, MaskedNDArray[np.bool])
216+
assert_type(MAR_o <= AR_o, MaskedNDArray[np.bool])
217+
assert_type(MAR_1d <= 0, MaskedNDArray[np.bool])
218+
assert_type(MAR_s <= MAR_s, MaskedNDArray[np.bool])
219+
assert_type(MAR_byte <= MAR_byte, MaskedNDArray[np.bool])
220+

0 commit comments

Comments
 (0)