-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
TST use global_dtype in sklearn/cluster/tests/test_birch.py #22671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly to #22672, we need at least one test that checks the impact of changing the dtype of X on the fitted attribute subcluster_centers_
. I would have expected float32
but it's not the case. I am not sure why maybe this reveals a suboptimal operation in the Birch class.
Since Birch
is also a transformer, we should also check the dtype of Birch().fit_transform(X.astype(np.float32))
which I would also have expected to be float32
but it's not the case either, probably because subcluster_centers_
is always float64
.
Furthermore, while trying interactively for myself, I observed that fitting a Birch()
model now raises a warning:
>>> import numpy as np
>>> from sklearn.cluster import Birch
>>> Birch().fit(np.random.randn(100, 5).astype(np.float32)).subcluster_centers_.dtype
/Users/ogrisel/code/scikit-learn/sklearn/cluster/_birch.py:760: UserWarning: Some metric_kwargs have been passed ({'Y_norm_squared': array([ 3.86024464, 11.683084 , 12.99578754, 6.910336 , 2.55765101,
5.03749571, 5.06213863, 10.26015348, 0.90800463, 1.26334124,
3.4959797 , 6.45426856, 13.48504291, 3.90932701, 6.88367298,
3.43889466, 4.35950137, 6.41143755, 1.44802403, 0.48286628,
6.90164792, 3.67385496, 6.00607854, 6.95285525, 6.90960791,
3.7674752 , 3.52569363, 3.47566689, 5.2203662 , 2.10915227,
1.57974275, 1.7641581 , 3.3768002 , 4.25468386, 3.47676319,
4.86241842, 2.3451047 , 2.17838236, 6.30420973, 3.64096226,
8.83660004, 4.07342638, 2.08893818, 1.51923725, 6.9491891 ,
4.84401576, 7.78082366, 3.14570647, 3.43566494, 6.79652774,
5.56645993, 11.18789605, 4.60155353, 6.88679755, 0.88999525,
4.95820257, 2.69660257, 0.75948625, 4.14714094, 10.64599847,
3.07223409, 9.7867565 , 8.27261454, 2.71387669, 14.36863042,
10.79094887, 5.11707374, 2.67162805, 2.10645627, 2.35485549,
3.02763474, 6.7502932 , 2.70514103, 7.39961664, 3.18970976,
5.23735055, 1.956462 , 7.20984384, 12.08628175, 5.37923063])}) but aren'tusable for this case (FastEuclideanPairwiseDistancesArgKmin) and will be ignored.
self.labels_ = self._predict(X)
np.float64
Ideally our test should fail when we have such unexpected UserWarnings
raised by scikit-learn code but unfortunately this is not the case at the moment. We only do it for FutureWarning
on some dedicated CI runs.
For the However it does not run on Birch:
while running for either:
or
finds plenty of common tests. This is probably because |
dtype preservation for transformers is tracked in this issue: |
The second phrase tells you why it should not have the tag :) There's this long term issue #11000 to track which transformer properly preserves float32. Birch is not 1 of them yet. |
I don't understand. I think it should have the tag and the Birch code should be fixed to make sure the common test pass, no? |
Of course it should but it takes time :) |
test_birch.py
to test implementations on 32bit datasets
Actually BIRCH doesn't preserve the dtype yet so I think these changes should be delayed. |
According to some irl discussions, such tests should only be added after Let's keep this PR open in the mean time. |
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
Now that #22968 has been merged, I've updated this PR taking @jeremiedbb's last comments into account. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming the following works as expected, LGTM.
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that Birch does preserve dtype we can merge this one. LGTM.
Reference Issues/PRs
Partially addresses #22881
Precedes #22590
What does this implement/fix? Explain your changes.
This parametrizes tests from
test_birch.py
to run on 32bit datasets.Any other comments?
We could introduce a mechanism to be able to able to remove tests' execution on 32bit datasets if this takes too much time to complete.