Skip to content

[MRG] EHN handle NaN value in QuantileTransformer #10437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Apr 21, 2018

Conversation

glemaitre
Copy link
Member

@glemaitre glemaitre commented Jan 9, 2018

Reference Issues/PRs

partially addresses #10404

What does this implement/fix? Explain your changes.

NaN are handled and ignored during processing in the QuantileTransformer.

Any other comments?

TODO:

@glemaitre
Copy link
Member Author

@jnothman I have 2 questions:

  • Shall we modify check_array to handle inf/nan separately. I could think that ensure_all_finite could accept a string as well nan or inf to make the checking for one of this case only. Right now, I put this logic in the _check_inputs of the QuantileTransformer
  • I get a RuntimeWarning due to a comparison with NaN. Since the comparison return False, everything is fine. Shall we silent the warning or instead make a comparison a masked array?

@jnothman
Copy link
Member

jnothman commented Jan 9, 2018 via email

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice:)

@glemaitre
Copy link
Member Author

what's the harm in silencing the warning?

I don't see any. I would go for that solution.

However, nanpercentile is available from numpy 1.9. Ubuntu has only 1.8.2 version. Shall we backport the nanfunctions.py from numpy?

@jnothman
Copy link
Member

jnothman commented Jan 9, 2018

Sure. Backport and inform other contributors at #10404

1 similar comment
@jnothman
Copy link
Member

jnothman commented Jan 9, 2018

Sure. Backport and inform other contributors at #10404

@glemaitre
Copy link
Member Author

Backport is headache in fact. I tried and we have to bring to much code from numpy to be able to support it. Speaking IRL with @ogrisel, we propose to raise a NotImplemented error when there is NaN and that we require a nanfunctions (involving numpy >= 1.9) and ask for an upgrade.

@glemaitre glemaitre changed the title [WIP] EHN handle NaN value in QuantileTransformer [MRG] EHN handle NaN value in QuantileTransformer Jan 10, 2018
@glemaitre
Copy link
Member Author

@lesteve @jnothman I wanted to mock the version of numpy to be sure that the error in the test is raised properly. I used pytest-mock for the moment. Is there any problem with that, should we use an alternative. Note that mock is only in the standard library in python 3 but not in python 2 which required to install it from pip as well.

@lesteve
Copy link
Member

lesteve commented Jan 10, 2018

@lesteve @jnothman I wanted to mock the version of numpy to be sure that the error in the test is raised properly.

Is pytest-mock really needed? Can we not use either:

  • if statement in the test. The numpy <= 1.9 will be tested in one of the build on Travis.
  • pytest monkeypatch fixture. I haven't looked what the difference is compared to pytest-mock to be honest.

@glemaitre
Copy link
Member Author

if statement in the test. The numpy <= 1.9 will be tested in one of the build on Travis.

+1 on that one. It seems good enough to me. Thanks

@glemaitre
Copy link
Member Author

@jnothman @lesteve I think this is ready for a first look.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should support other missing_values indicators?

@@ -2357,7 +2361,10 @@ def _transform_col(self, X_col, quantiles, inverse):
X_col[lower_bounds_idx] = lower_bound_y
# for forward transform, match the output PDF
if not inverse:
X_col = output_distribution.ppf(X_col)
# comparison with NaN will raise a warning which we make silent
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a numpy error, can use np.errstate

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a scipy warning

@glemaitre
Copy link
Member Author

Do you think we should support other missing_values indicators?

I give it some thought and I tried couple of stuff. I have the following interrogations:

  • The QuantileTransformer is actually converting X to a float array. So the simpler way is to replace the missing_values by NaN and the processing will remain the same.
  • If we actually want to keep the data type of X (only at fit because we will return a float X at transform anyway), we can compute compute the quantiles only on the sliced X_col. However, at transform, you will need to convert X to a float and in this case the easier way is to fall back on the first solution (replace missing_values by NaN).

Having implemented both approaches (mid-way) I think that only replacing missing_value by NaN is sufficient.

def _check_inputs(self, X, accept_sparse_negative=False):
"""Check inputs before fit and transform"""
if sparse.issparse(X):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why before check_array?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_array will convert the matrix into a float dtype. I wanted to compare when data can still be int.
However I could also use np.isclose which can handle NaN as well.

and not np.isfinite(X[~np.isnan(X)]).all()):
raise ValueError("Input contains infinity"
" or a value too large for %r." % X.dtype)
if np.count_nonzero(self._mask_missing_values):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is strange to have this in a function whose argument is X, not _mask_missing_values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect you should avoid storing this mask as an attribute.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function will get away once the #10455 is addressed.

self._percentile_func = np.nanpercentile
else:
raise NotImplementedError(
'QuantileTransformer does not handle NaN value with'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it easy enough to just implement in sklearn.utils.fixes:

def nanpercentile(a, q):
    return np.percentile(np.compress(a, ~np.isnan(a)), q)

seeing as we don't use the other features of nanpercentile?

It is annoying to have parts of the library with different minimum numpy requirements. It means that code is not portable across supported platforms.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right

@glemaitre
Copy link
Member Author

glemaitre commented Mar 18, 2018

I'm not sure that's applicable to QuantileTransformer...?

True ... we will get always float as output. So it should be for sure in a separate PR.

@glemaitre
Copy link
Member Author

@jnothman I added a test_common file. Could you check that the creation of the instance is ok or you would see another way to create the instance of the estimator on the fly (a dict for instance?)

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test. It essentially also tests that the estimators are feature-wise... So we could in theory remove some existing tests


@pytest.mark.parametrize(
"est, X, n_missing",
_generate_tuple_transformer_missing_value()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get why this is better than just parameterizing est directly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could even consider a list of all feature-wise preprocessors then xfail some...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could even consider a list of all feature-wise preprocessors then xfail some...

I agree. We could switch to this behaviour when a majority of those preprocessor support this feature.

rng.randint(X.shape[1], size=n_missing)] = np.nan
X_train, X_test = train_test_split(X)
# sanity check
assert not np.all(np.isnan(X_train), axis=0).any()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably also check that there are NaNs in both train and test

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a docstring note along the lines of "NaNs are treated as missing values: disregarded in fit, and maintained in transform". Perhaps that's too terse.

X_col = .5 * (np.interp(X_col, quantiles, self.references_)
- np.interp(-X_col, -quantiles[::-1],
-self.references_[::-1]))
X_col[isfinite_mask] = .5 * (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw, it's possible that np.ma would handle the none-missing case more efficiently than using an ad-hoc hoc mask. I've not checked.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Playing around, I think that it will trigger the same number of copy.

X_train, X_test = train_test_split(X)
# sanity check
assert not np.all(np.isnan(X_train), axis=0).any()
assert np.any(X_train, axis=0).all()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the right condition?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check that there is some Nan in each column.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ups yes there I forgot to check for NaN :)



from sklearn.datasets import load_iris

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why so much vertical space?

@glemaitre
Copy link
Member Author

@lesteve We have 2 approved here. Do you want to make a quick review before to merge this, hopefully :)

"""Force the output of nanpercentile to be finite."""
percentile = nanpercentile(column_data, percentiles)
with np.errstate(invalid='ignore'): # hide NaN comparison warnings
if np.all(np.isclose(percentile, np.nan, equal_nan=True)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not:

if np.all(np.isnan(percentile))

If you use that I think you can remove the with np.errstate(...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. no idea how I come to something so complex.

percentile = nanpercentile(column_data, percentiles)
with np.errstate(invalid='ignore'): # hide NaN comparison warnings
if np.all(np.isclose(percentile, np.nan, equal_nan=True)):
warnings.warn("All samples in a column of X are NaN.")
Copy link
Member

@lesteve lesteve Mar 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can mention in the warning that you are returning 0 for all the quantiles?

It would be nice if you could test you get the warning when expected.

Bonus points if you check that you do not get any warning when you don't expect a warning.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure anymore why I force percentile to be finite.

@glemaitre
Copy link
Member Author

@jnothman @lesteve would it makes sense to let the quantiles to nan? It would not affect the rest of the processing I think.

@glemaitre
Copy link
Member Author

I added a test to check that inverse transform is behaving properly in the common test.
I am checking that the quantile are nan when a column is NaN.
It seems a better behavior than forcing the column to zero. I really don't recall what was my argument to do so.

@jnothman
Copy link
Member

jnothman commented Mar 20, 2018 via email

@glemaitre
Copy link
Member Author

I suppose the problem with making quantiles NaN is that for finite data passed to transform, you'd get NaN when transformed. That sort of makes sense... I suppose if NaN is in the data we assume it will be handled downstream.

That's true. But it seems a more logical think to map finite to NaN if during training we did not learn anything (due to a full NaN column). So I think that the way right now is ok.

@glemaitre
Copy link
Member Author

@lesteve you can have a second look at it and tell us if you it makes sense to you.

@glemaitre
Copy link
Member Author

ping @lesteve @qinhanmin2014

@rth
Copy link
Member

rth commented Apr 21, 2018

LGTM

Given that there are already 2 +1 and Loic's comments were addressed, as far as I can tell, will merge when CI is green.

@rth rth merged commit c3548a8 into scikit-learn:master Apr 21, 2018
@jnothman jnothman mentioned this pull request Jun 16, 2018
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants