Skip to content

array API support for mean_poisson_deviance #29227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cuda-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ jobs:
run: |
source "${HOME}/conda/etc/profile.d/conda.sh"
conda activate sklearn
pytest -k 'array_api'
SCIPY_ARRAY_API=1 pytest -k 'array_api'
3 changes: 3 additions & 0 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ jobs:
# Here we make sure, that they are still run on a regular basis.
${{ if eq(variables['Build.Reason'], 'Schedule') }}:
SKLEARN_SKIP_NETWORK_TESTS: '0'
SCIPY_ARRAY_API: '1'

# Check compilation with Ubuntu 22.04 LTS (Jammy Jellyfish) and scipy from conda-forge
# By default the CI is sequential, where `Ubuntu_Jammy_Jellyfish` runs first and
Expand Down Expand Up @@ -221,6 +222,7 @@ jobs:
# makes sure that they are single threaded in each xdist subprocess.
PYTEST_XDIST_VERSION: 'none'
PIP_BUILD_ISOLATION: 'true'
SCIPY_ARRAY_API: '1'

- template: build_tools/azure/posix-docker.yml
parameters:
Expand Down Expand Up @@ -259,6 +261,7 @@ jobs:
DISTRIB: 'conda'
LOCK_FILE: './build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock'
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '5' # non-default seed
SCIPY_ARRAY_API: '1'
pylatest_conda_mkl_no_openmp:
DISTRIB: 'conda'
LOCK_FILE: './build_tools/azure/pylatest_conda_mkl_no_openmp_osx-64_conda.lock'
Expand Down
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ Metrics
- :func:`sklearn.metrics.mean_absolute_error`
- :func:`sklearn.metrics.mean_absolute_percentage_error`
- :func:`sklearn.metrics.mean_gamma_deviance`
- :func:`sklearn.metrics.mean_poisson_deviance` (requires `enabling array API support for SciPy <https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html#using-array-api-standard-support>`_)
- :func:`sklearn.metrics.mean_squared_error`
- :func:`sklearn.metrics.mean_tweedie_deviance`
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
Expand Down
1 change: 1 addition & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ See :ref:`array_api` for more details.
and :pr:`29143` by :user:`Tialo <Tialo>` and :user:`Loïc Estève <lesteve>`;
- :func:`sklearn.metrics.mean_absolute_percentage_error` :pr:`29300` by :user:`Emily Chen <EmilyXinyi>`;
- :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :user:`Emily Chen <EmilyXinyi>`;
- :func:`sklearn.metrics.mean_poisson_deviance` :pr:`29227` by :user:`Emily Chen <EmilyXinyi>`;
- :func:`sklearn.metrics.mean_squared_error` :pr:`29142` by :user:`Yaroslav Korobko <Tialo>`;
- :func:`sklearn.metrics.mean_tweedie_deviance` :pr:`28106` by :user:`Thomas Li <lithomas1>`;
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;
Expand Down
9 changes: 8 additions & 1 deletion sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
assert_array_less,
ignore_warnings,
)
from sklearn.utils.fixes import COO_CONTAINERS
from sklearn.utils.fixes import COO_CONTAINERS, parse_version, sp_version
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import _num_samples, check_random_state

Expand Down Expand Up @@ -1867,6 +1867,12 @@ def check_array_api_multilabel_classification_metric(


def check_array_api_regression_metric(metric, array_namespace, device, dtype_name):
func_name = metric.func.__name__ if isinstance(metric, partial) else metric.__name__
if func_name == "mean_poisson_deviance" and sp_version < parse_version("1.14.0"):
pytest.skip(
"mean_poisson_deviance's dependency `xlogy` is available as of scipy 1.14.0"
)

y_true_np = np.array([2.0, 0.1, 1.0, 4.0], dtype=dtype_name)
y_pred_np = np.array([0.5, 0.5, 2, 2], dtype=dtype_name)

Expand Down Expand Up @@ -2012,6 +2018,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_regression_metric,
],
paired_cosine_distances: [check_array_api_metric_pairwise],
mean_poisson_deviance: [check_array_api_regression_metric],
additive_chi2_kernel: [check_array_api_metric_pairwise],
mean_gamma_deviance: [check_array_api_regression_metric],
max_error: [check_array_api_regression_metric],
Expand Down
13 changes: 13 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import itertools
import math
import os
import warnings
from functools import wraps

import numpy
Expand Down Expand Up @@ -106,6 +108,17 @@ def _check_array_api_dispatch(array_api_dispatch):
f"NumPy must be {min_numpy_version} or newer to dispatch array using"
" the API specification"
)
if os.environ.get("SCIPY_ARRAY_API") != "1":
warnings.warn(
(
"Some scikit-learn array API features might rely on enabling "
"SciPy's own support for array API to function properly. "
"Please set the SCIPY_ARRAY_API=1 environment variable "
"before importing sklearn or scipy. More details at: "
"https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html"
),
UserWarning,
)


def _single_array_device(array):
Expand Down
15 changes: 14 additions & 1 deletion sklearn/utils/tests/test_array_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
from functools import partial

Expand Down Expand Up @@ -77,7 +78,7 @@ def test_get_namespace_ndarray_with_dispatch():


@skip_if_array_api_compat_not_configured
def test_get_namespace_array_api():
def test_get_namespace_array_api(monkeypatch):
"""Test get_namespace for ArrayAPI arrays."""
xp = pytest.importorskip("array_api_strict")

Expand All @@ -90,6 +91,18 @@ def test_get_namespace_array_api():
with pytest.raises(TypeError):
xp_out, is_array_api_compliant = get_namespace(X_xp, X_np)

def mock_getenv(key):
if key == "SCIPY_ARRAY_API":
return "0"

monkeypatch.setattr("os.environ.get", mock_getenv)
assert os.environ.get("SCIPY_ARRAY_API") != "1"
with pytest.warns(
UserWarning,
match="enabling SciPy's own support for array API to function properly. ",
):
xp_out, is_array_api_compliant = get_namespace(X_xp)


class _AdjustableNameAPITestWrapper(_ArrayAPIWrapper):
"""API wrapper that has an adjustable name. Used for testing."""
Expand Down