diff --git a/doc/release/upcoming_changes/18155.new_feature.rst b/doc/release/upcoming_changes/18155.new_feature.rst new file mode 100644 index 000000000000..23495fda7838 --- /dev/null +++ b/doc/release/upcoming_changes/18155.new_feature.rst @@ -0,0 +1,31 @@ +Added a protocol for representing nested sequences +-------------------------------------------------- + +The new `~numpy.typing.NestedSequence` protocol has been added to `numpy.typing`. +Per its name, the protocol can be used in static type checking for +representing `sequences ` with arbitrary levels of nesting. + +For example: + +.. code-block:: python + + from __future__ import annotations + from typing import Any, List, TYPE_CHECKING + import numpy as np + import numpy.typing as npt + + def get_dtype(seq: npt.NestedSequence[int]) -> np.dtype[np.int_]: + return np.asarray(seq).dtype + + a = get_dtype([1]) + b = get_dtype([[1]]) + c = get_dtype([[[1]]]) + d = get_dtype([[[[1]]]]) + + if TYPE_CHECKING: + reveal_locals() + # note: Revealed local types are: + # note: a: numpy.dtype[numpy.signedinteger[numpy.typing._64Bit]] + # note: b: numpy.dtype[numpy.signedinteger[numpy.typing._64Bit]] + # note: c: numpy.dtype[numpy.signedinteger[numpy.typing._64Bit]] + # note: d: numpy.dtype[numpy.signedinteger[numpy.typing._64Bit]] diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py index 86bba57be326..6ff8296c3cec 100644 --- a/numpy/typing/__init__.py +++ b/numpy/typing/__init__.py @@ -229,6 +229,7 @@ class _8Bit(_16Bit): ... # type: ignore[misc] # Clean up the namespace del TYPE_CHECKING, final, List +from ._nested_sequence import NestedSequence from ._nbit import ( _NBitByte, _NBitShort, @@ -299,7 +300,6 @@ class _8Bit(_16Bit): ... # type: ignore[misc] from ._array_like import ( ArrayLike, _ArrayLike, - _NestedSequence, _SupportsArray, _ArrayLikeBool, _ArrayLikeUInt, @@ -315,10 +315,16 @@ class _8Bit(_16Bit): ... # type: ignore[misc] ) if __doc__ is not None: + import textwrap from ._add_docstring import _docstrings __doc__ += _docstrings - __doc__ += '\n.. autoclass:: numpy.typing.NBitBase\n' - del _docstrings + __doc__ += textwrap.dedent(""" + .. autoclass:: numpy.typing.NBitBase + .. autoclass:: numpy.typing.NestedSequence + :members: __init__ + :exclude-members: __init__ + """) + del textwrap, _docstrings from numpy._pytesttester import PytestTester test = PytestTester(__name__) diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py index d6473442c37b..e5ec627e93d1 100644 --- a/numpy/typing/_array_like.py +++ b/numpy/typing/_array_like.py @@ -19,6 +19,8 @@ str_, bytes_, ) + +from ._nested_sequence import NestedSequence from ._dtype_like import DTypeLike if sys.version_info >= (3, 8): @@ -45,22 +47,14 @@ def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ... else: _SupportsArray = Any -# TODO: Wait for support for recursive types -_NestedSequence = Union[ - _T, - Sequence[_T], - Sequence[Sequence[_T]], - Sequence[Sequence[Sequence[_T]]], - Sequence[Sequence[Sequence[Sequence[_T]]]], -] -_RecursiveSequence = Sequence[Sequence[Sequence[Sequence[Sequence[Any]]]]] - # A union representing array-like objects; consists of two typevars: # One representing types that can be parametrized w.r.t. `np.dtype` # and another one for the rest _ArrayLike = Union[ - _NestedSequence[_SupportsArray[_DType]], - _NestedSequence[_T], + _SupportsArray[_DType], + NestedSequence[_SupportsArray[_DType]], + _T, + NestedSequence[_T], ] # TODO: support buffer protocols once @@ -70,12 +64,9 @@ def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ... # is resolved. See also the mypy issue: # # https://github.com/python/typing/issues/593 -ArrayLike = Union[ - _RecursiveSequence, - _ArrayLike[ - "dtype[Any]", - Union[bool, int, float, complex, str, bytes] - ], +ArrayLike = _ArrayLike[ + "dtype[Any]", + Union[bool, int, float, complex, str, bytes], ] # `ArrayLike`: array-like objects that can be coerced into `X` @@ -104,10 +95,19 @@ def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ... "dtype[Union[bool_, integer[Any], timedelta64]]", Union[bool, int], ] -_ArrayLikeDT64 = _NestedSequence[_SupportsArray["dtype[datetime64]"]] -_ArrayLikeObject = _NestedSequence[_SupportsArray["dtype[object_]"]] +_ArrayLikeDT64 = Union[ + _SupportsArray["dtype[datetime64]"], + NestedSequence[_SupportsArray["dtype[datetime64]"]], +] +_ArrayLikeObject = Union[ + _SupportsArray["dtype[object_]"], + NestedSequence[_SupportsArray["dtype[object_]"]], +] -_ArrayLikeVoid = _NestedSequence[_SupportsArray["dtype[void]"]] +_ArrayLikeVoid = Union[ + _SupportsArray["dtype[void]"], + NestedSequence[_SupportsArray["dtype[void]"]], +] _ArrayLikeStr = _ArrayLike[ "dtype[str_]", str, diff --git a/numpy/typing/_nested_sequence.py b/numpy/typing/_nested_sequence.py new file mode 100644 index 000000000000..4494e37380b8 --- /dev/null +++ b/numpy/typing/_nested_sequence.py @@ -0,0 +1,169 @@ +"""A module containing the `NestedSequence` protocol.""" + +import sys +from abc import abstractmethod, ABCMeta +from collections.abc import Sequence +from typing import ( + Any, + Callable, + Generic, + Iterable, + Iterator, + overload, + TYPE_CHECKING, + TypeVar, + Union, +) + +import numpy as np + +if sys.version_info >= (3, 8): + from typing import Protocol, runtime_checkable + HAVE_PROTOCOL = True +else: + try: + from typing_extensions import Protocol, runtime_checkable + except ImportError: + HAVE_PROTOCOL = False + else: + HAVE_PROTOCOL = True + +__all__ = ["NestedSequence"] + +_TT = TypeVar("_TT", bound=type) +_T_co = TypeVar("_T_co", covariant=True) + +_SeqOrScalar = Union[_T_co, "NestedSequence[_T_co]"] + +_NBitInt = f"_{8 * np.int_().itemsize}Bit" +_DOC = f"""A protocol for representing nested sequences. + + Runtime usage of the protocol requires either Python >= 3.8 or + the typing-extensions_ package. + + .. _typing-extensions: https://pypi.org/project/typing-extensions/ + + See Also + -------- + :class:`collections.abc.Sequence` + ABCs for read-only and mutable :term:`sequences`. + + Examples + -------- + .. code-block:: python + + >>> from __future__ import annotations + >>> from typing import Any, List, TYPE_CHECKING + >>> import numpy as np + >>> import numpy.typing as npt + + >>> def get_dtype(seq: npt.NestedSequence[int]) -> np.dtype[np.int_]: + ... return np.asarray(seq).dtype + + >>> a = get_dtype([1]) + >>> b = get_dtype([[1]]) + >>> c = get_dtype([[[1]]]) + >>> d = get_dtype([[[[1]]]]) + + >>> if TYPE_CHECKING: + ... reveal_locals() + ... # note: Revealed local types are: + ... # note: a: numpy.dtype[numpy.signedinteger[numpy.typing.{_NBitInt}]] + ... # note: b: numpy.dtype[numpy.signedinteger[numpy.typing.{_NBitInt}]] + ... # note: c: numpy.dtype[numpy.signedinteger[numpy.typing.{_NBitInt}]] + ... # note: d: numpy.dtype[numpy.signedinteger[numpy.typing.{_NBitInt}]] + +""" + + +def _set_module_and_doc(module: str, doc: str) -> Callable[[_TT], _TT]: + """A decorator for setting ``__module__`` and `__doc__`.""" + def decorator(func): + func.__module__ = module + func.__doc__ = doc + return func + return decorator + + +# E: Error message for `_ProtocolMeta` and `_ProtocolMixin` +_ERR_MSG = ( + "runtime usage of `NestedSequence` requires " + "either typing-extensions or Python >= 3.8" +) + + +class _ProtocolMeta(ABCMeta): + """Metaclass of `_ProtocolMixin`.""" + + def __instancecheck__(self, params): + raise RuntimeError(_ERR_MSG) + + def __subclasscheck__(self, params): + raise RuntimeError(_ERR_MSG) + + +class _ProtocolMixin(Generic[_T_co], metaclass=_ProtocolMeta): + """A mixin that raises upon executing methods that require `typing.Protocol`.""" + + __slots__ = () + + def __init__(self): + raise RuntimeError(_ERR_MSG) + + def __init_subclass__(cls): + if cls is not NestedSequence: + raise RuntimeError(_ERR_MSG) + super().__init_subclass__() + + +# Plan B in case `typing.Protocol` is unavailable. +# +# A `RuntimeError` will be raised if one attempts to execute +# methods that absolutelly require `typing.Protocol`. +if not TYPE_CHECKING and not HAVE_PROTOCOL: + Protocol = _ProtocolMixin + + +@_set_module_and_doc("numpy.typing", doc=_DOC) +@runtime_checkable +class NestedSequence(Protocol[_T_co]): + if not TYPE_CHECKING: + __slots__ = () + + # Can't directly inherit from `collections.abc.Sequence` + # (as it is not a Protocol), but we can forward to its' methods + def __contains__(self, x: object) -> bool: + """Return ``x in self``.""" + return Sequence.__contains__(self, x) # type: ignore[operator] + + @overload + @abstractmethod + def __getitem__(self, i: int) -> _SeqOrScalar[_T_co]: ... + @overload + @abstractmethod + def __getitem__(self, s: slice) -> "NestedSequence[_T_co]": ... + @abstractmethod + def __getitem__(self, s): + """Return ``self[s]``.""" + raise NotImplementedError("Trying to call an abstract method") + + def __iter__(self) -> Iterator[_SeqOrScalar[_T_co]]: + """Return ``iter(self)``.""" + return Sequence.__iter__(self) # type: ignore[arg-type] + + @abstractmethod + def __len__(self) -> int: + """Return ``len(self)``.""" + raise NotImplementedError("Trying to call an abstract method") + + def __reversed__(self) -> Iterator[_SeqOrScalar[_T_co]]: + """Return ``reversed(self)``.""" + return Sequence.__reversed__(self) # type: ignore[arg-type] + + def count(self, value: Any) -> int: + """Return the number of occurrences of `value`.""" + return Sequence.count(self, value) # type: ignore[arg-type] + + def index(self, value: Any, start: int = 0, stop: int = sys.maxsize) -> int: + """Return the first index of `value`.""" + return Sequence.index(self, value, start, stop) # type: ignore[arg-type] diff --git a/numpy/typing/tests/data/fail/nested_sequence.py b/numpy/typing/tests/data/fail/nested_sequence.py new file mode 100644 index 000000000000..34cbc6c5581a --- /dev/null +++ b/numpy/typing/tests/data/fail/nested_sequence.py @@ -0,0 +1,17 @@ +from typing import Sequence, Tuple, List +import numpy.typing as npt + +a: Sequence[float] +b: List[complex] +c: Tuple[str, ...] +d: int +e: str + +def func(a: npt.NestedSequence[int]) -> None: + ... + +reveal_type(func(a)) # E: incompatible type +reveal_type(func(b)) # E: incompatible type +reveal_type(func(c)) # E: incompatible type +reveal_type(func(d)) # E: incompatible type +reveal_type(func(e)) # E: incompatible type diff --git a/numpy/typing/tests/data/reveal/nbit_base_example.py b/numpy/typing/tests/data/reveal/examples.py similarity index 60% rename from numpy/typing/tests/data/reveal/nbit_base_example.py rename to numpy/typing/tests/data/reveal/examples.py index 99fb71560a24..b5560aee61b3 100644 --- a/numpy/typing/tests/data/reveal/nbit_base_example.py +++ b/numpy/typing/tests/data/reveal/examples.py @@ -16,3 +16,12 @@ def add(a: np.floating[T], b: np.integer[T]) -> np.floating[T]: reveal_type(add(f4, i8)) # E: {float64} reveal_type(add(f8, i4)) # E: {float64} reveal_type(add(f4, i4)) # E: {float32} + + +def get_dtype(seq: npt.NestedSequence[int]) -> np.dtype[np.int_]: + return np.asarray(seq).dtype + + +reveal_type(get_dtype([1])) # E: numpy.dtype[{int_}] +reveal_type(get_dtype([[1]])) # E: numpy.dtype[{int_}] +reveal_type(get_dtype([[[1]]])) # E: numpy.dtype[{int_}] diff --git a/numpy/typing/tests/data/reveal/nested_sequence.py b/numpy/typing/tests/data/reveal/nested_sequence.py new file mode 100644 index 000000000000..25e697eb589e --- /dev/null +++ b/numpy/typing/tests/data/reveal/nested_sequence.py @@ -0,0 +1,25 @@ +from typing import Sequence, Tuple, List, Any +import numpy.typing as npt + +a: Sequence[int] +b: Sequence[Sequence[int]] +c: Sequence[Sequence[Sequence[int]]] +d: Sequence[Sequence[Sequence[Sequence[int]]]] +e: Sequence[bool] +f: Tuple[int, ...] +g: List[int] +h: Sequence[Any] + +def func(a: npt.NestedSequence[int]) -> None: + ... + +reveal_type(func(a)) # E: None +reveal_type(func(b)) # E: None +reveal_type(func(c)) # E: None +reveal_type(func(d)) # E: None +reveal_type(func(e)) # E: None +reveal_type(func(f)) # E: None +reveal_type(func(g)) # E: None +reveal_type(func(h)) # E: None + +reveal_type(isinstance(1, npt.NestedSequence)) # E: bool diff --git a/numpy/typing/tests/test_nested_sequence.py b/numpy/typing/tests/test_nested_sequence.py new file mode 100644 index 000000000000..351acc6e830c --- /dev/null +++ b/numpy/typing/tests/test_nested_sequence.py @@ -0,0 +1,81 @@ +"""A module with runtime tests for `numpy.typing.NestedSequence`.""" + +import sys +from typing import Callable, Any +from collections.abc import Sequence + +import pytest +import numpy as np +from numpy.typing import NestedSequence +from numpy.typing._nested_sequence import _ProtocolMixin + +if sys.version_info >= (3, 8): + from typing import Protocol + HAVE_PROTOCOL = True +else: + try: + from typing_extensions import Protocol + except ImportError: + HAVE_PROTOCOL = False + else: + HAVE_PROTOCOL = True + +if HAVE_PROTOCOL: + class _SubClass(NestedSequence[int]): + def __init__(self, seq): + self._seq = seq + + def __getitem__(self, s): + return self._seq[s] + + def __len__(self): + return len(self._seq) + + SEQ = _SubClass([0, 0, 1]) +else: + SEQ = NotImplemented + + +class TestNestedSequence: + """Runtime tests for `numpy.typing.NestedSequence`.""" + + @pytest.mark.parametrize( + "name,func", + [ + ("__instancecheck__", lambda: isinstance(1, _ProtocolMixin)), + ("__subclasscheck__", lambda: issubclass(int, _ProtocolMixin)), + ("__init__", lambda: _ProtocolMixin()), + ("__init_subclass__", lambda: type("SubClass", (_ProtocolMixin,), {})), + ] + ) + def test_raises(self, name: str, func: Callable[[], Any]) -> None: + """Test that the `_ProtocolMixin` methods successfully raise.""" + with pytest.raises(RuntimeError): + func() + + @pytest.mark.parametrize( + "name,ref,func", + [ + ("__contains__", True, lambda: 0 in SEQ), + ("__getitem__", 0, lambda: SEQ[0]), + ("__getitem__", [0, 0, 1], lambda: SEQ[:]), + ("__iter__", 0, lambda: next(iter(SEQ))), + ("__len__", 3, lambda: len(SEQ)), + ("__reversed__", 1, lambda: next(reversed(SEQ))), + ("count", 2, lambda: SEQ.count(0)), + ("index", 0, lambda: SEQ.index(0)), + ("index", 1, lambda: SEQ.index(0, start=1)), + ("__instancecheck__", True, lambda: isinstance([1], NestedSequence)), + ("__instancecheck__", False, lambda: isinstance(1, NestedSequence)), + ("__subclasscheck__", True, lambda: issubclass(Sequence, NestedSequence)), + ("__subclasscheck__", False, lambda: issubclass(int, NestedSequence)), + ("__class_getitem__", True, lambda: bool(NestedSequence[int])), + ("__abstractmethods__", Sequence.__abstractmethods__, + lambda: NestedSequence.__abstractmethods__), + ] + ) + @pytest.mark.skipif(not HAVE_PROTOCOL, reason="requires the `Protocol` class") + def test_method(self, name: str, ref: Any, func: Callable[[], Any]) -> None: + """Test that the ``NestedSequence`` methods return the intended values.""" + value = func() + assert value == ref