-
-
Notifications
You must be signed in to change notification settings - Fork 26k
ENH Preserving dtype for np.float32 in *DictionaryLearning, SparseCoder and orthogonal_mp_gram #22002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Preserving dtype for np.float32 in *DictionaryLearning, SparseCoder and orthogonal_mp_gram #22002
Conversation
…float32 and np.float64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few questions/suggestions for improvements below but the PR already LGTM as it is.
Please do not forget to document the enhancement in doc/whats_new/v1.1.rst
.
# instead of comparing directory U and V. | ||
assert_allclose(np.matmul(U_64, V_64), np.matmul(U_32, V_32), rtol=rtol, atol=atol) | ||
assert_allclose(np.sum(np.abs(U_64)), np.sum(np.abs(U_32)), rtol=rtol, atol=atol) | ||
assert_allclose(np.sum(V_64 ** 2), np.sum(V_32 ** 2), rtol=rtol, atol=atol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a clever way to test numerical equivalence of the solutions.
I am just wandering, is rtol
really necessary if we already pass atol=1e-7
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather have only rtol than only atol. In addition I think rtol=1e-7 is a little bit too optimistic when comparing float32s because it's slightly lower than the machine precision. 1e-6 would be safer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
I have changed to use only rtol.
Anyway some tests are difficult to pass rtol=1e-7 or 1e-6. So minimum rtol to pass tests are set.
doc/whats_new/v1.1.rst
Outdated
- |Enhancement| `dict_learning` and `dict_learning_online` methods preserve dtype for numpy.float32. | ||
:pr:`22002` by :user:`Takeshi Oura <takoika>`. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your changes also impact DictionaryLearning
, MiniBatchDictionaryLearning
and SparseCoder
. Please mention them here. I think it would also be nice to add similar tests for those as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
Also unit tests to verify dtype matching for DictionaryLearning
, MiniBatchDictionaryLearning
and SparseCoder
are added. Then a test for numerical consistency among np.float32
and np.float64
in sparse_encode
method is added.
…aryLearning and MiniBatchDictionaryLearning
I suspect that we only need to change this to make "omp" preserve dtype as well scikit-learn/sklearn/linear_model/_omp.py Line 557 in 532e1bc
scikit-learn/sklearn/linear_model/_omp.py Line 559 in 532e1bc
adding dtype=Gram.dtype .
Would you mind trying this ? If it requires more work, we can do it later in a separated PR. |
Thank you for the suggestion. |
@jeremiedbb Anyway |
Something that will be required as well for the common test is to add the proper tag to the classes that are preserving the dtype. It boils down to adding the def _more_tags(self):
return {
"preserves_dtype": [np.float64, np.float32],
} This should be added to:
Some checks are probably a bit duplicated with the one written but I don't think that this is a big deal because I am not sure that we are running the common test for all the above classes. |
I would advise making it in a separate PR if this is not as straightforward regarding the behaviour. |
@glemaitre
I have added
I have confirmed that the change affect only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM on my side. @jeremiedbb are you OK with the PR as-is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @takoika !
…er and orthogonal_mp_gram (scikit-learn#22002)
Reference Issues/PRs
This PR is part of #11000 .
What does this implement/fix? Explain your changes.
This PR makes an obtained code and dictionary of Dictionary Learning numpy.float32 when input data is numpy.float32 in order to preserve input data type.
Any other comments?
I found two difficulties in testing numerical consistency between numpy.float32 and numpy.float64 for dictionary learning.
In the scope of the PR it is OK. But potentially this makes it difficult to guarantee numerical consistency for downstream methods to use Dictionary learning.
Further test cases may be required because this PR does not cover argument variations.
I used #13303, #13243 and #20155 as references to make this.