Skip to content

ENH Reduce redundancy in floating type checks for Array API support in _regression.py #30128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a5f8460
ENH Reduce redundancy in floating type checks for Array API support
virchan Oct 21, 2024
92c048e
ENH Reduce redundancy in floating type checks for Array API support
virchan Oct 21, 2024
69e00cc
Updating `_regression.py` and `test_common.py`.
virchan Oct 22, 2024
3a2ed0e
Updating `test_common.py`.
virchan Oct 22, 2024
eec32bb
Update changelog.
virchan Oct 22, 2024
e7a2602
Update `_regression.py`.
virchan Oct 22, 2024
ad5cdea
Update `_regression.py`.
virchan Oct 22, 2024
8cc2583
Update sklearn/metrics/_regression.py
virchan Oct 22, 2024
b940827
Update sklearn/metrics/_regression.py
virchan Oct 22, 2024
dafacaf
Update sklearn/metrics/_regression.py
virchan Oct 22, 2024
1ba7c37
Remove changelog and add `xp` back to `max_error`.
virchan Oct 22, 2024
f53172e
Update `_check_reg_targets` doc-string.
virchan Oct 22, 2024
0e3cfae
Update variable name to `dtype_name`.
virchan Oct 23, 2024
4b844a0
Update `r2_score`.
virchan Oct 24, 2024
08ee432
Fix linting.
virchan Oct 24, 2024
ef944fd
Merge branch 'main' into issues/30106_floating_dtype_regression_Array…
virchan Oct 24, 2024
5029c9d
Merge branch 'main' into issues/30106_floating_dtype_regression_Array…
virchan Oct 25, 2024
f9ee0f7
Update `_check_reg_targets_and_floating_dtype`.
virchan Oct 25, 2024
fc0f593
Update sklearn/metrics/_regression.py
virchan Oct 25, 2024
635c883
Update sklearn/metrics/_regression.py
virchan Oct 25, 2024
bbde963
Merge branch 'main' into issues/30106_floating_dtype_regression_Array…
virchan Oct 25, 2024
eccc7f1
Rename `_check_reg_targets_and_floating_dtype` to `_check_reg_targets…
virchan Oct 25, 2024
048551c
Merge branch 'main' into issues/30106_floating_dtype_regression_Array…
virchan Oct 29, 2024
e7533ea
Merge branch 'main' into issues/30106_floating_dtype_regression_Array…
virchan Oct 31, 2024
0959a06
Merge branch 'main' into issues/30106_floating_dtype_regression_Array…
virchan Nov 1, 2024
64d5e95
Updating `_check_reg_targets` and _check_reg_targets_with_floating_dt…
virchan Nov 2, 2024
e772fe2
Updating `_check_reg_targets_with_floating_dtype` to convert `sample_…
virchan Nov 2, 2024
ed6224c
Updating `_check_reg_targets_with_floating_dtype`
virchan Nov 2, 2024
c507513
Updating `_check_reg_targets_with_floating_dtype`.
virchan Nov 2, 2024
8d7507d
Updating `_check_reg_targets_with_floating_dtype`.
virchan Nov 2, 2024
63c34f5
Updating `_check_reg_targets_with_floating_dtype`.
virchan Nov 2, 2024
ee66bd2
Updating `_check_reg_targets_with_floating_dtype`.
virchan Nov 3, 2024
35f3e5e
Merge branch 'main' into issues/30106_floating_dtype_regression_Array…
virchan Nov 18, 2024
e868d9d
Update `_check_reg_targets_with_floating_dtype` doc-string.
virchan Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 104 additions & 38 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,16 @@
def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
"""Check that y_true and y_pred belong to the same regression task.

To reduce redundancy when calling `_find_matching_floating_dtype`,
please use `_check_reg_targets_with_floating_dtype` instead.

Parameters
----------
y_true : array-like
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
Ground truth (correct) target values.

y_pred : array-like
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
Estimated target values.

multioutput : array-like or string in ['raw_values', uniform_average',
'variance_weighted'] or None
Expand Down Expand Up @@ -137,6 +142,71 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
return y_type, y_true, y_pred, multioutput


def _check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=None
):
"""Ensures that y_true, y_pred, and sample_weight correspond to the same
regression task.

Extends `_check_reg_targets` by automatically selecting a suitable floating-point
data type for inputs using `_find_matching_floating_dtype`.

Use this private method only when converting inputs to array API-compatibles.

Parameters
----------
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
Ground truth (correct) target values.

y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
Estimated target values.

sample_weight : array-like of shape (n_samples,)

multioutput : array-like or string in ['raw_values', 'uniform_average', \
'variance_weighted'] or None
None is accepted due to backward compatibility of r2_score().

xp : module, default=None
Precomputed array namespace module. When passed, typically from a caller
that has already performed inspection of its own inputs, skips array
namespace inspection.

Returns
-------
type_true : one of {'continuous', 'continuous-multioutput'}
The type of the true target data, as output by
'utils.multiclass.type_of_target'.

y_true : array-like of shape (n_samples, n_outputs)
Ground truth (correct) target values.

y_pred : array-like of shape (n_samples, n_outputs)
Estimated target values.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

multioutput : array-like of shape (n_outputs) or string in ['raw_values', \
'uniform_average', 'variance_weighted'] or None
Custom output weights if ``multioutput`` is array-like or
just the corresponding argument if ``multioutput`` is a
correct keyword.
"""
dtype_name = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)

y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype_name, xp=xp
)

# _check_reg_targets does not accept sample_weight as input.
# Convert sample_weight's data type separately to match dtype_name.
if sample_weight is not None:
sample_weight = xp.asarray(sample_weight, dtype=dtype_name)

return y_type, y_true, y_pred, sample_weight, multioutput


@validate_params(
{
"y_true": ["array-like"],
Expand Down Expand Up @@ -201,14 +271,14 @@ def mean_absolute_error(
>>> mean_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
0.85...
"""
input_arrays = [y_true, y_pred, sample_weight, multioutput]
xp, _ = get_namespace(*input_arrays)

dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)

_, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype, xp=xp
_, y_true, y_pred, sample_weight, multioutput = (
_check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
)
)

check_consistent_length(y_true, y_pred, sample_weight)

output_errors = _average(
Expand Down Expand Up @@ -398,19 +468,16 @@ def mean_absolute_percentage_error(
>>> mean_absolute_percentage_error(y_true, y_pred)
112589990684262.48
"""
input_arrays = [y_true, y_pred, sample_weight, multioutput]
xp, _ = get_namespace(*input_arrays)
dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)

y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype, xp=xp
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)
_, y_true, y_pred, sample_weight, multioutput = (
_check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
)
)
check_consistent_length(y_true, y_pred, sample_weight)
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=dtype)
y_true_abs = xp.asarray(xp.abs(y_true), dtype=dtype)
mape = xp.asarray(xp.abs(y_pred - y_true), dtype=dtype) / xp.maximum(
y_true_abs, epsilon
)
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=y_true.dtype)
y_true_abs = xp.abs(y_true)
mape = xp.abs(y_pred - y_true) / xp.maximum(y_true_abs, epsilon)
output_errors = _average(mape, weights=sample_weight, axis=0)
if isinstance(multioutput, str):
if multioutput == "raw_values":
Expand Down Expand Up @@ -494,10 +561,10 @@ def mean_squared_error(
0.825...
"""
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)
dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)

_, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype, xp=xp
_, y_true, y_pred, sample_weight, multioutput = (
_check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
)
)
check_consistent_length(y_true, y_pred, sample_weight)
output_errors = _average((y_true - y_pred) ** 2, axis=0, weights=sample_weight)
Expand Down Expand Up @@ -670,10 +737,9 @@ def mean_squared_log_error(
0.060...
"""
xp, _ = get_namespace(y_true, y_pred)
dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)

_, y_true, y_pred, _ = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype, xp=xp
_, y_true, y_pred, _, _ = _check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
)

if xp.any(y_true <= -1) or xp.any(y_pred <= -1):
Expand Down Expand Up @@ -747,10 +813,9 @@ def root_mean_squared_log_error(
0.199...
"""
xp, _ = get_namespace(y_true, y_pred)
dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)

_, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype, xp=xp
_, y_true, y_pred, _, _ = _check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
)

if xp.any(y_true <= -1) or xp.any(y_pred <= -1):
Expand Down Expand Up @@ -1188,11 +1253,12 @@ def r2_score(
y_true, y_pred, sample_weight, multioutput
)

dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)

_, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype, xp=xp
_, y_true, y_pred, sample_weight, multioutput = (
_check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
)
)

check_consistent_length(y_true, y_pred, sample_weight)

if _num_samples(y_pred) < 2:
Expand All @@ -1201,7 +1267,7 @@ def r2_score(
return float("nan")

if sample_weight is not None:
sample_weight = column_or_1d(sample_weight, dtype=dtype)
sample_weight = column_or_1d(sample_weight)
weight = sample_weight[:, None]
else:
weight = 1.0
Expand Down Expand Up @@ -1356,8 +1422,8 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
1.4260...
"""
xp, _ = get_namespace(y_true, y_pred)
y_type, y_true, y_pred, _ = _check_reg_targets(
y_true, y_pred, None, dtype=[xp.float64, xp.float32], xp=xp
y_type, y_true, y_pred, sample_weight, _ = _check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput=None, xp=xp
)
if y_type == "continuous-multioutput":
raise ValueError("Multioutput not supported in mean_tweedie_deviance")
Expand Down Expand Up @@ -1570,8 +1636,8 @@ def d2_tweedie_score(y_true, y_pred, *, sample_weight=None, power=0):
"""
xp, _ = get_namespace(y_true, y_pred)

y_type, y_true, y_pred, _ = _check_reg_targets(
y_true, y_pred, None, dtype=[xp.float64, xp.float32], xp=xp
y_type, y_true, y_pred, sample_weight, _ = _check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput=None, xp=xp
)
if y_type == "continuous-multioutput":
raise ValueError("Multioutput not supported in d2_tweedie_score")
Expand Down
4 changes: 2 additions & 2 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,8 @@ def _require_positive_targets(y1, y2):
def _require_log1p_targets(y1, y2):
"""Make targets strictly larger than -1"""
offset = abs(min(y1.min(), y2.min())) - 0.99
y1 = y1.astype(float)
y2 = y2.astype(float)
y1 = y1.astype(np.float64)
y2 = y2.astype(np.float64)
y1 += offset
y2 += offset
return y1, y2
Expand Down