-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
FIX Add input array check to randomized_range_finder
#30819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX Add input array check to randomized_range_finder
#30819
Conversation
randomized_range_finder
randomized_range_finder
sklearn/utils/extmath.py
Outdated
A = check_array(A, accept_sparse=True) | ||
xp, is_array_api_compliant = get_namespace(A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to basically cancel the array API work I think.
I wonder if explicitly checking for complex values is the only way.
cc @OmarManzoor maybe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, you're right. Thanks for the catch. What if we just call check_array
but not reassign A
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if we don't assign the new array to A, it should work like before while handling the exceptions we need to catch. By the way I think check_array supports the array api itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, made the change to not reassign A after check_array.
I also tried check_array on a torch Tensor. The check runs and passes but it does convert to a numpy ndarray.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check_array
is array-api compliant so it can be used as usual.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeremiedbb, I see, when I tried out check_array
on a torch Tensor previously, I didn't have the array api enabled.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -0,0 +1,2 @@ | |||
- :func:`sklearn.utils.extmath.randomized_svd` now support Array API compatible inputs. | |||
By :user:`Connor Lane <clane9>`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generous attribution to user @clane9 😂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe @jeremiedbb you could add yourself as a co-author of this contribution ;)
Thanks for the help with this PR @jeremiedbb and @ogrisel! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
@jeremiedbb I let you credit your work in the changelog and merge if you wish. |
Thanks @clane9 and sorry for the late feedback on this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to add a test to check that randomized_svd
and randomized_range_finder
raise a meaninful error on invalid data (e.g. complex data like in the original issue) or is it overkill?
The thing is that we don't have checks for complex input for any other function so I would not add it here and trust the |
…inder` (scikit-learn#30819) Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com> Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
Reference Issues/PRs
Fixes #30736. See also #30737.
What does this implement/fix? Explain your changes.
Adds a check of the input array to
randomized_range_finder
so that complex valued inputs are rejected. As a result, complex inputs are also rejected fromrandomized_svd
, which callsrandomized_range_finder
.