Skip to content

Commit c75757c

Browse files
authored
TYP: Type MaskedArray.count and np.ma.count (numpy#28735)
1 parent fec8624 commit c75757c

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

numpy/ma/core.pyi

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ from numpy import (
1919
expand_dims,
2020
float64,
2121
generic,
22+
int_,
2223
intp,
2324
ndarray,
2425
)
@@ -452,7 +453,17 @@ class MaskedArray(ndarray[_ShapeT_co, _DTypeT_co]):
452453
@property # type: ignore[misc]
453454
def real(self): ...
454455
get_real: Any
455-
def count(self, axis=..., keepdims=...): ...
456+
457+
# keep in sync with `np.ma.count`
458+
@overload
459+
def count(self, axis: None = None, keepdims: Literal[False] | _NoValueType = ...) -> int: ...
460+
@overload
461+
def count(self, axis: _ShapeLike, keepdims: bool | _NoValueType = ...) -> NDArray[int_]: ...
462+
@overload
463+
def count(self, axis: _ShapeLike | None = ..., *, keepdims: Literal[True]) -> NDArray[int_]: ...
464+
@overload
465+
def count(self, axis: _ShapeLike | None, keepdims: Literal[True]) -> NDArray[int_]: ...
466+
456467
def ravel(self, order=...): ...
457468
def reshape(self, *s, **kwargs): ...
458469
def resize(self, newshape, refcheck=..., order=...): ...
@@ -949,7 +960,15 @@ sum: _frommethod
949960
swapaxes: _frommethod
950961
trace: _frommethod
951962
var: _frommethod
952-
count: _frommethod
963+
964+
@overload
965+
def count(self: ArrayLike, axis: None = None, keepdims: Literal[False] | _NoValueType = ...) -> int: ...
966+
@overload
967+
def count(self: ArrayLike, axis: _ShapeLike, keepdims: bool | _NoValueType = ...) -> NDArray[int_]: ...
968+
@overload
969+
def count(self: ArrayLike, axis: _ShapeLike | None = ..., *, keepdims: Literal[True]) -> NDArray[int_]: ...
970+
@overload
971+
def count(self: ArrayLike, axis: _ShapeLike | None, keepdims: Literal[True]) -> NDArray[int_]: ...
953972

954973
@overload
955974
def argmin(

numpy/typing/tests/data/fail/ma.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,7 @@ m > (lambda x: 'mango') # E: No overload variant
104104
m <= (lambda x: 'mango') # E: No overload variant
105105

106106
m < (lambda x: 'mango') # E: No overload variant
107+
108+
m.count(axis=0.) # E: No overload variant
109+
110+
np.ma.count(m, axis=0.) # E: No overload variant

numpy/typing/tests/data/reveal/ma.pyi

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,18 @@ assert_type(MAR_1d <= 0, MaskedNDArray[np.bool])
232232
assert_type(MAR_s <= MAR_s, MaskedNDArray[np.bool])
233233
assert_type(MAR_byte <= MAR_byte, MaskedNDArray[np.bool])
234234

235+
assert_type(MAR_byte.count(), int)
236+
assert_type(MAR_f4.count(axis=None), int)
237+
assert_type(MAR_f4.count(axis=0), NDArray[np.int_])
238+
assert_type(MAR_b.count(axis=(0,1)), NDArray[np.int_])
239+
assert_type(MAR_o.count(keepdims=True), NDArray[np.int_])
240+
assert_type(MAR_o.count(axis=None, keepdims=True), NDArray[np.int_])
241+
assert_type(MAR_o.count(None, True), NDArray[np.int_])
242+
243+
assert_type(np.ma.count(MAR_byte), int)
244+
assert_type(np.ma.count(MAR_byte, axis=None), int)
245+
assert_type(np.ma.count(MAR_f4, axis=0), NDArray[np.int_])
246+
assert_type(np.ma.count(MAR_b, axis=(0,1)), NDArray[np.int_])
247+
assert_type(np.ma.count(MAR_o, keepdims=True), NDArray[np.int_])
248+
assert_type(np.ma.count(MAR_o, axis=None, keepdims=True), NDArray[np.int_])
249+
assert_type(np.ma.count(MAR_o, None, True), NDArray[np.int_])

0 commit comments

Comments
 (0)