Description
Describe the workflow you want to enable
While working on #29978, we noticed that the following procedure is repeated across most regression metrics in _regression.py
for the Array API:
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)
dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)
_, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype, xp=xp
)
To reduce redundancy, it would make sense to incorporate the _find_matching_floating_dtype
logic directly into the _check_reg_targets
function. This would result in the following cleaner implementation:
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)
_, y_true, y_pred, multioutput, dtype = _check_reg_targets(
y_true, y_pred, multioutput, xp=xp
)
Describe your proposed solution
We could introduce a new function, _check_reg_targets_and_dtype
, defined in the obvious way. This approach would enable us to utilise the existing tests in test_regression.py
with minimal changes.
Describe alternatives you've considered, if relevant
We could modify the original _check_reg_targets
function, but this would require carefully reviewing and updating the relevant tests in test_regression.py
to ensure everything remains consistent.
Additional context
This is part of the Array API project #26024.
ping: @ogrisel
cc: @glemaitre, @betatim, @adrinjalali.