Skip to content

ENH Add Array API compatibility to mean_absolute_error #27736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
4481be6
converted mae to array api
EdAbati Nov 2, 2023
8b3dc8d
fixes for MPS device
EdAbati Nov 6, 2023
451352c
Merge remote-tracking branch 'origin/main' into mae-array-api
EdAbati Nov 6, 2023
0f6ea74
updated docs
EdAbati Nov 6, 2023
746d8ec
returning float when scalar
EdAbati Nov 8, 2023
4b2604e
fixed comment
EdAbati Nov 14, 2023
bc9e6f8
added float32 comment
EdAbati Nov 14, 2023
cb3613d
improved error message
EdAbati Nov 14, 2023
0b58f4b
added test with axis=0
EdAbati Nov 14, 2023
002cacf
added error tests
EdAbati Nov 14, 2023
6f452a6
fix to dtype=float32
EdAbati Nov 14, 2023
48bd567
test multioutput
EdAbati Nov 14, 2023
855aad0
Update sklearn/metrics/_regression.py
EdAbati Nov 14, 2023
80450bf
Merge branch 'main' into mae-array-api
EdAbati Nov 14, 2023
c4d7ef9
Merge branch 'main' into mae-array-api
EdAbati Nov 15, 2023
d3378d5
fix linting
EdAbati Nov 15, 2023
6447185
added new error message
EdAbati Nov 15, 2023
b390da3
using _convert_to_numpy in tests
EdAbati Nov 27, 2023
7fa04e2
Merge branch 'main' into mae-array-api
EdAbati Nov 27, 2023
b8fe7c4
adding xfail for cupy.array_api
EdAbati Nov 27, 2023
d3cf189
Update sklearn/utils/_array_api.py
EdAbati Dec 12, 2023
adc2809
Merge remote-tracking branch 'upstream/main' into mae-array-api
EdAbati Dec 12, 2023
6fc2978
Merge remote-tracking branch 'upstream/main' into mae-array-api
EdAbati Mar 12, 2024
fcd2078
remove redundant test
EdAbati Mar 12, 2024
dab04a0
updated relevant release note
EdAbati Mar 12, 2024
ce38872
re enabled check_array_api_multioutput_regression_metric
EdAbati Mar 13, 2024
fcf5ee4
removed the cast to float
EdAbati Mar 13, 2024
58b799e
readded cast to float
EdAbati Mar 13, 2024
0e7ed2a
match r2_score changes
EdAbati Mar 13, 2024
4403874
Trigger CI
ogrisel Mar 25, 2024
3c311b0
add missing :user:
EdAbati Mar 25, 2024
197fe3e
Merge remote-tracking branch 'upstream/main' into mae-array-api
EdAbati May 8, 2024
5cc9c25
moved to 1.6 whatsnew
EdAbati May 8, 2024
18d3ec3
removed reduntant conversion
EdAbati May 14, 2024
6ab0d16
Merge remote-tracking branch 'upstream/main' into mae-array-api
EdAbati May 14, 2024
e0dea5b
Revert unrelated change to 1.5.
ogrisel May 15, 2024
0fa23d3
Revert more unrelated changelog.
ogrisel May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Metrics
-------

- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.mean_absolute_error`
- :func:`sklearn.metrics.mean_tweedie_deviance`
- :func:`sklearn.metrics.r2_score`
- :func:`sklearn.metrics.zero_one_loss`
Expand Down
1 change: 1 addition & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ See :ref:`array_api` for more details.
- :func:`sklearn.metrics.mean_tweedie_deviance` now supports Array API compatible
inputs.
:pr:`28106` by :user:`Thomas Li <lithomas1>`
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`.

**Classes:**

Expand Down
26 changes: 21 additions & 5 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def mean_absolute_error(

Returns
-------
loss : float or ndarray of floats
loss : float or array of floats
If multioutput is 'raw_values', then mean absolute error is returned
for each output separately.
If multioutput is 'uniform_average' or an ndarray of weights, then the
Expand All @@ -213,19 +213,35 @@ def mean_absolute_error(
>>> mean_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
0.85...
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput
input_arrays = [y_true, y_pred, sample_weight, multioutput]
xp, _ = get_namespace(*input_arrays)

dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)

_, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype, xp=xp
)
check_consistent_length(y_true, y_pred, sample_weight)
output_errors = np.average(np.abs(y_pred - y_true), weights=sample_weight, axis=0)

output_errors = _average(
xp.abs(y_pred - y_true), weights=sample_weight, axis=0, xp=xp
)
if isinstance(multioutput, str):
if multioutput == "raw_values":
return output_errors
elif multioutput == "uniform_average":
# pass None as weights to np.average: uniform mean
multioutput = None

return np.average(output_errors, weights=multioutput)
# Average across the outputs (if needed).
mean_absolute_error = _average(output_errors, weights=multioutput)

# Since `y_pred.ndim <= 2` and `y_true.ndim <= 2`, the second call to _average
# should always return a scalar array that we convert to a Python float to
# consistently return the same eager evaluated value, irrespective of the
# Array API implementation.
assert mean_absolute_error.shape == ()
return float(mean_absolute_error)


@validate_params(
Expand Down
11 changes: 11 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,13 @@ def check_array_api_regression_metric_multioutput(
)


def check_array_api_multioutput_regression_metric(
metric, array_namespace, device, dtype_name
):
metric = partial(metric, multioutput="raw_values")
check_array_api_regression_metric(metric, array_namespace, device, dtype_name)


array_api_metric_checkers = {
accuracy_score: [
check_array_api_binary_classification_metric,
Expand All @@ -1893,6 +1900,10 @@ def check_array_api_regression_metric_multioutput(
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
mean_absolute_error: [
check_array_api_regression_metric,
check_array_api_multioutput_regression_metric,
],
}


Expand Down
Loading