Skip to content

[MRG] ENH Allow handling nan during input validation #8074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ def test_check_array():
X_nan[0, 0] = np.nan
assert_raises(ValueError, check_array, X_nan)
check_array(X_inf, force_all_finite=False) # no raise
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tests for force_all_finite=False, allow_nan=False

# allow_nan check
check_array(X_nan, force_all_finite=True, allow_nan=True) # no raise
# allow_nan check should not hinder check for inf
assert_raises(ValueError, check_array, X_inf, force_all_finite=True,
allow_nan=False)

# dtype and order enforcement.
X_C = np.arange(4).reshape(2, 2).copy("C")
Expand Down
72 changes: 46 additions & 26 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,33 @@
warnings.simplefilter('ignore', NonBLASDotWarning)


def _assert_all_finite(X):
def _assert_all_finite(X, allow_nan=False):
"""Like assert_all_finite, but only for ndarray."""
X = np.asanyarray(X)
# First try an O(n) time, O(1) space solution for the common case that
# everything is finite; fall back to O(n) space np.isfinite to prevent
# false positives from overflow in sum method.
if (X.dtype.char in np.typecodes['AllFloat'] and not np.isfinite(X.sum())
and not np.isfinite(X).all()):
raise ValueError("Input contains NaN, infinity"
" or a value too large for %r." % X.dtype)
if allow_nan:
def any_not_isfinite(X): return np.isinf(X).any()
np_sum = np.nansum
else:
def any_not_isfinite(X): return not np.isfinite(X).all()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman I fixed your comment in this PR...

np_sum = np.sum

if (X.dtype.char in np.typecodes['AllFloat'] and
not np.isfinite(np_sum(X)) and any_not_isfinite(X)):
raise ValueError("Input contains %sinfinity or a value too large for "
"%r." % ("" if allow_nan else "NaN, ", X.dtype))


def assert_all_finite(X):
"""Throw a ValueError if X contains NaN or infinity.
def assert_all_finite(X, allow_nan=False):
"""Throw a ValueError if X contains infinity or NaN (if allow_nan is False)

Input MUST be an np.ndarray instance or a scipy.sparse matrix."""
_assert_all_finite(X.data if sp.issparse(X) else X)
_assert_all_finite(X.data if sp.issparse(X) else X, allow_nan)


def as_float_array(X, copy=True, force_all_finite=True):
def as_float_array(X, copy=True, force_all_finite=True, allow_nan=False):
"""Converts an array-like to an array of floats

The new dtype will be np.float32 or np.float64, depending on the original
Expand All @@ -65,6 +72,9 @@ def as_float_array(X, copy=True, force_all_finite=True):
force_all_finite : boolean (default=True)
Whether to raise an error on np.inf and np.nan in X.

allow_nan : boolean (default=False)
Whether to allow nan values in X.

Returns
-------
XT : {array, sparse matrix}
Expand All @@ -74,7 +84,7 @@ def as_float_array(X, copy=True, force_all_finite=True):
and not sp.issparse(X)):
return check_array(X, ['csr', 'csc', 'coo'], dtype=np.float64,
copy=copy, force_all_finite=force_all_finite,
ensure_2d=False)
allow_nan=allow_nan, ensure_2d=False)
elif sp.issparse(X) and X.dtype in [np.float32, np.float64]:
return X.copy() if copy else X
elif X.dtype in [np.float32, np.float64]: # is numpy array
Expand Down Expand Up @@ -190,7 +200,7 @@ def indexable(*iterables):


def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy,
force_all_finite):
force_all_finite, allow_nan):
"""Convert a sparse matrix to a given format.

Checks the sparse format of spmatrix and converts if necessary.
Expand All @@ -214,7 +224,11 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy,
be triggered by a conversion.

force_all_finite : boolean (default=True)
Whether to raise an error on np.inf and np.nan in X.
Whether to raise an error on np.inf and np.nan (if allow_nan is False)
in X.

allow_nan : boolean (default=True)
Whether to allow nan.

Returns
-------
Expand Down Expand Up @@ -247,14 +261,14 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy,
warnings.warn("Can't check %s sparse matrix for nan or inf."
% spmatrix.format)
else:
_assert_all_finite(spmatrix.data)
_assert_all_finite(spmatrix.data, allow_nan)
return spmatrix


def check_array(array, accept_sparse=None, dtype="numeric", order=None,
copy=False, force_all_finite=True, ensure_2d=True,
allow_nd=False, ensure_min_samples=1, ensure_min_features=1,
warn_on_dtype=False, estimator=None):
copy=False, force_all_finite=True, allow_nan=False,
ensure_2d=True, allow_nd=False, ensure_min_samples=1,
ensure_min_features=1, warn_on_dtype=False, estimator=None):
"""Input validation on an array, list, sparse matrix or similar.

By default, the input is converted to an at least 2D numpy array.
Expand Down Expand Up @@ -290,7 +304,10 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
be triggered by a conversion.

force_all_finite : boolean (default=True)
Whether to raise an error on np.inf and np.nan in X.
Whether to raise an error on np.inf in X.

allow_nan : boolean (default=False)
Whether to allow nan values in X.

ensure_2d : boolean (default=True)
Whether to raise a value error if X is not 2d.
Expand Down Expand Up @@ -359,7 +376,7 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,

if sp.issparse(array):
array = _ensure_sparse_format(array, accept_sparse, dtype, copy,
force_all_finite)
force_all_finite, allow_nan)
else:
array = np.array(array, dtype=dtype, order=order, copy=copy)

Expand All @@ -380,7 +397,7 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
raise ValueError("Found array with dim %d. %s expected <= 2."
% (array.ndim, estimator_name))
if force_all_finite:
_assert_all_finite(array)
_assert_all_finite(array, allow_nan)

shape_repr = _shape_repr(array.shape)
if ensure_min_samples > 0:
Expand All @@ -407,9 +424,9 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,


def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None,
copy=False, force_all_finite=True, ensure_2d=True,
allow_nd=False, multi_output=False, ensure_min_samples=1,
ensure_min_features=1, y_numeric=False,
copy=False, force_all_finite=True, allow_nan=False,
ensure_2d=True, allow_nd=False, multi_output=False,
ensure_min_samples=1, ensure_min_features=1, y_numeric=False,
warn_on_dtype=False, estimator=None):
"""Input validation for standard estimators.

Expand Down Expand Up @@ -450,6 +467,9 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None,
Whether to raise an error on np.inf and np.nan in X. This parameter
does not influence whether y can have np.inf or np.nan values.

allow_nan : boolean (default=False)
Whether to allow nan values in X.

ensure_2d : boolean (default=True)
Whether to make X at least 2d.

Expand Down Expand Up @@ -493,14 +513,14 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None,
The converted and validated y.
"""
X = check_array(X, accept_sparse, dtype, order, copy, force_all_finite,
ensure_2d, allow_nd, ensure_min_samples,
allow_nan, ensure_2d, allow_nd, ensure_min_samples,
ensure_min_features, warn_on_dtype, estimator)
if multi_output:
y = check_array(y, 'csr', force_all_finite=True, ensure_2d=False,
dtype=None)
y = check_array(y, 'csr', force_all_finite=True, allow_nan=allow_nan,
ensure_2d=False, dtype=None)
else:
y = column_or_1d(y, warn=True)
_assert_all_finite(y)
_assert_all_finite(y, allow_nan)
if y_numeric and y.dtype.kind == 'O':
y = y.astype(np.float64)

Expand Down