Skip to content

Commit bac7b36

Browse files
authored
Merge pull request #28650 from charris/backport-28644
TYP: fix `ndarray.tolist()` and `.item()` for unknown dtype
2 parents 9470412 + 0a7f819 commit bac7b36

File tree

2 files changed

+23
-43
lines changed

2 files changed

+23
-43
lines changed

numpy/__init__.pyi

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,35 +1093,11 @@ class _SupportsItem(Protocol[_T_co]):
10931093
class _SupportsDLPack(Protocol[_T_contra]):
10941094
def __dlpack__(self, /, *, stream: _T_contra | None = None) -> CapsuleType: ...
10951095

1096-
@type_check_only
1097-
class _HasShape(Protocol[_ShapeT_co]):
1098-
@property
1099-
def shape(self, /) -> _ShapeT_co: ...
1100-
11011096
@type_check_only
11021097
class _HasDType(Protocol[_T_co]):
11031098
@property
11041099
def dtype(self, /) -> _T_co: ...
11051100

1106-
@type_check_only
1107-
class _HasShapeAndSupportsItem(_HasShape[_ShapeT_co], _SupportsItem[_T_co], Protocol[_ShapeT_co, _T_co]): ...
1108-
1109-
# matches any `x` on `x.type.item() -> _T_co`, e.g. `dtype[np.int8]` gives `_T_co: int`
1110-
@type_check_only
1111-
class _HasTypeWithItem(Protocol[_T_co]):
1112-
@property
1113-
def type(self, /) -> type[_SupportsItem[_T_co]]: ...
1114-
1115-
# matches any `x` on `x.shape: _ShapeT_co` and `x.dtype.type.item() -> _T_co`,
1116-
# useful for capturing the item-type (`_T_co`) of the scalar-type of an array with
1117-
# specific shape (`_ShapeT_co`).
1118-
@type_check_only
1119-
class _HasShapeAndDTypeWithItem(Protocol[_ShapeT_co, _T_co]):
1120-
@property
1121-
def shape(self, /) -> _ShapeT_co: ...
1122-
@property
1123-
def dtype(self, /) -> _HasTypeWithItem[_T_co]: ...
1124-
11251101
@type_check_only
11261102
class _HasRealAndImag(Protocol[_RealT_co, _ImagT_co]):
11271103
@property
@@ -2204,29 +2180,26 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DType_co]):
22042180
@property
22052181
def flat(self) -> flatiter[Self]: ...
22062182

2207-
@overload # special casing for `StringDType`, which has no scalar type
2208-
def item(self: ndarray[Any, dtypes.StringDType], /) -> str: ...
2209-
@overload
2210-
def item(self: ndarray[Any, dtypes.StringDType], arg0: SupportsIndex | tuple[SupportsIndex, ...] = ..., /) -> str: ...
2211-
@overload
2212-
def item(self: ndarray[Any, dtypes.StringDType], /, *args: SupportsIndex) -> str: ...
22132183
@overload # use the same output type as that of the underlying `generic`
2214-
def item(self: _HasShapeAndDTypeWithItem[Any, _T], /) -> _T: ...
2215-
@overload
2216-
def item(self: _HasShapeAndDTypeWithItem[Any, _T], arg0: SupportsIndex | tuple[SupportsIndex, ...] = ..., /) -> _T: ...
2217-
@overload
2218-
def item(self: _HasShapeAndDTypeWithItem[Any, _T], /, *args: SupportsIndex) -> _T: ...
2184+
def item(self: NDArray[generic[_T]], i0: SupportsIndex | tuple[SupportsIndex, ...] = ..., /, *args: SupportsIndex) -> _T: ...
2185+
@overload # special casing for `StringDType`, which has no scalar type
2186+
def item(
2187+
self: ndarray[Any, dtypes.StringDType],
2188+
arg0: SupportsIndex | tuple[SupportsIndex, ...] = ...,
2189+
/,
2190+
*args: SupportsIndex,
2191+
) -> str: ...
22192192

22202193
@overload
2221-
def tolist(self: _HasShapeAndSupportsItem[tuple[()], _T], /) -> _T: ...
2194+
def tolist(self: ndarray[tuple[()], dtype[generic[_T]]], /) -> _T: ...
22222195
@overload
2223-
def tolist(self: _HasShapeAndSupportsItem[tuple[int], _T], /) -> list[_T]: ...
2196+
def tolist(self: ndarray[tuple[int], dtype[generic[_T]]], /) -> list[_T]: ...
22242197
@overload
2225-
def tolist(self: _HasShapeAndSupportsItem[tuple[int, int], _T], /) -> list[list[_T]]: ...
2198+
def tolist(self: ndarray[tuple[int, int], dtype[generic[_T]]], /) -> list[list[_T]]: ...
22262199
@overload
2227-
def tolist(self: _HasShapeAndSupportsItem[tuple[int, int, int], _T], /) -> list[list[list[_T]]]: ...
2200+
def tolist(self: ndarray[tuple[int, int, int], dtype[generic[_T]]], /) -> list[list[list[_T]]]: ...
22282201
@overload
2229-
def tolist(self: _HasShapeAndSupportsItem[Any, _T], /) -> _T | list[_T] | list[list[_T]] | list[list[list[Any]]]: ...
2202+
def tolist(self, /) -> Any: ...
22302203

22312204
@overload
22322205
def resize(self, new_shape: _ShapeLike, /, *, refcheck: builtins.bool = ...) -> None: ...
@@ -5379,7 +5352,7 @@ class matrix(ndarray[_2DShapeT_co, _DType_co]):
53795352
def ptp(self, axis: None | _ShapeLike = ..., out: _ArrayT = ...) -> _ArrayT: ...
53805353

53815354
def squeeze(self, axis: None | _ShapeLike = ...) -> matrix[_2D, _DType_co]: ...
5382-
def tolist(self: _SupportsItem[_T]) -> list[list[_T]]: ...
5355+
def tolist(self: matrix[Any, dtype[generic[_T]]]) -> list[list[_T]]: ... # pyright: ignore[reportIncompatibleMethodOverride]
53835356
def ravel(self, /, order: _OrderKACF = "C") -> matrix[tuple[L[1], int], _DType_co]: ... # pyright: ignore[reportIncompatibleMethodOverride]
53845357
def flatten(self, /, order: _OrderKACF = "C") -> matrix[tuple[L[1], int], _DType_co]: ... # pyright: ignore[reportIncompatibleMethodOverride]
53855358

numpy/typing/tests/data/reveal/ndarray_conversion.pyi

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,15 @@ assert_type(b1_0d.tolist(), bool)
3030
assert_type(u2_1d.tolist(), list[int])
3131
assert_type(i4_2d.tolist(), list[list[int]])
3232
assert_type(f8_3d.tolist(), list[list[list[float]]])
33-
assert_type(cG_4d.tolist(), complex | list[complex] | list[list[complex]] | list[list[list[Any]]])
34-
assert_type(i0_nd.tolist(), int | list[int] | list[list[int]] | list[list[list[Any]]])
33+
assert_type(cG_4d.tolist(), Any)
34+
assert_type(i0_nd.tolist(), Any)
35+
36+
# regression tests for numpy/numpy#27944
37+
any_dtype: np.ndarray[Any, Any]
38+
any_sctype: np.ndarray[Any, Any]
39+
assert_type(any_dtype.tolist(), Any)
40+
assert_type(any_sctype.tolist(), Any)
41+
3542

3643
# itemset does not return a value
3744
# tostring is pretty simple

0 commit comments

Comments
 (0)