Skip to content

TST use global_dtype in sklearn/neighbors/tests/test_neighbors.py #22663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 30, 2022

Conversation

jjerphan
Copy link
Member

@jjerphan jjerphan commented Mar 3, 2022

Reference Issues/PRs

Partially addresses #22881
Precedes #22590

What does this implement/fix? Explain your changes.

This parametrizes tests from test_neighbors.py to run on 32bit datasets.

Any other comments?

We could introduce a mechanism to be able to able to remove tests' execution on 32bit datasets if this takes too much time to complete.

@jjerphan jjerphan marked this pull request as ready for review March 3, 2022 15:01
@jeremiedbb
Copy link
Member

general remark regarding the batch of similar PRs. I think we need to be careful to not parametrize with dtype too many tests because it doubles the time for running the tests. The test suite takes a lot of time already.

@jjerphan
Copy link
Member Author

jjerphan commented Mar 3, 2022

Yes -- I put a remark regarding having a mechanism to only test generally for dtype=np.float64. I think I'll create an issue for this topic.

@thomasjpfan
Copy link
Member

Yes -- I put a remark regarding having a mechanism to only test generally for dtype=np.float64

I have a mechanism in mind. It comes down to a global fixture that is either triggered with a pytest mark or an ENV variable. (A custom command line option can work but it requires a little more work to make it work with --pyargs)

The ENV variable is easiest, because we can just set it in azure_pipeline.yml without additional logic. The idea looks like:

import pytest
import numpy as np
from os import environ

_SKIP32_MARK = pytest.mark.skipif(
    environ.get("SKLEARN_SKIP_FLOAT32", "1") != "0",
    reason="Set SKLEARN_SKIP_FLOAT32=0 to run float32 dtype tests",
)

# place this in `conftest.py` in scikit-learn
@pytest.fixture(params=[pytest.param(np.float32, marks=_SKIP32_MARK), np.float64])
def dtype(request):
    yield request.param

def test_dtype(dtype):
    a = np.asarray([1, 2], dtype=dtype)
    assert a.dtype == dtype

Skips float32 by default

pytest test_script.py -v

To run float32

SKLEARN_SKIP_FLOAT32=0 pytest test_script.py -v

@jjerphan
Copy link
Member Author

jjerphan commented Mar 3, 2022

You were faster than me, @thomasjpfan!
I like the idea of using pytest.fixtures even if it looks like an automagic trick to me.

In the meantime, I created #22680 to discuss it and to pin the group of PRs for testing on 32bit datasets.

@ogrisel
Copy link
Member

ogrisel commented Mar 3, 2022

general remark regarding the batch of similar PRs. I think we need to be careful to not parametrize with dtype too many tests because it doubles the time for running the tests. The test suite takes a lot of time already.

On top of this I think it's important to add extra assertions on what is the expected impact of fitting a model with a specific input dtype.

For instance a fitted attribute that is an array of fitted parameters could have a dtype that depends on the input, or to the contrary, we can check that the dtype of such a fitted attribute is always float64 even if the input is float32, if there is a good reason for that (e.g. to avoid a known numerical stability problem). We should add a comment in the test to explain when this is the case as in general I would expect the precision of the fitted attributes to be lower when the input data is of lower precision.

Similarly for the dtype of the arrays returned by transform for transformers, or predict for regressors or predict_proba for classifiers.

@ogrisel
Copy link
Member

ogrisel commented Mar 3, 2022

In particular assert_array_equal(a, b) and assert_allclose(a, b) can pass if a.dtype == np.float64 and b.dtype == np.float32 so it's important to explicitly check what is the expected dtype in those cases.

@jeremiedbb
Copy link
Member

Similarly for the dtype of the arrays returned by transform for transformers,

For that there's the common check check_transformer_preserve_dtypes

@jeremiedbb
Copy link
Member

For the attributes it's estimator and attribute specific. Some are meant to have the same dtype as the input but others are integers arrays or scalars, or whatever. For estimators that are supposed to preserve some dtypes, we usually have a dedicated test to check the dtype of the appropriate attributes.

It's however possible that we don't have tests for all of them. We should.

@ogrisel
Copy link
Member

ogrisel commented Mar 3, 2022

For that there's the common check check_transformer_preserve_dtypes

Good point. No need to duplicate this check then. We don't have anything similar for the predict method of regressors or the predict_proba method of classifier, right?

sklearn/kernel_approximation.py:                "check_transformer_preserve_dtypes": (
sklearn/manifold/tests/test_spectral_embedding.py:    `check_transformer_preserve_dtypes`. However, this test only run
sklearn/utils/estimator_checks.py:        yield check_transformer_preserve_dtypes
sklearn/utils/estimator_checks.py:def check_transformer_preserve_dtypes(name, transformer_orig):

We probably should.

@thomasjpfan
Copy link
Member

We don't have anything similar for the predict method of regressors or the predict_proba method of classifier, right?

We do not. Likely deserves an issue to define what the behavior should be. For example, regressor.fit(X_32, y_64), should regressor.predict(X_32) be float32 or float64?

@ogrisel
Copy link
Member

ogrisel commented Mar 3, 2022

@thomasjpfan I created an issue: #22682

Feel free to edit. If you agree with my proposal, please remove the Needs Triage label and add a new Help Wanted label.

@jjerphan jjerphan changed the title TST Adapt test_neighbors.py to test implementations on 32bit datasets TST use global_dtype in sklearn/neighbors/tests/test_neighbors.py Mar 17, 2022
Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some comments. In addition all astype should use copy=False

@@ -1317,7 +1346,7 @@ def test_kneighbors_graph():
assert_array_equal(A.toarray(), np.eye(A.shape[0]))

A = neighbors.kneighbors_graph(X, 1, mode="distance")
assert_array_almost_equal(
assert_allclose(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this test (test_kneighbors_graph) use the global_dtype ?

Co-authored-by: Jérémie du Boisberranger
@jeremiedbb
Copy link
Member

All the astype are missing copy=False. Looks good otherwise

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed some nitpicks + a fix for warnings raised during the tests.

There are still warnings about invalid values in division when running the tests with -Werror but this is because of _weight_func that can generate invalid weights for zero distance pairs. However this is unrelated to the scope of this PR (it happens irrespective of the dtype of X) and would better be addressed in a dedicated PR.

Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jeremiedbb jeremiedbb merged commit 7931262 into scikit-learn:main Mar 30, 2022
@jjerphan jjerphan deleted the tst/test_neighbors-32bit branch March 30, 2022 12:59
glemaitre pushed a commit to glemaitre/scikit-learn that referenced this pull request Apr 6, 2022
…ikit-learn#22663)

Co-authored-by: Jérémie du Boisberranger
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants