Skip to content

MAINT: Misc np.array_api annotation fixes #19969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions numpy/array_api/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def eye(
n_cols: Optional[int] = None,
/,
*,
k: Optional[int] = 0,
k: int = 0,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
) -> Array:
Expand Down Expand Up @@ -232,7 +232,7 @@ def linspace(
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))


def meshgrid(*arrays: Sequence[Array], indexing: str = "xy") -> List[Array, ...]:
def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.

Expand Down
4 changes: 2 additions & 2 deletions numpy/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np


def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]:
def broadcast_arrays(*arrays: Array) -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.

Expand Down Expand Up @@ -98,7 +98,7 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
return iinfo_object(ii.bits, ii.max, ii.min)


def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype:
def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
"""
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.

Expand Down
30 changes: 23 additions & 7 deletions numpy/array_api/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"PyCapsule",
]

from typing import Any, Literal, Sequence, Type, Union
import sys
from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar

from . import (
Array,
from ._array_object import Array
from numpy import (
dtype,
int8,
int16,
int32,
Expand All @@ -33,12 +35,26 @@

# This should really be recursive, but that isn't supported yet. See the
# similar comment in numpy/typing/_array_like.py
NestedSequence = Sequence[Sequence[Any]]
_T = TypeVar("_T")
NestedSequence = Sequence[Sequence[_T]]
Comment on lines +38 to +39
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NestedSequence needs a free (subscriptable) parameter based on how it's used in np.array_api.asarray:

NestedSequence[bool | int | float],

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off topic: mypy has some basic support for recursive types these days,
so it is now possible to define a proper arbitrary-nested sequence type with the help of protocols (xref #19894)


Device = Literal["cpu"]
Dtype = Type[
Union[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64]
]
if TYPE_CHECKING or sys.version_info >= (3, 9):
Dtype = dtype[Union[
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float32,
float64,
]]
else:
Dtype = dtype

SupportsDLPack = Any
SupportsBufferProtocol = Any
PyCapsule = Any