-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG] EHN handle NaN value in QuantileTransformer #10437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
cc3bb96
EHN handle NaN value in QuantileTransformer
glemaitre 76123c8
DOC add whats new entry
glemaitre 530c7bf
TST relax inf/nan common test
glemaitre 1f07963
FIX silent warning and raise an error for numpy version
glemaitre 91c947e
TST ensure that test raise error with older numpy
glemaitre 1c406c0
TST remove mocking
glemaitre ecc5048
EHN accept integer as missing values
glemaitre 965811f
address joel comments
glemaitre cd28883
FIX nanpercentile for python 2
glemaitre 3d0c389
TST test the output under numpy < 1.9
glemaitre a217af6
FIX nanpercentile numpy 1.8
glemaitre 85c6268
PEP8
glemaitre 1306992
TST check all missing values behaviour
glemaitre 73eed7b
TST change name for consistency
glemaitre ecdc675
Merge remote-tracking branch 'origin/master' into is/10404
glemaitre d7b6cd9
EHN only accept NaN for the moment
glemaitre f7bc642
unecessary change
glemaitre 84682b7
unecessary change
glemaitre 59dfdbe
solve issue in numpy 1.8
glemaitre d852206
address ogrisel comments
glemaitre a20ac59
Address some comments
glemaitre f8dd6a4
TST fix common test
glemaitre 8aa6059
TST common test for transformer letting pass nan
glemaitre e9b9855
Merge remote-tracking branch 'origin/master' into is/10404
glemaitre c745eab
TST add separate commont tests
glemaitre daa3a91
Merge remote-tracking branch 'origin/master' into is/10404
glemaitre ad878fa
TST improve testing
glemaitre 6784c3b
TST remove parametrization on X and n_missing
glemaitre daa40da
address joel comments
glemaitre 2c0ceb3
fix random state for the split training testing
glemaitre 004b0e3
do not force percentile to be finite
glemaitre 9ab77b6
fix
glemaitre 33cc416
TST add test for quantile transformer
glemaitre f58dcee
TST fix for older numpy version
glemaitre 0f03485
FIX for to use nanpercentile up to 1.11 for consistent behaviour
glemaitre d0a88bd
my mistake
glemaitre d554f8e
Merge branch 'master' into is/10404
glemaitre 1bb0006
Roman comments
glemaitre File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
import numpy as np | ||
|
||
from sklearn.datasets import load_iris | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.preprocessing import QuantileTransformer | ||
from sklearn.utils.testing import assert_array_equal | ||
from sklearn.utils.testing import assert_allclose | ||
|
||
iris = load_iris() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"est", | ||
[QuantileTransformer(n_quantiles=10, random_state=42)] | ||
) | ||
def test_missing_value_handling(est): | ||
# check that the preprocessing method let pass nan | ||
rng = np.random.RandomState(42) | ||
X = iris.data.copy() | ||
n_missing = 50 | ||
X[rng.randint(X.shape[0], size=n_missing), | ||
rng.randint(X.shape[1], size=n_missing)] = np.nan | ||
X_train, X_test = train_test_split(X, random_state=1) | ||
# sanity check | ||
assert not np.all(np.isnan(X_train), axis=0).any() | ||
assert np.any(np.isnan(X_train), axis=0).all() | ||
assert np.any(np.isnan(X_test), axis=0).all() | ||
X_test[:, 0] = np.nan # make sure this boundary case is tested | ||
|
||
Xt = est.fit(X_train).transform(X_test) | ||
# missing values should still be missing, and only them | ||
assert_array_equal(np.isnan(Xt), np.isnan(X_test)) | ||
|
||
# check that the inverse transform keep NaN | ||
Xt_inv = est.inverse_transform(Xt) | ||
assert_array_equal(np.isnan(Xt_inv), np.isnan(X_test)) | ||
# FIXME: we can introduce equal_nan=True in recent version of numpy. | ||
# For the moment which just check that non-NaN values are almost equal. | ||
assert_allclose(Xt_inv[~np.isnan(Xt_inv)], X_test[~np.isnan(X_test)]) | ||
|
||
for i in range(X.shape[1]): | ||
# train only on non-NaN | ||
est.fit(X_train[:, [i]][~np.isnan(X_train[:, i])]) | ||
# check transforming with NaN works even when training without NaN | ||
Xt_col = est.transform(X_test[:, [i]]) | ||
assert_array_equal(Xt_col, Xt[:, [i]]) | ||
# check non-NaN is handled as before - the 1st column is all nan | ||
if not np.isnan(X_test[:, i]).all(): | ||
Xt_col_nonan = est.transform( | ||
X_test[:, [i]][~np.isnan(X_test[:, i])]) | ||
assert_array_equal(Xt_col_nonan, | ||
Xt_col[~np.isnan(Xt_col.squeeze())]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fwiw, it's possible that np.ma would handle the none-missing case more efficiently than using an ad-hoc hoc mask. I've not checked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Playing around, I think that it will trigger the same number of copy.