diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 03b48290..e0fa5f54 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -1,5 +1,8 @@ """Public API Functions.""" +# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 +from __future__ import annotations + import operator import warnings @@ -719,7 +722,7 @@ def __init__( self._x = x self._idx = idx - def __getitem__(self, idx: Index, /) -> "at": # numpydoc ignore=PR01,RT01 + def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01 """ Allow for the alternate syntax ``at(x)[start:stop:step]``. diff --git a/src/array_api_extra/_lib/_compat.pyi b/src/array_api_extra/_lib/_compat.pyi index f65a28fc..4d06a7f1 100644 --- a/src/array_api_extra/_lib/_compat.pyi +++ b/src/array_api_extra/_lib/_compat.pyi @@ -1,5 +1,8 @@ """Static type stubs for `_compat.py`.""" +# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 +from __future__ import annotations + from types import ModuleType from ._typing import Array, Device diff --git a/src/array_api_extra/_lib/_utils.py b/src/array_api_extra/_lib/_utils.py index 523c21b8..1191b4f3 100644 --- a/src/array_api_extra/_lib/_utils.py +++ b/src/array_api_extra/_lib/_utils.py @@ -1,5 +1,8 @@ """Utility functions used by `array_api_extra/_funcs.py`.""" +# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 +from __future__ import annotations + from . import _compat from ._typing import Array, ModuleType