Skip to content

Reduce redundancy in floating type checks for Array API support #30106

Closed
@virchan

Description

@virchan

Describe the workflow you want to enable

While working on #29978, we noticed that the following procedure is repeated across most regression metrics in _regression.py for the Array API:

    xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)
    dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)

    _, y_true, y_pred, multioutput = _check_reg_targets(
        y_true, y_pred, multioutput, dtype=dtype, xp=xp
    )

To reduce redundancy, it would make sense to incorporate the _find_matching_floating_dtype logic directly into the _check_reg_targets function. This would result in the following cleaner implementation:

    xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)
    _, y_true, y_pred, multioutput, dtype = _check_reg_targets(
        y_true, y_pred, multioutput, xp=xp
    )

Describe your proposed solution

We could introduce a new function, _check_reg_targets_and_dtype, defined in the obvious way. This approach would enable us to utilise the existing tests in test_regression.py with minimal changes.

Describe alternatives you've considered, if relevant

We could modify the original _check_reg_targets function, but this would require carefully reviewing and updating the relevant tests in test_regression.py to ensure everything remains consistent.

Additional context

This is part of the Array API project #26024.

ping: @ogrisel
cc: @glemaitre, @betatim, @adrinjalali.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions