From 4481be6c15d4e4a2444c6d9ca62b42fc8cfc7597 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Thu, 2 Nov 2023 19:27:01 +0100 Subject: [PATCH 01/29] converted mae to array api --- sklearn/metrics/_regression.py | 11 +++--- sklearn/metrics/tests/test_common.py | 19 +++++++++++ sklearn/utils/_array_api.py | 49 +++++++++++++++++++++++++++ sklearn/utils/tests/test_array_api.py | 32 +++++++++++++++++ 4 files changed, 107 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 0259a3f41620c..80a456f2e1103 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -34,6 +34,7 @@ from scipy.special import xlogy from ..exceptions import UndefinedMetricWarning +from ..utils._array_api import _average, get_namespace from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params from ..utils.stats import _weighted_percentile from ..utils.validation import ( @@ -99,15 +100,16 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric"): just the corresponding argument if ``multioutput`` is a correct keyword. """ + xp, _ = get_namespace(y_true, y_pred) check_consistent_length(y_true, y_pred) y_true = check_array(y_true, ensure_2d=False, dtype=dtype) y_pred = check_array(y_pred, ensure_2d=False, dtype=dtype) if y_true.ndim == 1: - y_true = y_true.reshape((-1, 1)) + y_true = xp.reshape(y_true, (-1, 1)) if y_pred.ndim == 1: - y_pred = y_pred.reshape((-1, 1)) + y_pred = xp.reshape(y_pred, (-1, 1)) if y_true.shape[1] != y_pred.shape[1]: raise ValueError( @@ -204,11 +206,12 @@ def mean_absolute_error( >>> mean_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7]) 0.85... """ + xp, _ = get_namespace(y_true, y_pred) y_type, y_true, y_pred, multioutput = _check_reg_targets( y_true, y_pred, multioutput ) check_consistent_length(y_true, y_pred, sample_weight) - output_errors = np.average(np.abs(y_pred - y_true), weights=sample_weight, axis=0) + output_errors = _average(xp.abs(y_pred - y_true), weights=sample_weight, axis=0) if isinstance(multioutput, str): if multioutput == "raw_values": return output_errors @@ -216,7 +219,7 @@ def mean_absolute_error( # pass None as weights to np.average: uniform mean multioutput = None - return np.average(output_errors, weights=multioutput) + return _average(output_errors, weights=multioutput) @validate_params( diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index af652d1c90b41..a9a10fb1faed5 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1793,6 +1793,24 @@ def check_array_api_multiclass_classification_metric( ) +def check_array_api_regression_metric(metric, array_namespace, device, dtype): + y_true_np = np.array([3, -0.5, 2, 7]) + y_pred_np = np.array([2.5, 0.0, 2, 8]) + check_array_api_metric( + metric, array_namespace, device, dtype, y_true_np=y_true_np, y_pred_np=y_pred_np + ) + if "sample_weight" in signature(metric).parameters: + check_array_api_metric( + metric, + array_namespace, + device, + dtype, + y_true_np=y_true_np, + y_pred_np=y_pred_np, + sample_weight=np.array([0.0, 0.1, 2.0, 1.0]), + ) + + metric_checkers = { accuracy_score: [ check_array_api_binary_classification_metric, @@ -1802,6 +1820,7 @@ def check_array_api_multiclass_classification_metric( check_array_api_binary_classification_metric, check_array_api_multiclass_classification_metric, ], + mean_absolute_error: [check_array_api_regression_metric], } diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 24534faa931e8..64a61fd3cdd5f 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -595,3 +595,52 @@ def _estimator_with_converted_arrays(estimator, converter): def _atol_for_type(dtype): """Return the absolute tolerance for a given dtype.""" return numpy.finfo(dtype).eps * 100 + + +def _average(X, axis=None, weights=None): + """Compute the weighted average along the specified axis. + + This function is a port of numpy.average that supports the Array API. + + Please see the original documentation for more details: + https://numpy.org/doc/stable/reference/generated/numpy.average.html + + Parameters + ---------- + X : array object + Array containing data to be averaged. + axis : None or int or tuple of ints, optional + Axis or axes along which to average `X`. + weights : array_like, optional + An array of weights associated with the values in `X`. Each value in + `X` contributes to the average according to its associated weight. + + Returns + ------- + retval : array object + Return the average along the specified axis. + """ + xp, _ = get_namespace(X) + + if _is_numpy_namespace(xp): + return xp.asarray(numpy.average(X, axis=axis, weights=weights)) + + if weights is None: + return xp.mean(X, axis=axis) + + weights = xp.asarray(weights, device=device(X)) + if X.shape != weights.shape: + if axis is None: + raise TypeError( + "Axis must be specified when shapes of a and weights differ." + ) + if weights.ndim != 1: + raise TypeError("1D weights expected when shapes of a and weights differ.") + if weights.shape[0] != X.shape[axis]: + raise ValueError("Length of weights not compatible with specified axis.") + weights = xp.broadcast_to(weights, (X.ndim - 1) * (1,) + weights.shape) + weights = xp.swapaxes(weights, -1, axis) + weights_sum = xp.sum(weights, axis=axis) + if xp.any(weights_sum == 0): + raise ZeroDivisionError("Weights sum to zero, can't be normalized") + return xp.sum(xp.multiply(X, weights), axis=axis) / xp.sum(weights, axis=axis) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 866fd0e1d56f3..161b2786dafdc 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -10,6 +10,7 @@ _ArrayAPIWrapper, _asarray_with_order, _atol_for_type, + _average, _convert_to_numpy, _estimator_with_converted_arrays, _nanmax, @@ -371,3 +372,34 @@ def test_get_namespace_array_api_isdtype(wrapper): with pytest.raises(ValueError, match="Unrecognized data type"): assert xp.isdtype(xp.int16, "unknown") + + +@skip_if_array_api_compat_not_configured +@pytest.mark.parametrize( + "library", ["numpy", "numpy.array_api", "cupy", "cupy.array_api", "torch"] +) +@pytest.mark.parametrize( + "X,weights,axis,expected", + [ + [[0.0, 1.0, 2.0, 3.0], None, 0, 1.5], + [[0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 1.0, 2.0], 0, 2.25], + [[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], None, None, 3.5], + [[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], None, 1, [1.5, 5.5]], + [ + [[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], + [0.0, 1.0, 1.0, 2.0], + 1, + [2.25, 6.25], + ], + ], +) +def test__average(library, X, axis, weights, expected): + xp = pytest.importorskip(library) + + if isinstance(expected, list): + expected = xp.asarray(expected) + + with config_context(array_api_dispatch=True): + result = _average(xp.asarray(X), axis=axis, weights=weights) + + assert_allclose(result, expected) From 8b3dc8dba7826d582edf5ebb2ab882c89cd29616 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Mon, 6 Nov 2023 23:27:46 +0100 Subject: [PATCH 02/29] fixes for MPS device --- sklearn/metrics/tests/test_common.py | 11 ++++++++--- sklearn/utils/_array_api.py | 8 +++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index a9a10fb1faed5..0daf890dd2f65 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1743,9 +1743,14 @@ def check_array_api_metric( with config_context(array_api_dispatch=True): if sample_weight is not None: - sample_weight = xp.asarray(sample_weight, device=device) + sample_weight = xp.asarray(sample_weight.astype(dtype), device=device) metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight) + if not isinstance(metric_xp, float): + # if the result is not a scalar, the array has to be in the CPU + # before transforming it to numpy + metric_xp = xp.asarray(metric_xp, device="cpu") + assert_allclose( metric_xp, metric_np, @@ -1794,8 +1799,8 @@ def check_array_api_multiclass_classification_metric( def check_array_api_regression_metric(metric, array_namespace, device, dtype): - y_true_np = np.array([3, -0.5, 2, 7]) - y_pred_np = np.array([2.5, 0.0, 2, 8]) + y_true_np = np.array([3, -0.5, 2, 7], dtype="float32") + y_pred_np = np.array([2.5, 0.0, 2, 8], dtype="float32") check_array_api_metric( metric, array_namespace, device, dtype, y_true_np=y_true_np, y_pred_np=y_pred_np ) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 64a61fd3cdd5f..73333c2158dc4 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -462,13 +462,15 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None): sample_weight_np = None return float(numpy.average(sample_score_np, weights=sample_weight_np)) + # We move to cpu device ahead of time since certain devices may not support + # float64, but we want the same precision for all devices and namespaces. if not xp.isdtype(sample_score.dtype, "real floating"): - # We move to cpu device ahead of time since certain devices may not support - # float64, but we want the same precision for all devices and namespaces. sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64) if sample_weight is not None: - sample_weight = xp.asarray(sample_weight, dtype=sample_score.dtype) + sample_weight = xp.asarray( + xp.asarray(sample_weight, device="cpu"), dtype=sample_score.dtype + ) if not xp.isdtype(sample_weight.dtype, "real floating"): sample_weight = xp.astype(sample_weight, xp.float64) From 0f6ea74edd71a5a0cbfa8b6c687e0c4a2d0a4357 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Mon, 6 Nov 2023 23:39:40 +0100 Subject: [PATCH 03/29] updated docs --- doc/modules/array_api.rst | 1 + doc/whats_new/v1.4.rst | 11 ++++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index 777f0d1e2f17c..0a860533f13a3 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -104,6 +104,7 @@ Metrics ------- - :func:`sklearn.metrics.accuracy_score` +- :func:`sklearn.metrics.mean_absolute_error` - :func:`sklearn.metrics.zero_one_loss` Tools diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 00539260ffed6..acae24b6c8e2b 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -386,9 +386,14 @@ Changelog - |Enhancement| Added `neg_root_mean_squared_log_error_scorer` as scorer :pr:`26734` by :user:`Alejandro Martin Gil <101AlexMartin>`. -- |Enhancement| :func:`sklearn.metrics.accuracy_score` and - :func:`sklearn.metrics.zero_one_loss` now support Array API compatible inputs. - :pr:`27137` by :user:`Edoardo Abati `. +- |MajorFeature| The following metrics now support the + `Array API `_. See + :ref:`array_api` for more details. + + - :func:`sklearn.metrics.accuracy_score` :pr:`27137`` by :user:`Edoardo Abati ` + - :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati ` + - :func:`sklearn.metrics.zero_one_loss` :pr:`27137` by :user:`Edoardo Abati ` + - |API| Deprecated `needs_threshold` and `needs_proba` from :func:`metrics.make_scorer`. These parameters will be removed in version 1.6. Instead, use `response_method` that From 746d8eca1b9c7ad29c2aef1de9c2eb04fe7b877b Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 8 Nov 2023 08:40:56 +0100 Subject: [PATCH 04/29] returning float when scalar --- sklearn/metrics/_regression.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 1dcb3e9a2e443..840edfae58c4d 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -182,7 +182,7 @@ def mean_absolute_error( Returns ------- - loss : float or ndarray of floats + loss : float or array of floats If multioutput is 'raw_values', then mean absolute error is returned for each output separately. If multioutput is 'uniform_average' or an ndarray of weights, then the @@ -219,7 +219,10 @@ def mean_absolute_error( # pass None as weights to np.average: uniform mean multioutput = None - return _average(output_errors, weights=multioutput) + mean_absolute_error = _average(output_errors, weights=multioutput) + if mean_absolute_error.shape == (): + return float(mean_absolute_error) + return mean_absolute_error @validate_params( From 4b2604e8cf3fda4d79ac4e4f61519cdd476e275b Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 14 Nov 2023 19:56:59 +0100 Subject: [PATCH 05/29] fixed comment Co-authored-by: Olivier Grisel --- sklearn/metrics/tests/test_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 0daf890dd2f65..6f80f8763d68a 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1747,8 +1747,8 @@ def check_array_api_metric( metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight) if not isinstance(metric_xp, float): - # if the result is not a scalar, the array has to be in the CPU - # before transforming it to numpy + # If the result is not a scalar, the array has to be in the CPU + # before transforming it to a numpy array. metric_xp = xp.asarray(metric_xp, device="cpu") assert_allclose( From bc9e6f86809cb3180f8129ca8fb4f7a41e53917b Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 14 Nov 2023 19:57:36 +0100 Subject: [PATCH 06/29] added float32 comment Co-authored-by: Olivier Grisel --- sklearn/metrics/tests/test_common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 6f80f8763d68a..1ee74374c1965 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1799,6 +1799,8 @@ def check_array_api_multiclass_classification_metric( def check_array_api_regression_metric(metric, array_namespace, device, dtype): + # Not all Array API / device combinations support `float64` values, hence + # limit this test to the `float32` case for now. y_true_np = np.array([3, -0.5, 2, 7], dtype="float32") y_pred_np = np.array([2.5, 0.0, 2, 8], dtype="float32") check_array_api_metric( From cb3613dd0b65d0e8f61babb6edac9d8bc856f2a8 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 14 Nov 2023 20:05:22 +0100 Subject: [PATCH 07/29] improved error message --- sklearn/utils/_array_api.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 73333c2158dc4..9571e2a2bac8a 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -634,10 +634,14 @@ def _average(X, axis=None, weights=None): if X.shape != weights.shape: if axis is None: raise TypeError( - "Axis must be specified when shapes of a and weights differ." + "Axis must be specified when shapes of X and weights differ. " + f"Got {X.shape = } and {weights.shape = }." ) if weights.ndim != 1: - raise TypeError("1D weights expected when shapes of a and weights differ.") + raise TypeError( + "1D weights expected when shapes of X and weights differ. " + f"Got {X.shape = } and {weights.shape = }." + ) if weights.shape[0] != X.shape[axis]: raise ValueError("Length of weights not compatible with specified axis.") weights = xp.broadcast_to(weights, (X.ndim - 1) * (1,) + weights.shape) From 0b58f4bbad7d5569318a47579de87a63f1428f16 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 14 Nov 2023 20:14:35 +0100 Subject: [PATCH 08/29] added test with axis=0 --- sklearn/utils/tests/test_array_api.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 161b2786dafdc..d94190780ea4e 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -385,12 +385,19 @@ def test_get_namespace_array_api_isdtype(wrapper): [[0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 1.0, 2.0], 0, 2.25], [[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], None, None, 3.5], [[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], None, 1, [1.5, 5.5]], + [[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], None, 0, [2.0, 3.0, 4.0, 5.0]], [ [[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], [0.0, 1.0, 1.0, 2.0], 1, [2.25, 6.25], ], + [ + [[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], + [0.5, 2.0], + 0, + [3.2, 4.2, 5.2, 6.2], + ], ], ) def test__average(library, X, axis, weights, expected): From 002cacf047dfc7ea4bf57f116fffbcf16cedcaeb Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 14 Nov 2023 20:39:32 +0100 Subject: [PATCH 09/29] added error tests --- sklearn/utils/tests/test_array_api.py | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index d94190780ea4e..47fe9c9fdd68b 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -410,3 +410,39 @@ def test__average(library, X, axis, weights, expected): result = _average(xp.asarray(X), axis=axis, weights=weights) assert_allclose(result, expected) + + +@skip_if_array_api_compat_not_configured +@pytest.mark.parametrize("library", ["cupy", "cupy.array_api", "torch"]) +@pytest.mark.parametrize( + "axis,weights,expected_error,expected_msg", + [ + ( + None, + [0.5, 2.0], + TypeError, + "Axis must be specified when shapes of X and weights differ.", + ), + ( + 0, + [[0.5], [2.0]], + TypeError, + "1D weights expected when shapes of X and weights differ.", + ), + ( + 1, + [0.5, 2.0], + ValueError, + "Length of weights not compatible with specified axis.", + ), + (0, [0.5, -0.5], ZeroDivisionError, "Weights sum to zero, can't be normalized"), + ], +) +def test__average_error(library, axis, weights, expected_error, expected_msg): + xp = pytest.importorskip(library) + + X = xp.asarray([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=xp.float32) + + with config_context(array_api_dispatch=True): + with pytest.raises(expected_error, match=expected_msg): + _average(X, axis=axis, weights=weights) From 6f452a67a7649172fbd1add4a18f6b39157643b7 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 14 Nov 2023 20:53:53 +0100 Subject: [PATCH 10/29] fix to dtype=float32 --- sklearn/metrics/tests/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 1ee74374c1965..1f9d76a9a8da4 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1814,7 +1814,7 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype): dtype, y_true_np=y_true_np, y_pred_np=y_pred_np, - sample_weight=np.array([0.0, 0.1, 2.0, 1.0]), + sample_weight=np.array([0.0, 0.1, 2.0, 1.0], dtype="float32"), ) From 48bd567fd5403ccbec65f69189906240df16bca4 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 15 Nov 2023 00:01:51 +0100 Subject: [PATCH 11/29] test multioutput --- sklearn/metrics/tests/test_common.py | 30 +++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 1f9d76a9a8da4..411c9217616e0 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1818,6 +1818,31 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype): ) +def check_array_api_multioutput_regression_metric( + metric, array_namespace, device, dtype +): + # Not all Array API / device combinations support `float64` values, hence + # limit this test to the `float32` case for now. + y_true_np = np.array([[0.5, 1], [-1, 1], [7, -6]], dtype="float32") + y_pred_np = np.array([[0, 2], [-1, 2], [8, -5]], dtype="float32") + + metric = partial(metric, multioutput="raw_values") + + check_array_api_metric( + metric, array_namespace, device, dtype, y_true_np=y_true_np, y_pred_np=y_pred_np + ) + if "sample_weight" in signature(metric).parameters: + check_array_api_metric( + metric, + array_namespace, + device, + dtype, + y_true_np=y_true_np, + y_pred_np=y_pred_np, + sample_weight=np.array([0.0, 0.1, 2.0], dtype="float32"), + ) + + metric_checkers = { accuracy_score: [ check_array_api_binary_classification_metric, @@ -1827,7 +1852,10 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype): check_array_api_binary_classification_metric, check_array_api_multiclass_classification_metric, ], - mean_absolute_error: [check_array_api_regression_metric], + mean_absolute_error: [ + check_array_api_regression_metric, + check_array_api_multioutput_regression_metric, + ], } From 855aad03a6044f3ab091f71f771731d46ca90adb Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 15 Nov 2023 00:15:26 +0100 Subject: [PATCH 12/29] Update sklearn/metrics/_regression.py Co-authored-by: Olivier Grisel --- sklearn/metrics/_regression.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 840edfae58c4d..3c97ae9a809b7 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -219,10 +219,15 @@ def mean_absolute_error( # pass None as weights to np.average: uniform mean multioutput = None + # Average across the outputs (if needed). mean_absolute_error = _average(output_errors, weights=multioutput) - if mean_absolute_error.shape == (): - return float(mean_absolute_error) - return mean_absolute_error + + # Since `y_pred.ndim <= 2` and `y_true.ndim <= 2`, the second call to _average + # should always return a scalar array that we convert to a Python float to + # consistently return the same eager evaluated value, irrespective of the + # Array API implementation. + assert mean_absolute_error.shape == () + return float(mean_absolute_error) @validate_params( From d3378d5ecd1a3763832db821337e1a0d311f72a4 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 15 Nov 2023 08:52:33 +0100 Subject: [PATCH 13/29] fix linting --- sklearn/metrics/_regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 3c97ae9a809b7..a7e6b7b1b18eb 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -221,7 +221,7 @@ def mean_absolute_error( # Average across the outputs (if needed). mean_absolute_error = _average(output_errors, weights=multioutput) - + # Since `y_pred.ndim <= 2` and `y_true.ndim <= 2`, the second call to _average # should always return a scalar array that we convert to a Python float to # consistently return the same eager evaluated value, irrespective of the From 64471851945c23df79c93b20e0fc7b119b5c0ef5 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 15 Nov 2023 23:23:31 +0100 Subject: [PATCH 14/29] added new error message --- sklearn/utils/_array_api.py | 5 ++++- sklearn/utils/tests/test_array_api.py | 7 +------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 9571e2a2bac8a..c48a9e80e43ff 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -643,7 +643,10 @@ def _average(X, axis=None, weights=None): f"Got {X.shape = } and {weights.shape = }." ) if weights.shape[0] != X.shape[axis]: - raise ValueError("Length of weights not compatible with specified axis.") + raise ValueError( + f"Length of weights ({weights.shape=}) not compatible with " + f"{X.shape=} and {axis=}." + ) weights = xp.broadcast_to(weights, (X.ndim - 1) * (1,) + weights.shape) weights = xp.swapaxes(weights, -1, axis) weights_sum = xp.sum(weights, axis=axis) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 47fe9c9fdd68b..e67828202d60e 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -429,12 +429,7 @@ def test__average(library, X, axis, weights, expected): TypeError, "1D weights expected when shapes of X and weights differ.", ), - ( - 1, - [0.5, 2.0], - ValueError, - "Length of weights not compatible with specified axis.", - ), + (1, [0.5, 2.0], ValueError, "Length of weights"), (0, [0.5, -0.5], ZeroDivisionError, "Weights sum to zero, can't be normalized"), ], ) From b390da32cd7963fd8e24262de300aa863795e459 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Mon, 27 Nov 2023 22:45:13 +0100 Subject: [PATCH 15/29] using _convert_to_numpy in tests --- sklearn/metrics/tests/test_common.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 411c9217616e0..84284da2be3a9 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -55,6 +55,7 @@ from sklearn.utils import shuffle from sklearn.utils._array_api import ( _atol_for_type, + _convert_to_numpy, yield_namespace_device_dtype_combinations, ) from sklearn.utils._testing import ( @@ -1747,9 +1748,7 @@ def check_array_api_metric( metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight) if not isinstance(metric_xp, float): - # If the result is not a scalar, the array has to be in the CPU - # before transforming it to a numpy array. - metric_xp = xp.asarray(metric_xp, device="cpu") + metric_xp = _convert_to_numpy(metric_xp, xp) assert_allclose( metric_xp, From b8fe7c496de2b07712c1d00854bb3c73b4cdbb66 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Mon, 27 Nov 2023 23:39:06 +0100 Subject: [PATCH 16/29] adding xfail for cupy.array_api --- sklearn/metrics/tests/test_common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 84284da2be3a9..6d420641f1d04 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1869,4 +1869,10 @@ def yield_metric_checker_combinations(metric_checkers=metric_checkers): ) @pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations()) def test_array_api_compliance(metric, array_namespace, device, dtype, check_func): + if ( + metric == mean_absolute_error + and check_func == check_array_api_regression_metric + and array_namespace == "cupy.array_api" + ): + pytest.xfail(reason="module 'cupy.array_api' has no attribute 'swapaxes'") check_func(metric, array_namespace, device, dtype) From d3cf189cdd1344839648de935bdbd138d18b43d5 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 12 Dec 2023 11:12:01 +0100 Subject: [PATCH 17/29] Update sklearn/utils/_array_api.py Co-authored-by: Olivier Grisel --- sklearn/utils/_array_api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index c48a9e80e43ff..543aca6c5f266 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -469,6 +469,9 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None): if sample_weight is not None: sample_weight = xp.asarray( + # TODO: remove the successive xp.asarray calls once + # https://github.com/data-apis/array-api/issues/647 is widely + # implemented by all supported Array API libraries. xp.asarray(sample_weight, device="cpu"), dtype=sample_score.dtype ) if not xp.isdtype(sample_weight.dtype, "real floating"): From fcd20784663e37748efc3b600ee1cd98d2766a59 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 12 Mar 2024 17:46:29 +0100 Subject: [PATCH 18/29] remove redundant test --- sklearn/utils/tests/test_array_api.py | 69 --------------------------- 1 file changed, 69 deletions(-) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index c79abf806e7a0..6b5f600fb01c2 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -473,72 +473,3 @@ def test_get_namespace_array_api_isdtype(wrapper): with pytest.raises(ValueError, match="Unrecognized data type"): assert xp.isdtype(xp.int16, "unknown") - - -@skip_if_array_api_compat_not_configured -@pytest.mark.parametrize( - "library", ["numpy", "numpy.array_api", "cupy", "cupy.array_api", "torch"] -) -@pytest.mark.parametrize( - "X,weights,axis,expected", - [ - [[0.0, 1.0, 2.0, 3.0], None, 0, 1.5], - [[0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 1.0, 2.0], 0, 2.25], - [[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], None, None, 3.5], - [[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], None, 1, [1.5, 5.5]], - [[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], None, 0, [2.0, 3.0, 4.0, 5.0]], - [ - [[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], - [0.0, 1.0, 1.0, 2.0], - 1, - [2.25, 6.25], - ], - [ - [[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], - [0.5, 2.0], - 0, - [3.2, 4.2, 5.2, 6.2], - ], - ], -) -def test__average(library, X, axis, weights, expected): - xp = pytest.importorskip(library) - - if isinstance(expected, list): - expected = xp.asarray(expected) - - with config_context(array_api_dispatch=True): - result = _average(xp.asarray(X), axis=axis, weights=weights) - - assert_allclose(result, expected) - - -@skip_if_array_api_compat_not_configured -@pytest.mark.parametrize("library", ["cupy", "cupy.array_api", "torch"]) -@pytest.mark.parametrize( - "axis,weights,expected_error,expected_msg", - [ - ( - None, - [0.5, 2.0], - TypeError, - "Axis must be specified when shapes of X and weights differ.", - ), - ( - 0, - [[0.5], [2.0]], - TypeError, - "1D weights expected when shapes of X and weights differ.", - ), - (1, [0.5, 2.0], ValueError, "Length of weights"), - (0, [0.5, -0.5], ZeroDivisionError, "Weights sum to zero, can't be normalized"), - ], -) -def test__average_error(library, axis, weights, expected_error, expected_msg): - xp = pytest.importorskip(library) - - X = xp.asarray([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=xp.float32) - - with config_context(array_api_dispatch=True): - with pytest.raises(expected_error, match=expected_msg): - _average(X, axis=axis, weights=weights) From dab04a026326e52a925a3a2c2a0116650b3ddf9b Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 12 Mar 2024 17:52:52 +0100 Subject: [PATCH 19/29] updated relevant release note --- doc/whats_new/v1.4.rst | 4 +--- doc/whats_new/v1.5.rst | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 790ebda9d9ee9..fdea14b41d056 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -434,9 +434,7 @@ See :ref:`array_api` for more details. :pr:`27137` by :user:`Edoardo Abati `; - :func:`sklearn.model_selection.train_test_split` in :pr:`26855` by `Tim Head`_; - :func:`~utils.multiclass.is_multilabel` in :pr:`27601` by - :user:`Yaroslav Korobko `; -- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati `. - + :user:`Yaroslav Korobko `. **Classes:** diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index aef2d6d8af2a9..53f8812fa2de7 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -34,7 +34,8 @@ See :ref:`array_api` for more details. - :func:`sklearn.metrics.r2_score` now supports Array API compliant inputs. :pr:`27904` by :user:`Eric Lindgren `, `Franck Charras `, - `Olivier Grisel ` and `Tim Head `. + `Olivier Grisel ` and `Tim Head `; +- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati `. **Classes:** From ce388726e5c36885c17a6d1bc7d71422ada68bf0 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 13 Mar 2024 08:21:57 +0100 Subject: [PATCH 20/29] re enabled check_array_api_multioutput_regression_metric --- sklearn/metrics/tests/test_common.py | 56 ++++------------------------ 1 file changed, 8 insertions(+), 48 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 9dc9c74cfc223..95c84db660342 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1817,53 +1817,6 @@ def check_array_api_multiclass_classification_metric( ) -# def check_array_api_regression_metric(metric, array_namespace, device, dtype): -# # Not all Array API / device combinations support `float64` values, hence -# # limit this test to the `float32` case for now. -# y_true_np = np.array([3, -0.5, 2, 7], dtype="float32") -# y_pred_np = np.array([2.5, 0.0, 2, 8], dtype="float32") -# check_array_api_metric( -# metric, array_namespace, device, dtype, y_true_np=y_true_np, -# y_pred_np=y_pred_np -# ) -# if "sample_weight" in signature(metric).parameters: -# check_array_api_metric( -# metric, -# array_namespace, -# device, -# dtype, -# y_true_np=y_true_np, -# y_pred_np=y_pred_np, -# sample_weight=np.array([0.0, 0.1, 2.0, 1.0], dtype="float32"), -# ) - - -# def check_array_api_multioutput_regression_metric( -# metric, array_namespace, device, dtype -# ): -# # Not all Array API / device combinations support `float64` values, hence -# # limit this test to the `float32` case for now. -# y_true_np = np.array([[0.5, 1], [-1, 1], [7, -6]], dtype="float32") -# y_pred_np = np.array([[0, 2], [-1, 2], [8, -5]], dtype="float32") - -# metric = partial(metric, multioutput="raw_values") - -# check_array_api_metric( -# metric, array_namespace, device, dtype, y_true_np=y_true_np, -# y_pred_np=y_pred_np -# ) -# if "sample_weight" in signature(metric).parameters: -# check_array_api_metric( -# metric, -# array_namespace, -# device, -# dtype, -# y_true_np=y_true_np, -# y_pred_np=y_pred_np, -# sample_weight=np.array([0.0, 0.1, 2.0], dtype="float32"), -# ) - - def check_array_api_regression_metric(metric, array_namespace, device, dtype_name): y_true_np = np.array([[1, 3], [1, 2]], dtype=dtype_name) y_pred_np = np.array([[1, 4], [1, 1]], dtype=dtype_name) @@ -1891,6 +1844,13 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam ) +def check_array_api_multioutput_regression_metric( + metric, array_namespace, device, dtype_name +): + metric = partial(metric, multioutput="raw_values") + check_array_api_regression_metric(metric, array_namespace, device, dtype_name) + + array_api_metric_checkers = { accuracy_score: [ check_array_api_binary_classification_metric, @@ -1902,7 +1862,7 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam ], mean_absolute_error: [ check_array_api_regression_metric, - # check_array_api_multioutput_regression_metric, + check_array_api_multioutput_regression_metric, ], r2_score: [check_array_api_regression_metric], } From fcf5ee4a1ecc142c1b4453279440db2a3b0ac88b Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 13 Mar 2024 08:28:48 +0100 Subject: [PATCH 21/29] removed the cast to float --- sklearn/metrics/_regression.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 4811903808731..e16b3fab95584 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -226,14 +226,7 @@ def mean_absolute_error( multioutput = None # Average across the outputs (if needed). - mean_absolute_error = _average(output_errors, weights=multioutput) - - # Since `y_pred.ndim <= 2` and `y_true.ndim <= 2`, the second call to _average - # should always return a scalar array that we convert to a Python float to - # consistently return the same eager evaluated value, irrespective of the - # Array API implementation. - assert mean_absolute_error.shape == () - return float(mean_absolute_error) + return _average(output_errors, weights=multioutput) @validate_params( From 58b799eba31a4b05898329779ed20359d5c44064 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 13 Mar 2024 18:08:53 +0100 Subject: [PATCH 22/29] readded cast to float --- sklearn/metrics/_regression.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index e16b3fab95584..4811903808731 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -226,7 +226,14 @@ def mean_absolute_error( multioutput = None # Average across the outputs (if needed). - return _average(output_errors, weights=multioutput) + mean_absolute_error = _average(output_errors, weights=multioutput) + + # Since `y_pred.ndim <= 2` and `y_true.ndim <= 2`, the second call to _average + # should always return a scalar array that we convert to a Python float to + # consistently return the same eager evaluated value, irrespective of the + # Array API implementation. + assert mean_absolute_error.shape == () + return float(mean_absolute_error) @validate_params( From 0e7ed2a91f192742d585f70e28ca599ce0d1cc7c Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 13 Mar 2024 18:36:34 +0100 Subject: [PATCH 23/29] match r2_score changes --- sklearn/metrics/_regression.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 4811903808731..1205ee8e48f75 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -212,12 +212,19 @@ def mean_absolute_error( >>> mean_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7]) 0.85... """ - xp, _ = get_namespace(y_true, y_pred) - y_type, y_true, y_pred, multioutput = _check_reg_targets( - y_true, y_pred, multioutput + input_arrays = [y_true, y_pred, sample_weight, multioutput] + xp, _ = get_namespace(*input_arrays) + + dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp) + + _, y_true, y_pred, multioutput = _check_reg_targets( + y_true, y_pred, multioutput, dtype=dtype, xp=xp ) check_consistent_length(y_true, y_pred, sample_weight) - output_errors = _average(xp.abs(y_pred - y_true), weights=sample_weight, axis=0) + + output_errors = _average( + xp.abs(y_pred - y_true), weights=sample_weight, axis=0, xp=xp + ) if isinstance(multioutput, str): if multioutput == "raw_values": return output_errors From 44038743c21348ac9f64145346eea92974bae537 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Mon, 25 Mar 2024 18:19:26 +0100 Subject: [PATCH 24/29] Trigger CI From 3c311b02f5e6828200d03f0a2c3bc8151292e257 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Mon, 25 Mar 2024 18:31:45 +0100 Subject: [PATCH 25/29] add missing :user: --- doc/whats_new/v1.5.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 53f8812fa2de7..d203707b54ab3 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -33,8 +33,8 @@ See :ref:`array_api` for more details. **Functions:** - :func:`sklearn.metrics.r2_score` now supports Array API compliant inputs. - :pr:`27904` by :user:`Eric Lindgren `, `Franck Charras `, - `Olivier Grisel ` and `Tim Head `; + :pr:`27904` by :user:`Eric Lindgren `, :user:`Franck Charras `, + :user:`Olivier Grisel ` and :user:`Tim Head `; - :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati `. **Classes:** From 5cc9c25826ce3e30f636bafa0d008ad17e86d917 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 8 May 2024 12:16:49 +0200 Subject: [PATCH 26/29] moved to 1.6 whatsnew --- doc/whats_new/v1.5.rst | 1 - doc/whats_new/v1.6.rst | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 05d2f1e036eb1..9ba894e0e4157 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -80,7 +80,6 @@ See :ref:`array_api` for more details. - :func:`sklearn.metrics.r2_score` now supports Array API compliant inputs. :pr:`27904` by :user:`Eric Lindgren `, :user:`Franck Charras `, :user:`Olivier Grisel ` and :user:`Tim Head `; -- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati `. **Classes:** diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 6eda6717b3d1b..24fc62beb15e1 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -35,10 +35,11 @@ See :ref:`array_api` for more details. - :func:`sklearn.metrics.mean_tweedie_deviance` now supports Array API compatible inputs. :pr:`28106` by :user:`Thomas Li ` +- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati `. **Classes:** -- +- Changelog --------- From 18d3ec320a0214296754505840e333754a4bf6f2 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 14 May 2024 18:20:46 +0200 Subject: [PATCH 27/29] removed reduntant conversion --- sklearn/metrics/tests/test_common.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index f4c8f6a0eeb92..ae47ffe3d6a56 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1758,9 +1758,6 @@ def check_array_api_metric( with config_context(array_api_dispatch=True): metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight) - if not isinstance(metric_xp, float): - metric_xp = _convert_to_numpy(metric_xp, xp) - assert_allclose( _convert_to_numpy(xp.asarray(metric_xp), xp), metric_np, From e0dea5b0c0a33e58ed6b8cb243fbe6ce0eb3c2d3 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 15 May 2024 11:06:34 +0200 Subject: [PATCH 28/29] Revert unrelated change to 1.5. --- doc/whats_new/v1.5.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 1b9a8c6d61f3e..4eb6ecd9264f1 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -82,7 +82,7 @@ See :ref:`array_api` for more details. **Functions:** - :func:`sklearn.metrics.r2_score` now supports Array API compliant inputs. - :pr:`27904` by :user:`Eric Lindgren `, :user:`Franck Charras `, + :pr:`27904` by :user:`Eric Lindgren `, `Franck Charras `, :user:`Olivier Grisel ` and :user:`Tim Head `; **Classes:** From 0fa23d311cf2ca2a4c355917bbc2dc363ecd5a6c Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 15 May 2024 11:08:05 +0200 Subject: [PATCH 29/29] Revert more unrelated changelog. --- doc/whats_new/v1.5.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 4eb6ecd9264f1..5fdc0707ffbee 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -83,7 +83,7 @@ See :ref:`array_api` for more details. - :func:`sklearn.metrics.r2_score` now supports Array API compliant inputs. :pr:`27904` by :user:`Eric Lindgren `, `Franck Charras `, - :user:`Olivier Grisel ` and :user:`Tim Head `; + `Olivier Grisel ` and `Tim Head `. **Classes:**