Skip to content

ENH Add Array API compatibility to cosine_similarity #29014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 17, 2024
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ Metrics
- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.mean_absolute_error`
- :func:`sklearn.metrics.mean_tweedie_deviance`
- :func:`sklearn.metrics.pairwise.cosine_similarity``
- :func:`sklearn.metrics.r2_score`
- :func:`sklearn.metrics.zero_one_loss`

Expand Down
6 changes: 4 additions & 2 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ See :ref:`array_api` for more details.

- :func:`sklearn.metrics.mean_tweedie_deviance` now supports Array API compatible
inputs.
:pr:`28106` by :user:`Thomas Li <lithomas1>`
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`.
:pr:`28106` by :user:`Thomas Li <lithomas1>`;
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`;
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`.


**Classes:**

Expand Down
11 changes: 10 additions & 1 deletion sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
gen_batches,
gen_even_slices,
)
from ..utils._array_api import (
_find_matching_floating_dtype,
_is_numpy_namespace,
get_namespace,
)
from ..utils._chunking import get_chunk_n_rows
from ..utils._mask import _get_mask
from ..utils._missing import is_scalar_nan
Expand Down Expand Up @@ -154,7 +159,11 @@ def check_pairwise_arrays(
An array equal to Y if Y was not None, guaranteed to be a numpy array.
If Y was None, safe_Y will be a pointer to X.
"""
X, Y, dtype_float = _return_float_dtype(X, Y)
xp, _ = get_namespace(X, Y)
if any([issparse(X), issparse(Y)]) or _is_numpy_namespace(xp):
X, Y, dtype_float = _return_float_dtype(X, Y)
else:
dtype_float = _find_matching_floating_dtype(X, Y, xp=xp)

estimator = "check_pairwise_arrays"
if dtype == "infer_float":
Expand Down
64 changes: 41 additions & 23 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
zero_one_loss,
)
from sklearn.metrics._base import _average_binary_score
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import LabelBinarizer
from sklearn.utils import shuffle
from sklearn.utils._array_api import (
Expand Down Expand Up @@ -1743,20 +1744,22 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):


def check_array_api_metric(
metric, array_namespace, device, dtype_name, y_true_np, y_pred_np, sample_weight
metric, array_namespace, device, dtype_name, a_np, b_np, **metric_kwargs
):
xp = _array_api_for_tests(array_namespace, device)

y_true_xp = xp.asarray(y_true_np, device=device)
y_pred_xp = xp.asarray(y_pred_np, device=device)
a_xp = xp.asarray(a_np, device=device)
b_xp = xp.asarray(b_np, device=device)

metric_np = metric(y_true_np, y_pred_np, sample_weight=sample_weight)
metric_np = metric(a_np, b_np, **metric_kwargs)

if sample_weight is not None:
sample_weight = xp.asarray(sample_weight, device=device)
if metric_kwargs.get("sample_weight") is not None:
metric_kwargs["sample_weight"] = xp.asarray(
metric_kwargs["sample_weight"], device=device
)

with config_context(array_api_dispatch=True):
metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight)
metric_xp = metric(a_xp, b_xp, **metric_kwargs)

assert_allclose(
_convert_to_numpy(xp.asarray(metric_xp), xp),
Expand All @@ -1776,8 +1779,8 @@ def check_array_api_binary_classification_metric(
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=None,
)

Expand All @@ -1788,8 +1791,8 @@ def check_array_api_binary_classification_metric(
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=sample_weight,
)

Expand All @@ -1805,8 +1808,8 @@ def check_array_api_multiclass_classification_metric(
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=None,
)

Expand All @@ -1817,8 +1820,8 @@ def check_array_api_multiclass_classification_metric(
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=sample_weight,
)

Expand All @@ -1832,8 +1835,8 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=None,
)

Expand All @@ -1844,8 +1847,8 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=sample_weight,
)

Expand All @@ -1861,8 +1864,8 @@ def check_array_api_regression_metric_multioutput(
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=None,
)

Expand All @@ -1873,8 +1876,8 @@ def check_array_api_regression_metric_multioutput(
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=sample_weight,
)

Expand All @@ -1886,6 +1889,20 @@ def check_array_api_multioutput_regression_metric(
check_array_api_regression_metric(metric, array_namespace, device, dtype_name)


def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name):

X_np = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=dtype_name)
Y_np = np.array([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]], dtype=dtype_name)

metric_kwargs = {}
if "dense_output" in signature(metric).parameters:
metric_kwargs["dense_output"] = True

check_array_api_metric(
metric, array_namespace, device, dtype_name, a_np=X_np, b_np=Y_np
)


array_api_metric_checkers = {
accuracy_score: [
check_array_api_binary_classification_metric,
Expand All @@ -1900,6 +1917,7 @@ def check_array_api_multioutput_regression_metric(
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
cosine_similarity: [check_array_api_metric_pairwise],
mean_absolute_error: [
check_array_api_regression_metric,
check_array_api_multioutput_regression_metric,
Expand Down
Loading