Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4ba176c
ENH: Make roc_curve array API compatible
lithomas1 Feb 22, 2025
b7992cc
add whatsnew
lithomas1 Mar 2, 2025
8d2d989
Merge branch 'main' of github.com:scikit-learn/scikit-learn into roc-…
lithomas1 Mar 2, 2025
54892c0
try to fix for oldest numpy
lithomas1 Mar 2, 2025
e87c6f8
fix argsort for old numpy
lithomas1 Mar 2, 2025
0980486
update following code review
lithomas1 Mar 17, 2025
07a910a
fix mps
lithomas1 Mar 18, 2025
619b54d
Apply suggestions from code review
lithomas1 Mar 24, 2025
5322d94
Merge branch 'main' into roc-curve
lithomas1 Mar 25, 2025
df7bf52
fix tests
lithomas1 Mar 25, 2025
0e3b7f9
Merge branch 'main' into roc-curve
lithomas1 Mar 26, 2025
bdcd3ef
Apply suggestions from code review
lithomas1 Apr 6, 2025
11358ad
remove spurious TODO
lithomas1 Apr 6, 2025
4ce3251
Merge branch 'main' of github.com:scikit-learn/scikit-learn into roc-…
lithomas1 Apr 6, 2025
844fa49
fixes from code review
lithomas1 Apr 11, 2025
714b3b0
Merge branch 'main' into roc-curve
lithomas1 Apr 13, 2025
4ecafc3
Use indexing rather than xp.take
lesteve May 12, 2025
2c0b1fb
Merge branch 'main' of github.com:scikit-learn/scikit-learn into roc-…
lithomas1 May 16, 2025
dd066b9
add tests and address comments
lithomas1 May 18, 2025
27b16c0
address feedback
lithomas1 May 26, 2025
b267e23
Address comments
lithomas1 May 30, 2025
1a8f317
Merge branch 'main' into roc-curve
lithomas1 May 30, 2025
eb608fa
Missing device call
lithomas1 May 30, 2025
4b12ebe
address comments
lithomas1 Jun 4, 2025
6e39e9b
Merge branch 'main' into roc-curve
lithomas1 Jun 4, 2025
ca90e82
fix CI
lithomas1 Jun 4, 2025
92dc114
Update sklearn/utils/tests/test_extmath.py
OmarManzoor Jun 5, 2025
3666012
Merge branch 'main' into roc-curve
OmarManzoor Jun 16, 2025
37e64be
Use xp.cumulative_sum and add a non regression test for roc_curve pre…
OmarManzoor Jun 16, 2025
2ceae95
Update sklearn/metrics/_ranking.py
OmarManzoor Jun 18, 2025
70c5582
Update sklearn/metrics/tests/test_common.py
OmarManzoor Jun 18, 2025
e3cea0f
Remove slow test
OmarManzoor Jun 18, 2025
dbdfbd5
Merge branch 'main' into roc-curve
OmarManzoor Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ Metrics
- :func:`sklearn.metrics.precision_recall_fscore_support`
- :func:`sklearn.metrics.r2_score`
- :func:`sklearn.metrics.recall_score`
- :func:`sklearn.metrics.roc_curve`
- :func:`sklearn.metrics.root_mean_squared_error`
- :func:`sklearn.metrics.root_mean_squared_log_error`
- :func:`sklearn.metrics.zero_one_loss`
Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/30878.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- :func:`sklearn.metrics.roc_curve` now supports Array API compatible inputs.
By :user:`Thomas Li <lithomas1>`
54 changes: 39 additions & 15 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@
check_consistent_length,
column_or_1d,
)
from ..utils._array_api import (
_max_precision_float_dtype,
get_namespace_and_device,
size,
)
from ..utils._encode import _encode, _unique
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils.extmath import stable_cumsum
from ..utils.multiclass import type_of_target
from ..utils.sparsefuncs import count_nonzero
from ..utils.validation import _check_pos_label_consistency, _check_sample_weight
Expand Down Expand Up @@ -862,6 +866,8 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
if not (y_type == "binary" or (y_type == "multiclass" and pos_label is not None)):
raise ValueError("{0} format is not supported".format(y_type))

xp, _, device = get_namespace_and_device(y_true, y_score, sample_weight)

check_consistent_length(y_true, y_score, sample_weight)
y_true = column_or_1d(y_true)
y_score = column_or_1d(y_score)
Expand All @@ -883,7 +889,7 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
y_true = y_true == pos_label

# sort scores and corresponding truth values
desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
desc_score_indices = xp.argsort(y_score, stable=True, descending=True)
y_score = y_score[desc_score_indices]
y_true = y_true[desc_score_indices]
if sample_weight is not None:
Expand All @@ -894,17 +900,27 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
# y_score typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = np.where(np.diff(y_score))[0]
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
distinct_value_indices = xp.nonzero(xp.diff(y_score))[0]
threshold_idxs = xp.concat(
[distinct_value_indices, xp.asarray([size(y_true) - 1], device=device)]
)

# accumulate the true positives with decreasing threshold
tps = stable_cumsum(y_true * weight)[threshold_idxs]
max_float_dtype = _max_precision_float_dtype(xp, device)
# Perform the weighted cumulative sum using float64 precision when possible
# to avoid numerical stability problem with tens of millions of very noisy
# predictions:
# https://github.com/scikit-learn/scikit-learn/issues/31533#issuecomment-2967062437
y_true = xp.astype(y_true, max_float_dtype)
tps = xp.cumulative_sum(y_true * weight, dtype=max_float_dtype)[threshold_idxs]
if sample_weight is not None:
# express fps as a cumsum to ensure fps is increasing even in
# the presence of floating point errors
fps = stable_cumsum((1 - y_true) * weight)[threshold_idxs]
fps = xp.cumulative_sum((1 - y_true) * weight, dtype=max_float_dtype)[
threshold_idxs
]
else:
fps = 1 + threshold_idxs - tps
fps = 1 + xp.astype(threshold_idxs, max_float_dtype) - tps
return fps, tps, y_score[threshold_idxs]


Expand Down Expand Up @@ -1160,6 +1176,7 @@ def roc_curve(
>>> thresholds
array([ inf, 0.8 , 0.4 , 0.35, 0.1 ])
"""
xp, _, device = get_namespace_and_device(y_true, y_score)
fps, tps, thresholds = _binary_clf_curve(
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
)
Expand All @@ -1173,27 +1190,34 @@ def roc_curve(
# _binary_clf_curve). This keeps all cases where the point should be kept,
# but does not drop more complicated cases like fps = [1, 3, 7],
# tps = [1, 2, 4]; there is no harm in keeping too many thresholds.
if drop_intermediate and len(fps) > 2:
optimal_idxs = np.where(
np.r_[True, np.logical_or(np.diff(fps, 2), np.diff(tps, 2)), True]
if drop_intermediate and fps.shape[0] > 2:
optimal_idxs = xp.where(
xp.concat(
[
xp.asarray([True], device=device),
xp.logical_or(xp.diff(fps, 2), xp.diff(tps, 2)),
xp.asarray([True], device=device),
]
)
)[0]
fps = fps[optimal_idxs]
tps = tps[optimal_idxs]
thresholds = thresholds[optimal_idxs]

# Add an extra threshold position
# to make sure that the curve starts at (0, 0)
tps = np.r_[0, tps]
fps = np.r_[0, fps]
tps = xp.concat([xp.asarray([0.0], device=device), tps])
fps = xp.concat([xp.asarray([0.0], device=device), fps])
# get dtype of `y_score` even if it is an array-like
thresholds = np.r_[np.inf, thresholds]
thresholds = xp.astype(thresholds, _max_precision_float_dtype(xp, device))
thresholds = xp.concat([xp.asarray([xp.inf], device=device), thresholds])

if fps[-1] <= 0:
warnings.warn(
"No negative samples in y_true, false positive value should be meaningless",
UndefinedMetricWarning,
)
fpr = np.repeat(np.nan, fps.shape)
fpr = xp.full(fps.shape, xp.nan)
else:
fpr = fps / fps[-1]

Expand All @@ -1202,7 +1226,7 @@ def roc_curve(
"No positive samples in y_true, true positive value should be meaningless",
UndefinedMetricWarning,
)
tpr = np.repeat(np.nan, tps.shape)
tpr = xp.full(tps.shape, xp.nan)
else:
tpr = tps / tps[-1]

Expand Down
21 changes: 16 additions & 5 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,11 +1928,19 @@ def check_array_api_metric(
with config_context(array_api_dispatch=True):
metric_xp = metric(a_xp, b_xp, **metric_kwargs)

assert_allclose(
_convert_to_numpy(xp.asarray(metric_xp), xp),
metric_np,
atol=_atol_for_type(dtype_name),
)
def _check_metric_matches(xp_val, np_val):
assert_allclose(
_convert_to_numpy(xp.asarray(xp_val), xp),
np_val,
atol=_atol_for_type(dtype_name),
)

# Handle cases where there are multiple return values, e.g. roc_curve:
if isinstance(metric_xp, tuple):
for metric_xp_val, metric_np_val in zip(metric_xp, metric_np):
_check_metric_matches(metric_xp_val, metric_np_val)
else:
_check_metric_matches(metric_xp, metric_np)


def check_array_api_binary_classification_metric(
Expand Down Expand Up @@ -2269,6 +2277,9 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_regression_metric_multioutput,
],
sigmoid_kernel: [check_array_api_metric_pairwise],
roc_curve: [
check_array_api_binary_classification_metric,
],
}


Expand Down
118 changes: 94 additions & 24 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
deprecated,
)
from sklearn.utils._array_api import (
_convert_to_numpy,
_get_namespace_device_dtype_ids,
_is_numpy_namespace,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._mocking import (
Expand Down Expand Up @@ -66,6 +68,7 @@
_allclose_dense_sparse,
_check_feature_names_in,
_check_method_params,
_check_pos_label_consistency,
_check_psd_eigenvalues,
_check_response_method,
_check_sample_weight,
Expand Down Expand Up @@ -1593,50 +1596,117 @@ def test_check_psd_eigenvalues_invalid(lambdas, err_type, err_msg):
_check_psd_eigenvalues(lambdas)


def test_check_sample_weight():
# check array order
sample_weight = np.ones(10)[::2]
assert not sample_weight.flags["C_CONTIGUOUS"]
sample_weight = _check_sample_weight(sample_weight, X=np.ones((5, 1)))
assert sample_weight.flags["C_CONTIGUOUS"]

def _check_sample_weight_common(xp):
# Common checks between numpy/array api tests
# for check_sample_weight
# check None input
sample_weight = _check_sample_weight(None, X=np.ones((5, 2)))
assert_allclose(sample_weight, np.ones(5))
sample_weight = _check_sample_weight(None, X=xp.ones((5, 2)))
assert_allclose(_convert_to_numpy(sample_weight, xp), np.ones(5))

# check numbers input
sample_weight = _check_sample_weight(2.0, X=np.ones((5, 2)))
assert_allclose(sample_weight, 2 * np.ones(5))
sample_weight = _check_sample_weight(2.0, X=xp.ones((5, 2)))
assert_allclose(_convert_to_numpy(sample_weight, xp), 2 * np.ones(5))

# check wrong number of dimensions
with pytest.raises(ValueError, match="Sample weights must be 1D array or scalar"):
_check_sample_weight(np.ones((2, 4)), X=np.ones((2, 2)))
_check_sample_weight(xp.ones((2, 4)), X=xp.ones((2, 2)))

# check incorrect n_samples
msg = r"sample_weight.shape == \(4,\), expected \(2,\)!"
msg = re.escape(f"sample_weight.shape == {xp.ones(4).shape}, expected (2,)!")
with pytest.raises(ValueError, match=msg):
_check_sample_weight(np.ones(4), X=np.ones((2, 2)))
_check_sample_weight(xp.ones(4), X=xp.ones((2, 2)))

# float32 dtype is preserved
X = np.ones((5, 2))
sample_weight = np.ones(5, dtype=np.float32)
X = xp.ones((5, 2))
sample_weight = xp.ones(5, dtype=xp.float32)
sample_weight = _check_sample_weight(sample_weight, X)
assert sample_weight.dtype == np.float32

# int dtype will be converted to float64 instead
X = np.ones((5, 2), dtype=int)
sample_weight = _check_sample_weight(None, X, dtype=X.dtype)
assert sample_weight.dtype == np.float64
assert sample_weight.dtype == xp.float32

# check negative weight when ensure_non_negative=True
X = np.ones((5, 2))
sample_weight = np.ones(_num_samples(X))
X = xp.ones((5, 2))
sample_weight = xp.ones(_num_samples(X))
sample_weight[-1] = -10
err_msg = "Negative values in data passed to `sample_weight`"
with pytest.raises(ValueError, match=err_msg):
_check_sample_weight(sample_weight, X, ensure_non_negative=True)


def test_check_sample_weight():
# check array order
sample_weight = np.ones(10)[::2]
assert not sample_weight.flags["C_CONTIGUOUS"]
sample_weight = _check_sample_weight(sample_weight, X=np.ones((5, 1)))
assert sample_weight.flags["C_CONTIGUOUS"]

_check_sample_weight_common(np)

# int dtype will be converted to float64 instead
X = np.ones((5, 2), dtype=int)
sample_weight = _check_sample_weight(None, X, dtype=X.dtype)
assert sample_weight.dtype == np.float64


@pytest.mark.parametrize(
"array_namespace,device,dtype", yield_namespace_device_dtype_combinations()
)
def test_check_sample_weight_array_api(array_namespace, device, dtype):
xp = _array_api_for_tests(array_namespace, device)
with config_context(array_api_dispatch=True):
# check array order
sample_weight = xp.ones(10)[::2]
if _is_numpy_namespace(xp):
assert not sample_weight.flags["C_CONTIGUOUS"]
sample_weight = _check_sample_weight(sample_weight, X=xp.ones((5, 1)))
if _is_numpy_namespace(xp):
assert sample_weight.flags["C_CONTIGUOUS"]

_check_sample_weight_common(xp)


@pytest.mark.parametrize("y_true", [[0], [0, 1], [-1, 1], [1, 1, 1], [-1, -1, -1]])
def test_check_pos_label_consistency(y_true):
assert _check_pos_label_consistency(None, y_true) == 1


@pytest.mark.parametrize(
"array_namespace,device,dtype",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize("y_true", [[0], [0, 1], [-1, 1], [1, 1, 1], [-1, -1, -1]])
def test_check_pos_label_consistency_array_api(array_namespace, device, dtype, y_true):
xp = _array_api_for_tests(array_namespace, device)
with config_context(array_api_dispatch=True):
arr = xp.asarray(y_true, device=device)
assert _check_pos_label_consistency(None, arr) == 1


@pytest.mark.parametrize("y_true", [[2, 3, 4], [-10], [0, -1]])
def test_check_pos_label_consistency_invalid(y_true):
with pytest.raises(ValueError, match="y_true takes value in"):
_check_pos_label_consistency(None, y_true)
# Make sure we only raise if pos_label is None
assert _check_pos_label_consistency("a", y_true) == "a"


@pytest.mark.parametrize(
"array_namespace,device,dtype",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize("y_true", [[2, 3, 4], [-10], [0, -1]])
def test_check_pos_label_consistency_invalid_array_api(
array_namespace, device, dtype, y_true
):
xp = _array_api_for_tests(array_namespace, device)
with config_context(array_api_dispatch=True):
arr = xp.asarray(y_true, device=device)
with pytest.raises(ValueError, match="y_true takes value in"):
_check_pos_label_consistency(None, arr)
# Make sure we only raise if pos_label is None
assert _check_pos_label_consistency("a", arr) == "a"


@pytest.mark.parametrize("toarray", [np.array, sp.csr_matrix, sp.csc_matrix])
def test_allclose_dense_sparse_equals(toarray):
base = np.arange(9).reshape(3, 3)
Expand Down
29 changes: 19 additions & 10 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning
from ..utils._array_api import (
_asarray_with_order,
_convert_to_numpy,
_is_numpy_namespace,
_max_precision_float_dtype,
get_namespace,
Expand Down Expand Up @@ -2174,7 +2175,9 @@ def _check_sample_weight(
sample_weight : ndarray of shape (n_samples,)
Validated sample weight. It is guaranteed to be "C" contiguous.
"""
xp, _, device = get_namespace_and_device(sample_weight, X)
xp, _, device = get_namespace_and_device(
sample_weight, X, remove_types=(int, float)
)

n_samples = _num_samples(X)

Expand All @@ -2186,9 +2189,9 @@ def _check_sample_weight(
dtype = max_float_type

if sample_weight is None:
sample_weight = xp.ones(n_samples, dtype=dtype)
sample_weight = xp.ones(n_samples, dtype=dtype, device=device)
elif isinstance(sample_weight, numbers.Number):
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype)
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype, device=device)
else:
if dtype is None:
dtype = float_dtypes
Expand Down Expand Up @@ -2650,14 +2653,20 @@ def _check_pos_label_consistency(pos_label, y_true):
# when elements in the two arrays are not comparable.
if pos_label is None:
# Compute classes only if pos_label is not specified:
classes = np.unique(y_true)
if classes.dtype.kind in "OUS" or not (
np.array_equal(classes, [0, 1])
or np.array_equal(classes, [-1, 1])
or np.array_equal(classes, [0])
or np.array_equal(classes, [-1])
or np.array_equal(classes, [1])
xp, _, device = get_namespace_and_device(y_true)
classes = xp.unique_values(y_true)
if (
(_is_numpy_namespace(xp) and classes.dtype.kind in "OUS")
or classes.shape[0] > 2
or not (
xp.all(classes == xp.asarray([0, 1], device=device))
or xp.all(classes == xp.asarray([-1, 1], device=device))
or xp.all(classes == xp.asarray([0], device=device))
or xp.all(classes == xp.asarray([-1], device=device))
or xp.all(classes == xp.asarray([1], device=device))
)
):
classes = _convert_to_numpy(classes, xp=xp)
classes_repr = ", ".join([repr(c) for c in classes.tolist()])
raise ValueError(
f"y_true takes value in {{{classes_repr}}} and pos_label is not "
Expand Down