-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
FIX _check_reg_targets with Array API and multioutput argument #29143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
@betatim @OmarManzoor any opinion on the need to add a non-regression test for such a specific fix?
Since mean_absolute_error has only been added in the current main branch I don't think it might be necessary to add a regression test but it might be useful to enhance the check_array_api_multioutput_regression_metric with multioutput==array |
|
||
def check_array_api_multioutput_regression_metric( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before, there was check_array_api_multioutput_regression_metric
(multioutput in the middle) and check_array_api_regression_metric_multioutput
(multioutput at the end), I merged both functions into one.
@@ -1856,8 +1862,8 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam | |||
def check_array_api_regression_metric_multioutput( | |||
metric, array_namespace, device, dtype_name | |||
): | |||
y_true_np = np.array([[1, 3], [1, 2]], dtype=dtype_name) | |||
y_pred_np = np.array([[1, 4], [1, 1]], dtype=dtype_name) | |||
y_true_np = np.array([[1, 3, 2], [1, 2, 2]], dtype=dtype_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added more values to avoid having n_rows == n_columns
I have added a test, I think this should make codecov happy. Edit: not quite since the exception case is still not covered. I think this should be dealt with in a separate PR. At least the behaviour is tested now. |
…nto fix_check_reg_targets
I have added a changelog, I added the PR in the already existing |
I launched the CUDA GPU CI at: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming green CI, LGTM with the latest changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
OK merging the PR since there are two approvals, thanks @Tialo |
code to reproduce
traceback:
Reference Issues/PRs
What does this implement/fix? Explain your changes.
Any other comments?
Also, there is a function
check_array_api_multioutput_regression_metric
for testing regression multioutput metrics. But it only tests for multioutput=="raw_values", and not an array. Maybe array multioutput should also be tested?