Skip to content

ENH: Add Array API support to NDCG/DCG score #31152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ Metrics
- :func:`sklearn.metrics.cluster.entropy`
- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.d2_tweedie_score`
- :func:`sklearn.metrics.dcg_score`
- :func:`sklearn.metrics.explained_variance_score`
- :func:`sklearn.metrics.f1_score`
- :func:`sklearn.metrics.fbeta_score`
Expand All @@ -148,6 +149,7 @@ Metrics
- :func:`sklearn.metrics.mean_squared_log_error`
- :func:`sklearn.metrics.mean_tweedie_deviance`
- :func:`sklearn.metrics.multilabel_confusion_matrix`
- :func:`sklearn.metrics.ndcg_score`
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
- :func:`sklearn.metrics.pairwise.chi2_kernel`
- :func:`sklearn.metrics.pairwise.cosine_similarity`
Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/31152.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- :func:`sklearn.metrics.ndcg_score` and :func:`sklearn.metrics.dcg_score` now supports Array API compatible inputs.
By :user:`Thomas Li <lithomas1>`
67 changes: 46 additions & 21 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@
check_consistent_length,
column_or_1d,
)
from ..utils._array_api import (
_average,
_bincount,
_flip,
_max_precision_float_dtype,
get_namespace,
get_namespace_and_device,
)
from ..utils._encode import _encode, _unique
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils.extmath import stable_cumsum
Expand Down Expand Up @@ -1487,20 +1495,27 @@ def _dcg_sample_scores(y_true, y_score, k=None, log_base=2, ignore_ties=False):
Cumulative Gain (the DCG obtained for a perfect ranking), in order to
have a score between 0 and 1.
"""
discount = 1 / (np.log(np.arange(y_true.shape[1]) + 2) / np.log(log_base))
xp, _, device = get_namespace_and_device(y_true, y_score)
max_float_dtype = _max_precision_float_dtype(xp, device)
log_base = xp.asarray(log_base, device=device, dtype=max_float_dtype)
discount = 1 / (
xp.log(xp.arange(y_true.shape[1], dtype=max_float_dtype, device=device) + 2)
/ xp.log(log_base)
)
if k is not None:
discount[k:] = 0
if ignore_ties:
ranking = np.argsort(y_score)[:, ::-1]
ranked = y_true[np.arange(ranking.shape[0])[:, np.newaxis], ranking]
cumulative_gains = discount.dot(ranked.T)
ranking = _flip(xp.argsort(y_score), axis=1)
ranked = xp.take_along_axis(y_true, ranking, axis=1)
cumulative_gains = discount @ xp.asarray(ranked.T, dtype=max_float_dtype)
else:
discount_cumsum = np.cumsum(discount)
cumulative_gains = [
_tie_averaged_dcg(y_t, y_s, discount_cumsum)
for y_t, y_s in zip(y_true, y_score)
]
cumulative_gains = np.asarray(cumulative_gains)
discount_cumsum = xp.cumulative_sum(discount)
cumulative_gains = xp.empty(y_true.shape[0], device=device)
for i in range(y_true.shape[0]):
cumulative_gains[i] = _tie_averaged_dcg(
y_true[i, :], y_score[i, :], discount_cumsum
)
cumulative_gains = xp.asarray(cumulative_gains, device=device)
return cumulative_gains


Expand Down Expand Up @@ -1541,15 +1556,21 @@ def _tie_averaged_dcg(y_true, y_score, discount_cumsum):
European conference on information retrieval (pp. 414-421). Springer,
Berlin, Heidelberg.
"""
_, inv, counts = np.unique(-y_score, return_inverse=True, return_counts=True)
ranked = np.zeros(len(counts))
np.add.at(ranked, inv, y_true)
ranked /= counts
groups = np.cumsum(counts) - 1
discount_sums = np.empty(len(counts))
xp, _, device = get_namespace_and_device(y_true, y_score)
# TODO: use unique_all when pytorch supports it
# _, _, inv, counts = xp.unique_all(-y_score)
_, inv = xp.unique_inverse(-y_score)
_, counts = xp.unique_counts(-y_score)
ranked = _bincount(inv, y_true, xp=xp, minlength=0)

max_float_dtype = _max_precision_float_dtype(xp, device)
ranked = xp.asarray(ranked, dtype=max_float_dtype, device=device)
ranked /= xp.asarray(counts, dtype=max_float_dtype)
groups = xp.cumulative_sum(counts) - 1
discount_sums = xp.empty(counts.shape[0], device=device)
discount_sums[0] = discount_cumsum[groups[0]]
discount_sums[1:] = np.diff(discount_cumsum[groups])
return (ranked * discount_sums).sum()
discount_sums[1:] = xp.diff(discount_cumsum[groups])
return xp.sum(ranked * discount_sums)


def _check_dcg_target_type(y_true):
Expand Down Expand Up @@ -1675,14 +1696,16 @@ def dcg_score(
"""
y_true = check_array(y_true, ensure_2d=False)
y_score = check_array(y_score, ensure_2d=False)
xp, _ = get_namespace(y_true, y_score)
check_consistent_length(y_true, y_score, sample_weight)
_check_dcg_target_type(y_true)
return float(
np.average(
_average(
_dcg_sample_scores(
y_true, y_score, k=k, log_base=log_base, ignore_ties=ignore_ties
),
weights=sample_weight,
xp=xp,
)
)

Expand Down Expand Up @@ -1846,7 +1869,9 @@ def ndcg_score(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False
y_score = check_array(y_score, ensure_2d=False)
check_consistent_length(y_true, y_score, sample_weight)

if y_true.min() < 0:
xp, _ = get_namespace(y_true, y_score)

if xp.min(y_true) < 0:
raise ValueError("ndcg_score should not be used on negative y_true values.")
if y_true.ndim > 1 and y_true.shape[1] <= 1:
raise ValueError(
Expand All @@ -1855,7 +1880,7 @@ def ndcg_score(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False
)
_check_dcg_target_type(y_true)
gain = _ndcg_sample_scores(y_true, y_score, k=k, ignore_ties=ignore_ties)
return float(np.average(gain, weights=sample_weight))
return float(_average(gain, weights=sample_weight, xp=xp))


@validate_params(
Expand Down
79 changes: 77 additions & 2 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
_atol_for_type,
_convert_to_numpy,
_get_namespace_device_dtype_ids,
_max_precision_float_dtype,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._testing import (
Expand Down Expand Up @@ -1845,8 +1846,20 @@ def check_array_api_metric(
):
xp = _array_api_for_tests(array_namespace, device)

a_xp = xp.asarray(a_np, device=device)
b_xp = xp.asarray(b_np, device=device)
def _get_device_arr(arr_np):
# Gets the equivalent device array for input numpy array
# Downcasts to a lower float precision type if float64 isn't
# supported (e.g. on MPS)
if np.isdtype(arr_np.dtype, "real floating"):
max_float_dtype = _max_precision_float_dtype(xp, device)
arr_xp = xp.asarray(arr_np, dtype=max_float_dtype, device=device)
arr_np = _convert_to_numpy(arr_xp, xp)
return arr_np, arr_xp
arr_xp = xp.asarray(arr_np, device=device)
return arr_np, arr_xp

a_np, a_xp = _get_device_arr(a_np)
b_np, b_xp = _get_device_arr(b_np)

metric_np = metric(a_np, b_np, **metric_kwargs)

Expand Down Expand Up @@ -2133,6 +2146,66 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
)


def check_array_api_ranking_metric(metric, array_namespace, device, dtype_name):
y_true_np = np.array(
[
[10, 0, 0, 1, 5],
[0, 0, 10, 5, 1],
]
)
y_score_np = np.array(
[
[0.1, 0.2, 0.3, 4, 70],
[5, 1, 3, 0, 50],
]
)

check_array_api_metric(
metric,
array_namespace,
device,
dtype_name,
a_np=y_true_np,
b_np=y_score_np,
sample_weight=None,
)

sample_weight = np.array([0.1, 0.9], dtype=dtype_name)

check_array_api_metric(
metric,
array_namespace,
device,
dtype_name,
a_np=y_true_np,
b_np=y_score_np,
sample_weight=sample_weight,
)

if "k" in signature(metric).parameters:
check_array_api_metric(
metric,
array_namespace,
device,
dtype_name,
a_np=y_true_np,
b_np=y_score_np,
sample_weight=None,
k=2,
)
if "ignore_ties" in signature(metric).parameters:
check_array_api_metric(
metric,
array_namespace,
device,
dtype_name,
a_np=y_true_np,
b_np=y_score_np,
sample_weight=None,
ignore_ties=True,
)


array_api_metric_checkers = {
accuracy_score: [
check_array_api_binary_classification_metric,
Expand Down Expand Up @@ -2229,6 +2302,8 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_regression_metric_multioutput,
],
sigmoid_kernel: [check_array_api_metric_pairwise],
ndcg_score: [check_array_api_ranking_metric],
dcg_score: [check_array_api_ranking_metric],
}


Expand Down
16 changes: 16 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .._config import get_config
from ..externals import array_api_compat
from ..externals import array_api_extra as xpx
from ..externals.array_api_compat import is_torch_namespace
from ..externals.array_api_compat import numpy as np_compat
from .fixes import parse_version

Expand Down Expand Up @@ -1002,3 +1003,18 @@ def _tolist(array, xp=None):
return array.tolist()
array_np = _convert_to_numpy(array, xp=xp)
return [element.item() for element in array_np]


def _flip(array, axis, xp=None):
# Workaround for PyTorch not supporting ::-1 syntax
# (https://github.com/pytorch/pytorch/issues/59786)
xp, _ = get_namespace(array, xp=xp)
if is_torch_namespace(xp):
import torch

return torch.flip(array, (axis,))
index = [
slice(None),
] * array.ndim
index[axis] = slice(None, None, -1)
return array[tuple(index)]
13 changes: 8 additions & 5 deletions sklearn/utils/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,11 @@ def _raise_or_return():
try:
# TODO(1.7): Change to ValueError when byte labels is deprecated.
# labels in bytes format
first_row_or_val = y[[0], :] if issparse(y) else y[0]
if isinstance(first_row_or_val, bytes):
first_row_or_val = y[[0], :] if issparse(y) else y[0, ...]
if (
hasattr(first_row_or_val.dtype, "kind")
and first_row_or_val.dtype.kind == "S"
):
warnings.warn(
(
"Support for labels represented as bytes is deprecated in v1.5 and"
Expand All @@ -372,9 +375,9 @@ def _raise_or_return():
)
# The old sequence of sequences format
if (
not hasattr(first_row_or_val, "__array__")
and isinstance(first_row_or_val, Sequence)
and not isinstance(first_row_or_val, str)
not hasattr(y[0], "__array__")
and isinstance(y[0], Sequence)
and not isinstance(y[0], str)
):
raise ValueError(
"You appear to be using a legacy multi-label data"
Expand Down
Loading