Skip to content

ENH,API: Add a protocol for representing nested sequences #18155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions doc/release/upcoming_changes/18155.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
Added a protocol for representing nested sequences
--------------------------------------------------

The new `~numpy.typing.NestedSequence` protocol has been added to `numpy.typing`.
Per its name, the protocol can be used in static type checking for
representing `sequences <collections.abc.Sequence>` with arbitrary levels of nesting.

For example:

.. code-block:: python

from __future__ import annotations
from typing import Any, List, TYPE_CHECKING
import numpy as np
import numpy.typing as npt

def get_dtype(seq: npt.NestedSequence[int]) -> np.dtype[np.int_]:
return np.asarray(seq).dtype

a = get_dtype([1])
b = get_dtype([[1]])
c = get_dtype([[[1]]])
d = get_dtype([[[[1]]]])

if TYPE_CHECKING:
reveal_locals()
# note: Revealed local types are:
# note: a: numpy.dtype[numpy.signedinteger[numpy.typing._64Bit]]
# note: b: numpy.dtype[numpy.signedinteger[numpy.typing._64Bit]]
# note: c: numpy.dtype[numpy.signedinteger[numpy.typing._64Bit]]
# note: d: numpy.dtype[numpy.signedinteger[numpy.typing._64Bit]]
12 changes: 9 additions & 3 deletions numpy/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class _8Bit(_16Bit): ... # type: ignore[misc]
# Clean up the namespace
del TYPE_CHECKING, final, List

from ._nested_sequence import NestedSequence
from ._nbit import (
_NBitByte,
_NBitShort,
Expand Down Expand Up @@ -299,7 +300,6 @@ class _8Bit(_16Bit): ... # type: ignore[misc]
from ._array_like import (
ArrayLike,
_ArrayLike,
_NestedSequence,
_SupportsArray,
_ArrayLikeBool,
_ArrayLikeUInt,
Expand All @@ -315,10 +315,16 @@ class _8Bit(_16Bit): ... # type: ignore[misc]
)

if __doc__ is not None:
import textwrap
from ._add_docstring import _docstrings
__doc__ += _docstrings
__doc__ += '\n.. autoclass:: numpy.typing.NBitBase\n'
del _docstrings
__doc__ += textwrap.dedent("""
.. autoclass:: numpy.typing.NBitBase
.. autoclass:: numpy.typing.NestedSequence
:members: __init__
:exclude-members: __init__
""")
del textwrap, _docstrings

from numpy._pytesttester import PytestTester
test = PytestTester(__name__)
Expand Down
42 changes: 21 additions & 21 deletions numpy/typing/_array_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
str_,
bytes_,
)

from ._nested_sequence import NestedSequence
from ._dtype_like import DTypeLike

if sys.version_info >= (3, 8):
Expand All @@ -45,22 +47,14 @@ def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ...
else:
_SupportsArray = Any

# TODO: Wait for support for recursive types
_NestedSequence = Union[
_T,
Sequence[_T],
Sequence[Sequence[_T]],
Sequence[Sequence[Sequence[_T]]],
Sequence[Sequence[Sequence[Sequence[_T]]]],
]
_RecursiveSequence = Sequence[Sequence[Sequence[Sequence[Sequence[Any]]]]]

# A union representing array-like objects; consists of two typevars:
# One representing types that can be parametrized w.r.t. `np.dtype`
# and another one for the rest
_ArrayLike = Union[
_NestedSequence[_SupportsArray[_DType]],
_NestedSequence[_T],
_SupportsArray[_DType],
NestedSequence[_SupportsArray[_DType]],
_T,
NestedSequence[_T],
]

# TODO: support buffer protocols once
Expand All @@ -70,12 +64,9 @@ def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ...
# is resolved. See also the mypy issue:
#
# https://github.com/python/typing/issues/593
ArrayLike = Union[
_RecursiveSequence,
_ArrayLike[
"dtype[Any]",
Union[bool, int, float, complex, str, bytes]
],
ArrayLike = _ArrayLike[
"dtype[Any]",
Union[bool, int, float, complex, str, bytes],
]

# `ArrayLike<X>`: array-like objects that can be coerced into `X`
Expand Down Expand Up @@ -104,10 +95,19 @@ def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ...
"dtype[Union[bool_, integer[Any], timedelta64]]",
Union[bool, int],
]
_ArrayLikeDT64 = _NestedSequence[_SupportsArray["dtype[datetime64]"]]
_ArrayLikeObject = _NestedSequence[_SupportsArray["dtype[object_]"]]
_ArrayLikeDT64 = Union[
_SupportsArray["dtype[datetime64]"],
NestedSequence[_SupportsArray["dtype[datetime64]"]],
]
_ArrayLikeObject = Union[
_SupportsArray["dtype[object_]"],
NestedSequence[_SupportsArray["dtype[object_]"]],
]

_ArrayLikeVoid = _NestedSequence[_SupportsArray["dtype[void]"]]
_ArrayLikeVoid = Union[
_SupportsArray["dtype[void]"],
NestedSequence[_SupportsArray["dtype[void]"]],
]
_ArrayLikeStr = _ArrayLike[
"dtype[str_]",
str,
Expand Down
169 changes: 169 additions & 0 deletions numpy/typing/_nested_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""A module containing the `NestedSequence` protocol."""

import sys
from abc import abstractmethod, ABCMeta
from collections.abc import Sequence
from typing import (
Any,
Callable,
Generic,
Iterable,
Iterator,
overload,
TYPE_CHECKING,
TypeVar,
Union,
)

import numpy as np

if sys.version_info >= (3, 8):
from typing import Protocol, runtime_checkable
HAVE_PROTOCOL = True
else:
try:
from typing_extensions import Protocol, runtime_checkable
except ImportError:
HAVE_PROTOCOL = False
else:
HAVE_PROTOCOL = True

__all__ = ["NestedSequence"]

_TT = TypeVar("_TT", bound=type)
_T_co = TypeVar("_T_co", covariant=True)

_SeqOrScalar = Union[_T_co, "NestedSequence[_T_co]"]

_NBitInt = f"_{8 * np.int_().itemsize}Bit"
_DOC = f"""A protocol for representing nested sequences.

Runtime usage of the protocol requires either Python >= 3.8 or
the typing-extensions_ package.

.. _typing-extensions: https://pypi.org/project/typing-extensions/

See Also
--------
:class:`collections.abc.Sequence`
ABCs for read-only and mutable :term:`sequences<sequence>`.

Examples
--------
.. code-block:: python

>>> from __future__ import annotations
>>> from typing import Any, List, TYPE_CHECKING
>>> import numpy as np
>>> import numpy.typing as npt

>>> def get_dtype(seq: npt.NestedSequence[int]) -> np.dtype[np.int_]:
... return np.asarray(seq).dtype

>>> a = get_dtype([1])
>>> b = get_dtype([[1]])
>>> c = get_dtype([[[1]]])
>>> d = get_dtype([[[[1]]]])

>>> if TYPE_CHECKING:
... reveal_locals()
... # note: Revealed local types are:
... # note: a: numpy.dtype[numpy.signedinteger[numpy.typing.{_NBitInt}]]
... # note: b: numpy.dtype[numpy.signedinteger[numpy.typing.{_NBitInt}]]
... # note: c: numpy.dtype[numpy.signedinteger[numpy.typing.{_NBitInt}]]
... # note: d: numpy.dtype[numpy.signedinteger[numpy.typing.{_NBitInt}]]

"""


def _set_module_and_doc(module: str, doc: str) -> Callable[[_TT], _TT]:
"""A decorator for setting ``__module__`` and `__doc__`."""
def decorator(func):
func.__module__ = module
func.__doc__ = doc
return func
return decorator


# E: Error message for `_ProtocolMeta` and `_ProtocolMixin`
_ERR_MSG = (
"runtime usage of `NestedSequence` requires "
"either typing-extensions or Python >= 3.8"
)


class _ProtocolMeta(ABCMeta):
"""Metaclass of `_ProtocolMixin`."""

def __instancecheck__(self, params):
raise RuntimeError(_ERR_MSG)

def __subclasscheck__(self, params):
raise RuntimeError(_ERR_MSG)


class _ProtocolMixin(Generic[_T_co], metaclass=_ProtocolMeta):
"""A mixin that raises upon executing methods that require `typing.Protocol`."""

__slots__ = ()

def __init__(self):
raise RuntimeError(_ERR_MSG)

def __init_subclass__(cls):
if cls is not NestedSequence:
raise RuntimeError(_ERR_MSG)
super().__init_subclass__()


# Plan B in case `typing.Protocol` is unavailable.
#
# A `RuntimeError` will be raised if one attempts to execute
# methods that absolutelly require `typing.Protocol`.
if not TYPE_CHECKING and not HAVE_PROTOCOL:
Protocol = _ProtocolMixin


@_set_module_and_doc("numpy.typing", doc=_DOC)
@runtime_checkable
class NestedSequence(Protocol[_T_co]):
if not TYPE_CHECKING:
__slots__ = ()

# Can't directly inherit from `collections.abc.Sequence`
# (as it is not a Protocol), but we can forward to its' methods
def __contains__(self, x: object) -> bool:
"""Return ``x in self``."""
return Sequence.__contains__(self, x) # type: ignore[operator]

@overload
@abstractmethod
def __getitem__(self, i: int) -> _SeqOrScalar[_T_co]: ...
@overload
@abstractmethod
def __getitem__(self, s: slice) -> "NestedSequence[_T_co]": ...
@abstractmethod
def __getitem__(self, s):
"""Return ``self[s]``."""
raise NotImplementedError("Trying to call an abstract method")

def __iter__(self) -> Iterator[_SeqOrScalar[_T_co]]:
"""Return ``iter(self)``."""
return Sequence.__iter__(self) # type: ignore[arg-type]

@abstractmethod
def __len__(self) -> int:
"""Return ``len(self)``."""
raise NotImplementedError("Trying to call an abstract method")

def __reversed__(self) -> Iterator[_SeqOrScalar[_T_co]]:
"""Return ``reversed(self)``."""
return Sequence.__reversed__(self) # type: ignore[arg-type]

def count(self, value: Any) -> int:
"""Return the number of occurrences of `value`."""
return Sequence.count(self, value) # type: ignore[arg-type]

def index(self, value: Any, start: int = 0, stop: int = sys.maxsize) -> int:
"""Return the first index of `value`."""
return Sequence.index(self, value, start, stop) # type: ignore[arg-type]
17 changes: 17 additions & 0 deletions numpy/typing/tests/data/fail/nested_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Sequence, Tuple, List
import numpy.typing as npt

a: Sequence[float]
b: List[complex]
c: Tuple[str, ...]
d: int
e: str

def func(a: npt.NestedSequence[int]) -> None:
...

reveal_type(func(a)) # E: incompatible type
reveal_type(func(b)) # E: incompatible type
reveal_type(func(c)) # E: incompatible type
reveal_type(func(d)) # E: incompatible type
reveal_type(func(e)) # E: incompatible type
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,12 @@ def add(a: np.floating[T], b: np.integer[T]) -> np.floating[T]:
reveal_type(add(f4, i8)) # E: {float64}
reveal_type(add(f8, i4)) # E: {float64}
reveal_type(add(f4, i4)) # E: {float32}


def get_dtype(seq: npt.NestedSequence[int]) -> np.dtype[np.int_]:
return np.asarray(seq).dtype


reveal_type(get_dtype([1])) # E: numpy.dtype[{int_}]
reveal_type(get_dtype([[1]])) # E: numpy.dtype[{int_}]
reveal_type(get_dtype([[[1]]])) # E: numpy.dtype[{int_}]
25 changes: 25 additions & 0 deletions numpy/typing/tests/data/reveal/nested_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Sequence, Tuple, List, Any
import numpy.typing as npt

a: Sequence[int]
b: Sequence[Sequence[int]]
c: Sequence[Sequence[Sequence[int]]]
d: Sequence[Sequence[Sequence[Sequence[int]]]]
e: Sequence[bool]
f: Tuple[int, ...]
g: List[int]
h: Sequence[Any]

def func(a: npt.NestedSequence[int]) -> None:
...

reveal_type(func(a)) # E: None
reveal_type(func(b)) # E: None
reveal_type(func(c)) # E: None
reveal_type(func(d)) # E: None
reveal_type(func(e)) # E: None
reveal_type(func(f)) # E: None
reveal_type(func(g)) # E: None
reveal_type(func(h)) # E: None

reveal_type(isinstance(1, npt.NestedSequence)) # E: bool
Loading