Skip to content

FIX _check_reg_targets with Array API and multioutput argument #29143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 20, 2024

Conversation

Tialo
Copy link
Contributor

@Tialo Tialo commented May 30, 2024

code to reproduce

from sklearn.metrics import mean_absolute_error
from sklearn import set_config

from array_api_strict import asarray

set_config(array_api_dispatch=True)

m = asarray([2, 3, 4])
t1 = asarray([[1., 2., 3.], [3., 3., 3]])
t2 = asarray([[1, 2, 10], [5, 2, 1]])

print(mean_absolute_error(t1, t2, multioutput=m))

traceback:

Traceback (most recent call last):
  File "/Users/y/Desktop/scikit-learn/chec_s.py", line 12, in <module>
    print(mean_absolute_error(t1, t2, multioutput=m))
  File "/Users/y/Desktop/scikit-learn/sklearn/utils/_param_validation.py", line 213, in wrapper
    return func(*args, **kwargs)
  File "/Users/y/Desktop/scikit-learn/sklearn/metrics/_regression.py", line 221, in mean_absolute_error
    _, y_true, y_pred, multioutput = _check_reg_targets(
  File "/Users/y/Desktop/scikit-learn/sklearn/metrics/_regression.py", line 142, in _check_reg_targets
    elif n_outputs != len(multioutput):
TypeError: object of type 'Array' has no len()

Reference Issues/PRs

What does this implement/fix? Explain your changes.

Any other comments?

Also, there is a function check_array_api_multioutput_regression_metric for testing regression multioutput metrics. But it only tests for multioutput=="raw_values", and not an array. Maybe array multioutput should also be tested?

Copy link

github-actions bot commented May 30, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 2535a13. Link to the linter CI: here

@Tialo Tialo changed the title BUG _check_reg_targets breaks with Array API FIX _check_reg_targets breaks with Array API May 31, 2024
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@betatim @OmarManzoor any opinion on the need to add a non-regression test for such a specific fix?

@OmarManzoor
Copy link
Contributor

OmarManzoor commented Jun 7, 2024

Since mean_absolute_error has only been added in the current main branch I don't think it might be necessary to add a regression test but it might be useful to enhance the check_array_api_multioutput_regression_metric with multioutput==array


def check_array_api_multioutput_regression_metric(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before, there was check_array_api_multioutput_regression_metric (multioutput in the middle) and check_array_api_regression_metric_multioutput (multioutput at the end), I merged both functions into one.

@@ -1856,8 +1862,8 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam
def check_array_api_regression_metric_multioutput(
metric, array_namespace, device, dtype_name
):
y_true_np = np.array([[1, 3], [1, 2]], dtype=dtype_name)
y_pred_np = np.array([[1, 4], [1, 1]], dtype=dtype_name)
y_true_np = np.array([[1, 3, 2], [1, 2, 2]], dtype=dtype_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more values to avoid having n_rows == n_columns

@lesteve
Copy link
Member

lesteve commented Jun 17, 2024

I have added a test, I think this should make codecov happy. Edit: not quite since the exception case is still not covered. I think this should be dealt with in a separate PR. At least the behaviour is tested now.

@lesteve
Copy link
Member

lesteve commented Jun 18, 2024

I have added a changelog, I added the PR in the already existing mean_absolute_error array API entry although I guess this may also have fixed other metrics, not 100% sure.

@lesteve lesteve changed the title FIX _check_reg_targets breaks with Array API FIX Fix _check_reg_targets with Array API and multioutput argument Jun 18, 2024
@lesteve lesteve changed the title FIX Fix _check_reg_targets with Array API and multioutput argument FIX _check_reg_targets with Array API and multioutput argument Jun 18, 2024
@ogrisel
Copy link
Member

ogrisel commented Jun 18, 2024

I launched the CUDA GPU CI at:

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming green CI, LGTM with the latest changes.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Tialo @lesteve. LGTM otherwise.

lesteve and others added 2 commits June 20, 2024 09:33
Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
@lesteve
Copy link
Member

lesteve commented Jun 20, 2024

OK merging the PR since there are two approvals, thanks @Tialo

@lesteve lesteve merged commit e82c14b into scikit-learn:main Jun 20, 2024
30 checks passed
@Tialo Tialo deleted the fix_check_reg_targets branch June 20, 2024 09:00
@jeremiedbb jeremiedbb mentioned this pull request Jul 2, 2024
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

4 participants