Skip to content

BUG: lazy_xp_function crashes with Cython ufuncs #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 39 additions & 19 deletions src/array_api_extra/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
from __future__ import annotations

from collections.abc import Callable, Iterable, Sequence
import contextlib
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import wraps
from types import ModuleType
from typing import TYPE_CHECKING, Any, TypeVar, cast
Expand Down Expand Up @@ -42,6 +43,8 @@ def override(func: Callable[P, T]) -> Callable[P, T]:

T = TypeVar("T")

_ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[no-any-explicit]


def lazy_xp_function( # type: ignore[no-any-explicit]
func: Callable[..., Any],
Expand Down Expand Up @@ -132,12 +135,16 @@ def test_myfunc(xp):
a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
mymodule.myfunc(a) # This is not
"""
func.allow_dask_compute = allow_dask_compute # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess]
if jax_jit:
func.lazy_jax_jit_kwargs = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess]
"static_argnums": static_argnums,
"static_argnames": static_argnames,
}
tags = {
"allow_dask_compute": allow_dask_compute,
"jax_jit": jax_jit,
"static_argnums": static_argnums,
"static_argnames": static_argnames,
}
try:
func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
except AttributeError: # @cython.vectorize
_ufuncs_tags[func] = tags


def patch_lazy_xp_functions(
Expand Down Expand Up @@ -179,24 +186,37 @@ def xp(request, monkeypatch):
"""
globals_ = cast("dict[str, Any]", request.module.__dict__) # type: ignore[no-any-explicit]

if is_dask_namespace(xp):
def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]: # type: ignore[no-any-explicit]
for name, func in globals_.items():
n = getattr(func, "allow_dask_compute", None)
if n is not None:
assert isinstance(n, int)
wrapped = _allow_dask_compute(func, n)
monkeypatch.setitem(globals_, name, wrapped)
tags: dict[str, Any] | None = None # type: ignore[no-any-explicit]
with contextlib.suppress(AttributeError):
tags = func._lazy_xp_function # pylint: disable=protected-access
if tags is None:
with contextlib.suppress(KeyError, TypeError):
tags = _ufuncs_tags[func]
if tags is not None:
yield name, func, tags

if is_dask_namespace(xp):
for name, func, tags in iter_tagged():
n = tags["allow_dask_compute"]
wrapped = _allow_dask_compute(func, n)
monkeypatch.setitem(globals_, name, wrapped)

elif is_jax_namespace(xp):
import jax

for name, func in globals_.items():
kwargs = cast( # type: ignore[no-any-explicit]
"dict[str, Any] | None", getattr(func, "lazy_jax_jit_kwargs", None)
)
if kwargs is not None:
for name, func, tags in iter_tagged():
if tags["jax_jit"]:
# suppress unused-ignore to run mypy in -e lint as well as -e dev
wrapped = cast(Callable[..., Any], jax.jit(func, **kwargs)) # type: ignore[no-any-explicit,no-untyped-call,unused-ignore]
wrapped = cast( # type: ignore[no-any-explicit]
Callable[..., Any],
jax.jit(
func,
static_argnums=tags["static_argnums"],
static_argnames=tags["static_argnames"],
),
)
monkeypatch.setitem(globals_, name, wrapped)


Expand Down
30 changes: 30 additions & 0 deletions tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,33 @@ def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Arra
xp_assert_equal(func(x, 0, False), xp.asarray([3.0, 6.0]))
xp_assert_equal(func(x, 1, flag=True), xp.asarray([2.0, 4.0]))
xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0]))


try:
# Test an arbitrary Cython ufunc (@cython.vectorize).
# When SCIPY_ARRAY_API is not set, this is the same as
# scipy.special.erf.
from scipy.special._ufuncs import erf # type: ignore[import-not-found]

lazy_xp_function(erf) # pyright: ignore[reportUnknownArgumentType]
except ImportError:
erf = None


@pytest.mark.filterwarnings("ignore:__array_wrap__:DeprecationWarning") # torch
def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
pytest.importorskip("scipy")
assert erf is not None
x = xp.asarray([6.0, 7.0])
if library in (Backend.ARRAY_API_STRICT, Backend.JAX):
# array-api-strict arrays are auto-converted to numpy
# which results in an assertion error for mismatched namespaces
# eager jax arrays are auto-converted to numpy in eager jax
# and fail in jax.jit (which lazy_xp_function tests here)
with pytest.raises((TypeError, AssertionError)):
xp_assert_equal(erf(x), xp.asarray([1.0, 1.0]))
else:
# cupy, dask and sparse define __array_ufunc__ and dispatch accordingly
# note that when sparse reduces to scalar it returns a np.generic, which
# would make xp_assert_equal fail.
xp_assert_equal(erf(x), xp.asarray([1.0, 1.0]))