Skip to content

ENH Reduce redundancy in floating type checks for Array API support in _regression.py #30128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

virchan
Copy link
Member

@virchan virchan commented Oct 22, 2024

Reference Issues/PRs

Fixes #30106, and unpauses #29978.

What does this implement/fix? Explain your changes.

This PR introduces a new function, _check_reg_targets_and_floating_dtype, to streamline the floating type checks in _regression.py.

The new function integrates the _find_matching_floating_dtype logic directly into the _check_reg_targets function, and has the following signature:

y_type, y_true, y_pred, sample_weight, multioutput = _check_reg_targets_and_floating_dtype(
    y_true, y_pred, sample_weight, multioutput, xp
)

To inspect the resulting floating-point data type, users can access the .dtype attribute of the returned arrays, e.g., y_true.dtype or y_pred.dtype.

Additionally, it passes xp to avoid redundant namespace inspection, extending the work done in #30092.

The following regression metrics remain unchanged by this PR:

Any other comments?

This is part of the Array API project (#26024).

Ping: @ogrisel

Cc: @adrinjalali, @betatim, @glemaitre, @sqali.

Copy link

github-actions bot commented Oct 22, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: e868d9d. Link to the linter CI: here

@virchan virchan marked this pull request as ready for review October 22, 2024 02:15
Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also add a note in _check_reg_targets's docstring that we should probably use the new method. Otherwise LGTM.

virchan and others added 5 commits October 22, 2024 00:19
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
@adrinjalali
Copy link
Member

cc @ogrisel

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, LGTM. Here a few suggestions for further improvement.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some possible simplification of the code:

@ogrisel
Copy link
Member

ogrisel commented Oct 25, 2024

Actually, I do not have the permission to accept my own suggestions or to directly push to your branch. So please review and accept the suggestions first.

Could you also please rename the _check_reg_targets_and_floating_dtype function to _check_reg_targets_with_floating_dtype?

virchan and others added 4 commits October 25, 2024 09:24
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@virchan
Copy link
Member Author

virchan commented Oct 25, 2024

Actually, I do not have the permission to accept my own suggestions or to directly push to your branch. So please review and accept the suggestions first.

Could you also please rename the _check_reg_targets_and_floating_dtype function to _check_reg_targets_with_floating_dtype?

I've made the changes. Please let me know if there’s anything else you'd like me to update!

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @virchan

@OmarManzoor OmarManzoor merged commit 33f08f1 into scikit-learn:main Dec 2, 2024
30 checks passed
@virchan
Copy link
Member Author

virchan commented Dec 2, 2024

Thank you everyone for your time!

jeremiedbb pushed a commit to jeremiedbb/scikit-learn that referenced this pull request Dec 4, 2024
…n `_regression.py` (scikit-learn#30128)

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
jeremiedbb pushed a commit that referenced this pull request Dec 6, 2024
…n `_regression.py` (#30128)

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
virchan added a commit to virchan/scikit-learn that referenced this pull request Dec 9, 2024
…n `_regression.py` (scikit-learn#30128)

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Reduce redundancy in floating type checks for Array API support
5 participants