-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
TST use global_dtype in sklearn/neighbors/tests/test_neighbors.py #22663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TST use global_dtype in sklearn/neighbors/tests/test_neighbors.py #22663
Conversation
general remark regarding the batch of similar PRs. I think we need to be careful to not parametrize with dtype too many tests because it doubles the time for running the tests. The test suite takes a lot of time already. |
Yes -- I put a remark regarding having a mechanism to only test generally for |
I have a mechanism in mind. It comes down to a global fixture that is either triggered with a pytest mark or an ENV variable. (A custom command line option can work but it requires a little more work to make it work with The ENV variable is easiest, because we can just set it in import pytest
import numpy as np
from os import environ
_SKIP32_MARK = pytest.mark.skipif(
environ.get("SKLEARN_SKIP_FLOAT32", "1") != "0",
reason="Set SKLEARN_SKIP_FLOAT32=0 to run float32 dtype tests",
)
# place this in `conftest.py` in scikit-learn
@pytest.fixture(params=[pytest.param(np.float32, marks=_SKIP32_MARK), np.float64])
def dtype(request):
yield request.param
def test_dtype(dtype):
a = np.asarray([1, 2], dtype=dtype)
assert a.dtype == dtype Skips float32 by defaultpytest test_script.py -v To run float32SKLEARN_SKIP_FLOAT32=0 pytest test_script.py -v |
You were faster than me, @thomasjpfan! In the meantime, I created #22680 to discuss it and to pin the group of PRs for testing on 32bit datasets. |
On top of this I think it's important to add extra assertions on what is the expected impact of fitting a model with a specific input dtype. For instance a fitted attribute that is an array of fitted parameters could have a dtype that depends on the input, or to the contrary, we can check that the dtype of such a fitted attribute is always float64 even if the input is float32, if there is a good reason for that (e.g. to avoid a known numerical stability problem). We should add a comment in the test to explain when this is the case as in general I would expect the precision of the fitted attributes to be lower when the input data is of lower precision. Similarly for the dtype of the arrays returned by |
In particular |
For that there's the common check |
For the attributes it's estimator and attribute specific. Some are meant to have the same dtype as the input but others are integers arrays or scalars, or whatever. For estimators that are supposed to preserve some dtypes, we usually have a dedicated test to check the dtype of the appropriate attributes. It's however possible that we don't have tests for all of them. We should. |
Good point. No need to duplicate this check then. We don't have anything similar for the predict method of
We probably should. |
We do not. Likely deserves an issue to define what the behavior should be. For example, |
@thomasjpfan I created an issue: #22682 Feel free to edit. If you agree with my proposal, please remove the |
test_neighbors.py
to test implementations on 32bit datasetsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are some comments. In addition all astype should use copy=False
@@ -1317,7 +1346,7 @@ def test_kneighbors_graph(): | |||
assert_array_equal(A.toarray(), np.eye(A.shape[0])) | |||
|
|||
A = neighbors.kneighbors_graph(X, 1, mode="distance") | |||
assert_array_almost_equal( | |||
assert_allclose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this test (test_kneighbors_graph) use the global_dtype ?
Co-authored-by: Jérémie du Boisberranger
All the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed some nitpicks + a fix for warnings raised during the tests.
There are still warnings about invalid values in division when running the tests with -Werror
but this is because of _weight_func
that can generate invalid weights for zero distance pairs. However this is unrelated to the scope of this PR (it happens irrespective of the dtype of X
) and would better be addressed in a dedicated PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…ikit-learn#22663) Co-authored-by: Jérémie du Boisberranger Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Reference Issues/PRs
Partially addresses #22881
Precedes #22590
What does this implement/fix? Explain your changes.
This parametrizes tests from
test_neighbors.py
to run on 32bit datasets.Any other comments?
We could introduce a mechanism to be able to able to remove tests' execution on 32bit datasets if this takes too much time to complete.