diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index 6139c8e8b2863..3c650591746f0 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -165,6 +165,7 @@ Metrics - :func:`sklearn.metrics.precision_recall_fscore_support` - :func:`sklearn.metrics.r2_score` - :func:`sklearn.metrics.recall_score` +- :func:`sklearn.metrics.roc_curve` - :func:`sklearn.metrics.root_mean_squared_error` - :func:`sklearn.metrics.root_mean_squared_log_error` - :func:`sklearn.metrics.zero_one_loss` diff --git a/doc/whats_new/upcoming_changes/array-api/30878.feature.rst b/doc/whats_new/upcoming_changes/array-api/30878.feature.rst new file mode 100644 index 0000000000000..fabb4c80f5713 --- /dev/null +++ b/doc/whats_new/upcoming_changes/array-api/30878.feature.rst @@ -0,0 +1,2 @@ +- :func:`sklearn.metrics.roc_curve` now supports Array API compatible inputs. + By :user:`Thomas Li ` diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 2d0e5211c236c..59b6744d5778d 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -27,9 +27,13 @@ check_consistent_length, column_or_1d, ) +from ..utils._array_api import ( + _max_precision_float_dtype, + get_namespace_and_device, + size, +) from ..utils._encode import _encode, _unique from ..utils._param_validation import Interval, StrOptions, validate_params -from ..utils.extmath import stable_cumsum from ..utils.multiclass import type_of_target from ..utils.sparsefuncs import count_nonzero from ..utils.validation import _check_pos_label_consistency, _check_sample_weight @@ -862,6 +866,8 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): if not (y_type == "binary" or (y_type == "multiclass" and pos_label is not None)): raise ValueError("{0} format is not supported".format(y_type)) + xp, _, device = get_namespace_and_device(y_true, y_score, sample_weight) + check_consistent_length(y_true, y_score, sample_weight) y_true = column_or_1d(y_true) y_score = column_or_1d(y_score) @@ -883,7 +889,7 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): y_true = y_true == pos_label # sort scores and corresponding truth values - desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] + desc_score_indices = xp.argsort(y_score, stable=True, descending=True) y_score = y_score[desc_score_indices] y_true = y_true[desc_score_indices] if sample_weight is not None: @@ -894,17 +900,27 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): # y_score typically has many tied values. Here we extract # the indices associated with the distinct values. We also # concatenate a value for the end of the curve. - distinct_value_indices = np.where(np.diff(y_score))[0] - threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1] + distinct_value_indices = xp.nonzero(xp.diff(y_score))[0] + threshold_idxs = xp.concat( + [distinct_value_indices, xp.asarray([size(y_true) - 1], device=device)] + ) # accumulate the true positives with decreasing threshold - tps = stable_cumsum(y_true * weight)[threshold_idxs] + max_float_dtype = _max_precision_float_dtype(xp, device) + # Perform the weighted cumulative sum using float64 precision when possible + # to avoid numerical stability problem with tens of millions of very noisy + # predictions: + # https://github.com/scikit-learn/scikit-learn/issues/31533#issuecomment-2967062437 + y_true = xp.astype(y_true, max_float_dtype) + tps = xp.cumulative_sum(y_true * weight, dtype=max_float_dtype)[threshold_idxs] if sample_weight is not None: # express fps as a cumsum to ensure fps is increasing even in # the presence of floating point errors - fps = stable_cumsum((1 - y_true) * weight)[threshold_idxs] + fps = xp.cumulative_sum((1 - y_true) * weight, dtype=max_float_dtype)[ + threshold_idxs + ] else: - fps = 1 + threshold_idxs - tps + fps = 1 + xp.astype(threshold_idxs, max_float_dtype) - tps return fps, tps, y_score[threshold_idxs] @@ -1160,6 +1176,7 @@ def roc_curve( >>> thresholds array([ inf, 0.8 , 0.4 , 0.35, 0.1 ]) """ + xp, _, device = get_namespace_and_device(y_true, y_score) fps, tps, thresholds = _binary_clf_curve( y_true, y_score, pos_label=pos_label, sample_weight=sample_weight ) @@ -1173,9 +1190,15 @@ def roc_curve( # _binary_clf_curve). This keeps all cases where the point should be kept, # but does not drop more complicated cases like fps = [1, 3, 7], # tps = [1, 2, 4]; there is no harm in keeping too many thresholds. - if drop_intermediate and len(fps) > 2: - optimal_idxs = np.where( - np.r_[True, np.logical_or(np.diff(fps, 2), np.diff(tps, 2)), True] + if drop_intermediate and fps.shape[0] > 2: + optimal_idxs = xp.where( + xp.concat( + [ + xp.asarray([True], device=device), + xp.logical_or(xp.diff(fps, 2), xp.diff(tps, 2)), + xp.asarray([True], device=device), + ] + ) )[0] fps = fps[optimal_idxs] tps = tps[optimal_idxs] @@ -1183,17 +1206,18 @@ def roc_curve( # Add an extra threshold position # to make sure that the curve starts at (0, 0) - tps = np.r_[0, tps] - fps = np.r_[0, fps] + tps = xp.concat([xp.asarray([0.0], device=device), tps]) + fps = xp.concat([xp.asarray([0.0], device=device), fps]) # get dtype of `y_score` even if it is an array-like - thresholds = np.r_[np.inf, thresholds] + thresholds = xp.astype(thresholds, _max_precision_float_dtype(xp, device)) + thresholds = xp.concat([xp.asarray([xp.inf], device=device), thresholds]) if fps[-1] <= 0: warnings.warn( "No negative samples in y_true, false positive value should be meaningless", UndefinedMetricWarning, ) - fpr = np.repeat(np.nan, fps.shape) + fpr = xp.full(fps.shape, xp.nan) else: fpr = fps / fps[-1] @@ -1202,7 +1226,7 @@ def roc_curve( "No positive samples in y_true, true positive value should be meaningless", UndefinedMetricWarning, ) - tpr = np.repeat(np.nan, tps.shape) + tpr = xp.full(tps.shape, xp.nan) else: tpr = tps / tps[-1] diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index be741d67e24c2..8b915fcd0c1a6 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1928,11 +1928,19 @@ def check_array_api_metric( with config_context(array_api_dispatch=True): metric_xp = metric(a_xp, b_xp, **metric_kwargs) - assert_allclose( - _convert_to_numpy(xp.asarray(metric_xp), xp), - metric_np, - atol=_atol_for_type(dtype_name), - ) + def _check_metric_matches(xp_val, np_val): + assert_allclose( + _convert_to_numpy(xp.asarray(xp_val), xp), + np_val, + atol=_atol_for_type(dtype_name), + ) + + # Handle cases where there are multiple return values, e.g. roc_curve: + if isinstance(metric_xp, tuple): + for metric_xp_val, metric_np_val in zip(metric_xp, metric_np): + _check_metric_matches(metric_xp_val, metric_np_val) + else: + _check_metric_matches(metric_xp, metric_np) def check_array_api_binary_classification_metric( @@ -2269,6 +2277,9 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) check_array_api_regression_metric_multioutput, ], sigmoid_kernel: [check_array_api_metric_pairwise], + roc_curve: [ + check_array_api_binary_classification_metric, + ], } diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 99db6cdfb16aa..adc5d80f591be 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -35,7 +35,9 @@ deprecated, ) from sklearn.utils._array_api import ( + _convert_to_numpy, _get_namespace_device_dtype_ids, + _is_numpy_namespace, yield_namespace_device_dtype_combinations, ) from sklearn.utils._mocking import ( @@ -66,6 +68,7 @@ _allclose_dense_sparse, _check_feature_names_in, _check_method_params, + _check_pos_label_consistency, _check_psd_eigenvalues, _check_response_method, _check_sample_weight, @@ -1593,50 +1596,117 @@ def test_check_psd_eigenvalues_invalid(lambdas, err_type, err_msg): _check_psd_eigenvalues(lambdas) -def test_check_sample_weight(): - # check array order - sample_weight = np.ones(10)[::2] - assert not sample_weight.flags["C_CONTIGUOUS"] - sample_weight = _check_sample_weight(sample_weight, X=np.ones((5, 1))) - assert sample_weight.flags["C_CONTIGUOUS"] - +def _check_sample_weight_common(xp): + # Common checks between numpy/array api tests + # for check_sample_weight # check None input - sample_weight = _check_sample_weight(None, X=np.ones((5, 2))) - assert_allclose(sample_weight, np.ones(5)) + sample_weight = _check_sample_weight(None, X=xp.ones((5, 2))) + assert_allclose(_convert_to_numpy(sample_weight, xp), np.ones(5)) # check numbers input - sample_weight = _check_sample_weight(2.0, X=np.ones((5, 2))) - assert_allclose(sample_weight, 2 * np.ones(5)) + sample_weight = _check_sample_weight(2.0, X=xp.ones((5, 2))) + assert_allclose(_convert_to_numpy(sample_weight, xp), 2 * np.ones(5)) # check wrong number of dimensions with pytest.raises(ValueError, match="Sample weights must be 1D array or scalar"): - _check_sample_weight(np.ones((2, 4)), X=np.ones((2, 2))) + _check_sample_weight(xp.ones((2, 4)), X=xp.ones((2, 2))) # check incorrect n_samples - msg = r"sample_weight.shape == \(4,\), expected \(2,\)!" + msg = re.escape(f"sample_weight.shape == {xp.ones(4).shape}, expected (2,)!") with pytest.raises(ValueError, match=msg): - _check_sample_weight(np.ones(4), X=np.ones((2, 2))) + _check_sample_weight(xp.ones(4), X=xp.ones((2, 2))) # float32 dtype is preserved - X = np.ones((5, 2)) - sample_weight = np.ones(5, dtype=np.float32) + X = xp.ones((5, 2)) + sample_weight = xp.ones(5, dtype=xp.float32) sample_weight = _check_sample_weight(sample_weight, X) - assert sample_weight.dtype == np.float32 - - # int dtype will be converted to float64 instead - X = np.ones((5, 2), dtype=int) - sample_weight = _check_sample_weight(None, X, dtype=X.dtype) - assert sample_weight.dtype == np.float64 + assert sample_weight.dtype == xp.float32 # check negative weight when ensure_non_negative=True - X = np.ones((5, 2)) - sample_weight = np.ones(_num_samples(X)) + X = xp.ones((5, 2)) + sample_weight = xp.ones(_num_samples(X)) sample_weight[-1] = -10 err_msg = "Negative values in data passed to `sample_weight`" with pytest.raises(ValueError, match=err_msg): _check_sample_weight(sample_weight, X, ensure_non_negative=True) +def test_check_sample_weight(): + # check array order + sample_weight = np.ones(10)[::2] + assert not sample_weight.flags["C_CONTIGUOUS"] + sample_weight = _check_sample_weight(sample_weight, X=np.ones((5, 1))) + assert sample_weight.flags["C_CONTIGUOUS"] + + _check_sample_weight_common(np) + + # int dtype will be converted to float64 instead + X = np.ones((5, 2), dtype=int) + sample_weight = _check_sample_weight(None, X, dtype=X.dtype) + assert sample_weight.dtype == np.float64 + + +@pytest.mark.parametrize( + "array_namespace,device,dtype", yield_namespace_device_dtype_combinations() +) +def test_check_sample_weight_array_api(array_namespace, device, dtype): + xp = _array_api_for_tests(array_namespace, device) + with config_context(array_api_dispatch=True): + # check array order + sample_weight = xp.ones(10)[::2] + if _is_numpy_namespace(xp): + assert not sample_weight.flags["C_CONTIGUOUS"] + sample_weight = _check_sample_weight(sample_weight, X=xp.ones((5, 1))) + if _is_numpy_namespace(xp): + assert sample_weight.flags["C_CONTIGUOUS"] + + _check_sample_weight_common(xp) + + +@pytest.mark.parametrize("y_true", [[0], [0, 1], [-1, 1], [1, 1, 1], [-1, -1, -1]]) +def test_check_pos_label_consistency(y_true): + assert _check_pos_label_consistency(None, y_true) == 1 + + +@pytest.mark.parametrize( + "array_namespace,device,dtype", + yield_namespace_device_dtype_combinations(), + ids=_get_namespace_device_dtype_ids, +) +@pytest.mark.parametrize("y_true", [[0], [0, 1], [-1, 1], [1, 1, 1], [-1, -1, -1]]) +def test_check_pos_label_consistency_array_api(array_namespace, device, dtype, y_true): + xp = _array_api_for_tests(array_namespace, device) + with config_context(array_api_dispatch=True): + arr = xp.asarray(y_true, device=device) + assert _check_pos_label_consistency(None, arr) == 1 + + +@pytest.mark.parametrize("y_true", [[2, 3, 4], [-10], [0, -1]]) +def test_check_pos_label_consistency_invalid(y_true): + with pytest.raises(ValueError, match="y_true takes value in"): + _check_pos_label_consistency(None, y_true) + # Make sure we only raise if pos_label is None + assert _check_pos_label_consistency("a", y_true) == "a" + + +@pytest.mark.parametrize( + "array_namespace,device,dtype", + yield_namespace_device_dtype_combinations(), + ids=_get_namespace_device_dtype_ids, +) +@pytest.mark.parametrize("y_true", [[2, 3, 4], [-10], [0, -1]]) +def test_check_pos_label_consistency_invalid_array_api( + array_namespace, device, dtype, y_true +): + xp = _array_api_for_tests(array_namespace, device) + with config_context(array_api_dispatch=True): + arr = xp.asarray(y_true, device=device) + with pytest.raises(ValueError, match="y_true takes value in"): + _check_pos_label_consistency(None, arr) + # Make sure we only raise if pos_label is None + assert _check_pos_label_consistency("a", arr) == "a" + + @pytest.mark.parametrize("toarray", [np.array, sp.csr_matrix, sp.csc_matrix]) def test_allclose_dense_sparse_equals(toarray): base = np.arange(9).reshape(3, 3) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index d766ad16545da..acaac8c9f6c84 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -20,6 +20,7 @@ from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning from ..utils._array_api import ( _asarray_with_order, + _convert_to_numpy, _is_numpy_namespace, _max_precision_float_dtype, get_namespace, @@ -2174,7 +2175,9 @@ def _check_sample_weight( sample_weight : ndarray of shape (n_samples,) Validated sample weight. It is guaranteed to be "C" contiguous. """ - xp, _, device = get_namespace_and_device(sample_weight, X) + xp, _, device = get_namespace_and_device( + sample_weight, X, remove_types=(int, float) + ) n_samples = _num_samples(X) @@ -2186,9 +2189,9 @@ def _check_sample_weight( dtype = max_float_type if sample_weight is None: - sample_weight = xp.ones(n_samples, dtype=dtype) + sample_weight = xp.ones(n_samples, dtype=dtype, device=device) elif isinstance(sample_weight, numbers.Number): - sample_weight = xp.full(n_samples, sample_weight, dtype=dtype) + sample_weight = xp.full(n_samples, sample_weight, dtype=dtype, device=device) else: if dtype is None: dtype = float_dtypes @@ -2650,14 +2653,20 @@ def _check_pos_label_consistency(pos_label, y_true): # when elements in the two arrays are not comparable. if pos_label is None: # Compute classes only if pos_label is not specified: - classes = np.unique(y_true) - if classes.dtype.kind in "OUS" or not ( - np.array_equal(classes, [0, 1]) - or np.array_equal(classes, [-1, 1]) - or np.array_equal(classes, [0]) - or np.array_equal(classes, [-1]) - or np.array_equal(classes, [1]) + xp, _, device = get_namespace_and_device(y_true) + classes = xp.unique_values(y_true) + if ( + (_is_numpy_namespace(xp) and classes.dtype.kind in "OUS") + or classes.shape[0] > 2 + or not ( + xp.all(classes == xp.asarray([0, 1], device=device)) + or xp.all(classes == xp.asarray([-1, 1], device=device)) + or xp.all(classes == xp.asarray([0], device=device)) + or xp.all(classes == xp.asarray([-1], device=device)) + or xp.all(classes == xp.asarray([1], device=device)) + ) ): + classes = _convert_to_numpy(classes, xp=xp) classes_repr = ", ".join([repr(c) for c in classes.tolist()]) raise ValueError( f"y_true takes value in {{{classes_repr}}} and pos_label is not "