Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion maint_tools/vendor_array_api_compat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set -o nounset
set -o errexit

URL="https://github.com/data-apis/array-api-compat.git"
VERSION="1.11.2"
VERSION="1.12"

ROOT_DIR=sklearn/externals/array_api_compat

Expand Down
2 changes: 1 addition & 1 deletion sklearn/externals/array_api_compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
this implementation for the default when working with NumPy arrays.

"""
__version__ = '1.11.2'
__version__ = '1.12.0'

from .common import * # noqa: F401, F403
25 changes: 19 additions & 6 deletions sklearn/externals/array_api_compat/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
Internal helpers
"""

from collections.abc import Callable
from functools import wraps
from inspect import signature
from types import ModuleType
from typing import TypeVar

def get_xp(xp):
_T = TypeVar("_T")


def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
"""
Decorator to automatically replace xp with the corresponding array module.

Expand All @@ -22,14 +28,14 @@ def func(x, /, xp, kwarg=None):

"""

def inner(f):
def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
@wraps(f)
def wrapped_f(*args, **kwargs):
def wrapped_f(*args: object, **kwargs: object) -> object:
return f(*args, xp=xp, **kwargs)

sig = signature(f)
new_sig = sig.replace(
parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
parameters=[par for i, par in sig.parameters.items() if i != "xp"]
)

if wrapped_f.__doc__ is None:
Expand All @@ -40,7 +46,14 @@ def wrapped_f(*args, **kwargs):
specification for more details.

"""
wrapped_f.__signature__ = new_sig
return wrapped_f
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
return wrapped_f # pyright: ignore[reportReturnType]

return inner


__all__ = ["get_xp"]


def __dir__() -> list[str]:
return __all__
2 changes: 1 addition & 1 deletion sklearn/externals/array_api_compat/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._helpers import * # noqa: F403
from ._helpers import * # noqa: F403
Loading
Loading