Skip to content

TST use global_dtype in sklearn/cluster/tests/test_birch.py #22671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 18, 2022

Conversation

jjerphan
Copy link
Member

@jjerphan jjerphan commented Mar 3, 2022

Reference Issues/PRs

Partially addresses #22881
Precedes #22590

What does this implement/fix? Explain your changes.

This parametrizes tests from test_birch.py to run on 32bit datasets.

Any other comments?

We could introduce a mechanism to be able to able to remove tests' execution on 32bit datasets if this takes too much time to complete.

@jjerphan jjerphan marked this pull request as ready for review March 3, 2022 16:26
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly to #22672, we need at least one test that checks the impact of changing the dtype of X on the fitted attribute subcluster_centers_. I would have expected float32 but it's not the case. I am not sure why maybe this reveals a suboptimal operation in the Birch class.

Since Birch is also a transformer, we should also check the dtype of Birch().fit_transform(X.astype(np.float32)) which I would also have expected to be float32 but it's not the case either, probably because subcluster_centers_ is always float64.

Furthermore, while trying interactively for myself, I observed that fitting a Birch() model now raises a warning:

>>> import numpy as np
>>> from sklearn.cluster import Birch
>>> Birch().fit(np.random.randn(100, 5).astype(np.float32)).subcluster_centers_.dtype
/Users/ogrisel/code/scikit-learn/sklearn/cluster/_birch.py:760: UserWarning: Some metric_kwargs have been passed ({'Y_norm_squared': array([ 3.86024464, 11.683084  , 12.99578754,  6.910336  ,  2.55765101,
        5.03749571,  5.06213863, 10.26015348,  0.90800463,  1.26334124,
        3.4959797 ,  6.45426856, 13.48504291,  3.90932701,  6.88367298,
        3.43889466,  4.35950137,  6.41143755,  1.44802403,  0.48286628,
        6.90164792,  3.67385496,  6.00607854,  6.95285525,  6.90960791,
        3.7674752 ,  3.52569363,  3.47566689,  5.2203662 ,  2.10915227,
        1.57974275,  1.7641581 ,  3.3768002 ,  4.25468386,  3.47676319,
        4.86241842,  2.3451047 ,  2.17838236,  6.30420973,  3.64096226,
        8.83660004,  4.07342638,  2.08893818,  1.51923725,  6.9491891 ,
        4.84401576,  7.78082366,  3.14570647,  3.43566494,  6.79652774,
        5.56645993, 11.18789605,  4.60155353,  6.88679755,  0.88999525,
        4.95820257,  2.69660257,  0.75948625,  4.14714094, 10.64599847,
        3.07223409,  9.7867565 ,  8.27261454,  2.71387669, 14.36863042,
       10.79094887,  5.11707374,  2.67162805,  2.10645627,  2.35485549,
        3.02763474,  6.7502932 ,  2.70514103,  7.39961664,  3.18970976,
        5.23735055,  1.956462  ,  7.20984384, 12.08628175,  5.37923063])}) but aren'tusable for this case (FastEuclideanPairwiseDistancesArgKmin) and will be ignored.
  self.labels_ = self._predict(X)
np.float64

Ideally our test should fail when we have such unexpected UserWarnings raised by scikit-learn code but unfortunately this is not the case at the moment. We only do it for FutureWarning on some dedicated CI runs.

@ogrisel
Copy link
Member

ogrisel commented Mar 3, 2022

For the transform method there is already a common test: #22663 (comment)

However it does not run on Birch:

pytest -vk "check_transformer_preserve_dtypes and Birch" sklearn/tests/test_common.py
[...]
=================================================================================== 8836 deselected, 57 warnings in 1.69s ====================================================================================

while running for either:

pytest -vk "check_transformer_preserve_dtypes"

or

pytest -vk "Birch"

finds plenty of common tests.

This is probably because Birch does not have the preserves_dtype estimator tags and I think it should.

@ogrisel
Copy link
Member

ogrisel commented Mar 3, 2022

dtype preservation for transformers is tracked in this issue:

@jeremiedbb
Copy link
Member

This is probably because Birch does not have the preserves_dtype estimator tags and I think it should.
Birch().fit_transform(X.astype(np.float32)) which I would also have expected to be float32 but it's not the case either

The second phrase tells you why it should not have the tag :)

There's this long term issue #11000 to track which transformer properly preserves float32. Birch is not 1 of them yet.

@ogrisel
Copy link
Member

ogrisel commented Mar 17, 2022

The second phrase tells you why it should not have the tag :)

I don't understand. I think it should have the tag and the Birch code should be fixed to make sure the common test pass, no?

@jeremiedbb
Copy link
Member

I don't understand. I think it should have the tag and the Birch code should be fixed to make sure the common test pass, no?

Of course it should but it takes time :)
It does not have it yet because it does preserve float32 yet.

@jjerphan jjerphan changed the title TST Adapt test_birch.py to test implementations on 32bit datasets TST use global_dtype in sklearn/cluster/tests/test_birch.py Mar 18, 2022
@jeremiedbb
Copy link
Member

Actually BIRCH doesn't preserve the dtype yet so I think these changes should be delayed.

@jeremiedbb
Copy link
Member

jeremiedbb commented Jun 8, 2022

According to some irl discussions, such tests should only be added after Birch preserves float32 (ref #11000).

Let's keep this PR open in the mean time.

jjerphan and others added 2 commits October 21, 2022 17:50
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
@jjerphan
Copy link
Member Author

Now that #22968 has been merged, I've updated this PR taking @jeremiedbb's last comments into account.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming the following works as expected, LGTM.

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@jjerphan jjerphan added Quick Review For PRs that are quick to review Waiting for Second Reviewer First reviewer is done, need a second one! labels Nov 3, 2022
Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that Birch does preserve dtype we can merge this one. LGTM.

@jjerphan jjerphan removed Quick Review For PRs that are quick to review Waiting for Second Reviewer First reviewer is done, need a second one! labels Nov 18, 2022
@jeremiedbb jeremiedbb merged commit 964189d into scikit-learn:main Nov 18, 2022
@jjerphan jjerphan deleted the tst/test_birch-32bit branch November 18, 2022 13:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants