diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 4a1740e6..c3d77f8e 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -77,7 +77,7 @@ def pad( pad_width = xp.flip(pad_width, axis=(0,)).flatten() return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - if _delegate(xp, Backend.NUMPY, Backend.JAX_NUMPY, Backend.CUPY): + if _delegate(xp, Backend.NUMPY, Backend.JAX, Backend.CUPY): return xp.pad(x, pad_width, mode, constant_values=constant_values) return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) diff --git a/src/array_api_extra/_lib/_backends.py b/src/array_api_extra/_lib/_backends.py index ee2e051e..f044281a 100644 --- a/src/array_api_extra/_lib/_backends.py +++ b/src/array_api_extra/_lib/_backends.py @@ -28,9 +28,9 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace CUPY = "cupy", _compat.is_cupy_namespace TORCH = "torch", _compat.is_torch_namespace - DASK_ARRAY = "dask.array", _compat.is_dask_namespace + DASK = "dask.array", _compat.is_dask_namespace SPARSE = "sparse", _compat.is_pydata_sparse_namespace - JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace + JAX = "jax.numpy", _compat.is_jax_namespace def __new__( cls, value: str, _is_namespace: Callable[[ModuleType], bool] diff --git a/tests/conftest.py b/tests/conftest.py index fa555018..fc2e68e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -104,7 +104,7 @@ def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03 if library == Backend.NUMPY_READONLY: return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType] xp = pytest.importorskip(library.value) - if library == Backend.JAX_NUMPY: + if library == Backend.JAX: import jax jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call] diff --git a/tests/test_at.py b/tests/test_at.py index c65a4a0d..84aa464c 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -34,7 +34,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: @pytest.mark.skip_xp_backend( - Backend.SPARSE, reason="read-only backend without .at support" + Backend.SPARSE, reason="sparse:read-only backend without .at support" ) @pytest.mark.parametrize( ("kwargs", "expect_copy"), diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 3557642b..692d486b 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -26,7 +26,7 @@ # mypy: disable-error-code=no-untyped-usage -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="sparse:no expand_dims") class TestAtLeastND: def test_0D(self, xp: ModuleType): x = xp.asarray(1.0) @@ -98,7 +98,7 @@ def test_xp(self, xp: ModuleType): xp_assert_equal(y, x) -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="sparse:no isdtype") class TestCov: def test_basic(self, xp: ModuleType): xp_assert_close( @@ -137,7 +137,9 @@ def test_device(self, xp: ModuleType, device: Device): x = xp.asarray([1, 2, 3], device=device) assert get_device(cov(x)) == device - @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY) + @pytest.mark.skip_xp_backend( + Backend.NUMPY_READONLY, reason="numpy_readonly:explicit xp" + ) def test_xp(self, xp: ModuleType): xp_assert_close( cov(xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]]).T, xp=xp), @@ -145,7 +147,7 @@ def test_xp(self, xp: ModuleType): ) -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="sparse:no device kwarg in asarray") class TestCreateDiagonal: def test_1d(self, xp: ModuleType): # from np.diag tests @@ -191,10 +193,10 @@ def test_xp(self, xp: ModuleType): xp_assert_equal(y, xp.asarray([[1, 0], [0, 2]])) -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no sparse.expand_dims") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="sparse:no expand_dims") class TestExpandDims: - @pytest.mark.skip_xp_backend(Backend.DASK_ARRAY, reason="tuple index out of range") - @pytest.mark.skip_xp_backend(Backend.TORCH, reason="tuple index out of range") + @pytest.mark.skip_xp_backend(Backend.DASK, reason="dask:tuple index out of range") + @pytest.mark.skip_xp_backend(Backend.TORCH, reason="torch:tuple index out of range") def test_functionality(self, xp: ModuleType): def _squeeze_all(b: Array) -> Array: """Mimics `np.squeeze(b)`. `xpx.squeeze`?""" @@ -252,7 +254,7 @@ def test_xp(self, xp: ModuleType): assert y.shape == (1, 1, 1, 3) -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no sparse.expand_dims") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="sparse:no expand_dims") class TestKron: def test_basic(self, xp: ModuleType): # Using 0-dimensional array @@ -349,7 +351,9 @@ def test_xp(self, xp: ModuleType): xp_assert_equal(nunique(a, xp=xp), xp.asarray(3)) -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no arange, no device") +@pytest.mark.skip_xp_backend( + Backend.SPARSE, reason="sparse:no arange, no device kwarg in asarray" +) class TestPad: def test_simple(self, xp: ModuleType): a = xp.arange(1, 4) @@ -399,8 +403,8 @@ def test_list_of_tuples_width(self, xp: ModuleType): assert padded.shape == (4, 4) -@pytest.mark.skip_xp_backend(Backend.DASK_ARRAY, reason="no argsort") -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device") +@pytest.mark.skip_xp_backend(Backend.DASK, reason="dask:no argsort") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="sparse:no device kwarg in asarray") class TestSetDiff1D: @pytest.mark.skip_xp_backend( Backend.TORCH, reason="index_select not implemented for uint32" @@ -436,7 +440,9 @@ def test_device(self, xp: ModuleType, device: Device): x2 = xp.asarray([2, 3, 4], device=device) assert get_device(setdiff1d(x1, x2)) == device - @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY) + @pytest.mark.skip_xp_backend( + Backend.NUMPY_READONLY, reason="numpy_readonly:explicit xp" + ) def test_xp(self, xp: ModuleType): x1 = xp.asarray([3, 8, 20]) x2 = xp.asarray([2, 3, 4]) @@ -445,7 +451,7 @@ def test_xp(self, xp: ModuleType): xp_assert_equal(actual, expected) -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="sparse:no isdtype") class TestSinc: def test_simple(self, xp: ModuleType): xp_assert_equal(sinc(xp.asarray(0.0)), xp.asarray(1.0)) diff --git a/tests/test_testing.py b/tests/test_testing.py index e0ce66ad..41fc6673 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -17,7 +17,9 @@ xp_assert_equal, pytest.param( xp_assert_close, - marks=pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype"), + marks=pytest.mark.skip_xp_backend( + Backend.SPARSE, reason="sparse:no isdtype" + ), ), ], ) @@ -38,15 +40,19 @@ def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): # func(xp.asarray([1, 2]), xp.asarray([1, 3]), err_msg="hello") -@pytest.mark.skip_xp_backend(Backend.NUMPY) -@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY) +@pytest.mark.skip_xp_backend(Backend.NUMPY, reason="numpy:test other ns vs. numpy") +@pytest.mark.skip_xp_backend( + Backend.NUMPY_READONLY, reason="numpy_readonly:test other ns vs. numpy" +) @pytest.mark.parametrize( "func", [ xp_assert_equal, pytest.param( xp_assert_close, - marks=pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype"), + marks=pytest.mark.skip_xp_backend( + Backend.SPARSE, reason="sparse:no isdtype" + ), ), ], ) @@ -59,7 +65,7 @@ def test_assert_close_equal_namespace(xp: ModuleType, func: Callable[..., None]) func(xp.asarray([0]), [0]) -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="sparse:no isdtype") def test_assert_close_tolerance(xp: ModuleType): xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.03) with pytest.raises(AssertionError): diff --git a/tests/test_utils.py b/tests/test_utils.py index 981d5c03..f053f177 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,8 +12,10 @@ class TestIn1D: - @pytest.mark.skip_xp_backend(Backend.DASK_ARRAY, reason="no argsort") - @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse, no device") + @pytest.mark.skip_xp_backend(Backend.DASK, reason="dask:no argsort") + @pytest.mark.skip_xp_backend( + Backend.SPARSE, reason="sparse:no unique_inverse, no device kwarg in asarray" + ) # cover both code paths @pytest.mark.parametrize("n", [9, 15]) def test_no_invert_assume_unique(self, xp: ModuleType, n: int): @@ -23,14 +25,20 @@ def test_no_invert_assume_unique(self, xp: ModuleType, n: int): actual = in1d(x1, x2) xp_assert_equal(actual, expected) - @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device") + @pytest.mark.skip_xp_backend( + Backend.SPARSE, reason="sparse: no device kwarg in asarray" + ) def test_device(self, xp: ModuleType, device: Device): x1 = xp.asarray([3, 8, 20], device=device) x2 = xp.asarray([2, 3, 4], device=device) assert get_device(in1d(x1, x2)) == device - @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY) - @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no arange, no device") + @pytest.mark.skip_xp_backend( + Backend.NUMPY_READONLY, reason="numpy_readonly:explicit xp" + ) + @pytest.mark.skip_xp_backend( + Backend.SPARSE, reason="sparse:no arange, no device kwarg in asarray" + ) def test_xp(self, xp: ModuleType): x1 = xp.asarray([1, 6]) x2 = xp.arange(5)