From 89c1018c6413bde20b0e74dfedf740870d808676 Mon Sep 17 00:00:00 2001 From: TomDLT Date: Thu, 30 Apr 2015 18:48:52 +0200 Subject: [PATCH] ENH improve check_array ENH improve check_array to warn on dtype conversions ENH make check_array accept several dtypes ENH change validation with improved check_array ENH change astype to avoid copy if possible ENH remove warn_if_not_float --- doc/developers/utilities.rst | 3 - sklearn/cluster/k_means_.py | 9 +- sklearn/cluster/tests/test_k_means.py | 3 +- sklearn/linear_model/stochastic_gradient.py | 3 +- sklearn/manifold/locally_linear.py | 9 +- sklearn/manifold/spectral_embedding_.py | 10 +- sklearn/naive_bayes.py | 6 +- sklearn/preprocessing/data.py | 32 +++--- sklearn/preprocessing/tests/test_data.py | 9 +- sklearn/utils/__init__.py | 3 +- sklearn/utils/graph.py | 3 +- sklearn/utils/random.py | 4 +- sklearn/utils/testing.py | 6 +- sklearn/utils/tests/test_validation.py | 80 +++++++++++++- sklearn/utils/validation.py | 116 +++++++++++++------- 15 files changed, 200 insertions(+), 96 deletions(-) diff --git a/doc/developers/utilities.rst b/doc/developers/utilities.rst index c5740632337fa..cea51d8532077 100644 --- a/doc/developers/utilities.rst +++ b/doc/developers/utilities.rst @@ -43,9 +43,6 @@ should be used when applicable. be sliced or indexed using safe_index. This is used to validate input for cross-validation. -- :func:`warn_if_not_float`: Warn if input is not a floating-point value. - the input ``X`` is assumed to have ``X.dtype``. - If your code relies on a random number generator, it should never use functions like ``numpy.random.random`` or ``numpy.random.normal``. This approach can lead to repeatability issues in unit tests. Instead, a diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 9cae6338f1a76..ee450ed302ca0 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -27,6 +27,7 @@ from ..utils import as_float_array from ..utils import gen_batches from ..utils.validation import check_is_fitted +from ..utils.validation import FLOAT_DTYPES from ..utils.random import choice from ..externals.joblib import Parallel from ..externals.joblib import delayed @@ -759,18 +760,14 @@ def _check_fit_data(self, X): return X def _check_test_data(self, X): - X = check_array(X, accept_sparse='csr') + X = check_array(X, accept_sparse='csr', dtype=FLOAT_DTYPES, + warn_on_dtype=True) n_samples, n_features = X.shape expected_n_features = self.cluster_centers_.shape[1] if not n_features == expected_n_features: raise ValueError("Incorrect number of features. " "Got %d features, expected %d" % ( n_features, expected_n_features)) - if X.dtype.kind != 'f': - warnings.warn("Got data type %s, converted to float " - "to avoid overflows" % X.dtype, - RuntimeWarning, stacklevel=2) - X = X.astype(np.float) return X diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index f2d086523970d..3b3dec551ff12 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -17,6 +17,7 @@ from sklearn.utils.testing import assert_warns from sklearn.utils.testing import if_not_mac_os +from sklearn.utils.validation import DataConversionWarning from sklearn.utils.extmath import row_norms from sklearn.metrics.cluster import v_measure_score from sklearn.cluster import KMeans, k_means @@ -45,7 +46,7 @@ def test_kmeans_dtype(): X = rnd.normal(size=(40, 2)) X = (X * 10).astype(np.uint8) km = KMeans(n_init=1).fit(X) - pred_x = assert_warns(RuntimeWarning, km.predict, X) + pred_x = assert_warns(DataConversionWarning, km.predict, X) assert_array_equal(km.labels_, pred_x) diff --git a/sklearn/linear_model/stochastic_gradient.py b/sklearn/linear_model/stochastic_gradient.py index 6eb9c47648070..f4a38431513ea 100644 --- a/sklearn/linear_model/stochastic_gradient.py +++ b/sklearn/linear_model/stochastic_gradient.py @@ -22,6 +22,7 @@ from ..externals import six from .sgd_fast import plain_sgd, average_sgd +from ..utils.fixes import astype from ..utils.seq_dataset import ArrayDataset, CSRDataset from ..utils import compute_class_weight from .sgd_fast import Hinge @@ -867,7 +868,7 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate, n_iter, sample_weight, coef_init, intercept_init): X, y = check_X_y(X, y, "csr", copy=False, order='C', dtype=np.float64) - y = y.astype(np.float64) + y = astype(y, np.float64, copy=False) n_samples, n_features = X.shape diff --git a/sklearn/manifold/locally_linear.py b/sklearn/manifold/locally_linear.py index 5d78713115146..081be38e41e87 100644 --- a/sklearn/manifold/locally_linear.py +++ b/sklearn/manifold/locally_linear.py @@ -11,6 +11,7 @@ from ..utils import check_random_state, check_array from ..utils.arpack import eigsh from ..utils.validation import check_is_fitted +from ..utils.validation import FLOAT_DTYPES from ..neighbors import NearestNeighbors @@ -38,14 +39,10 @@ def barycenter_weights(X, Z, reg=1e-3): ----- See developers note for more information. """ - X = np.asarray(X) - Z = np.asarray(Z) + X = check_array(X, dtype=FLOAT_DTYPES) + Z = check_array(Z, dtype=FLOAT_DTYPES, allow_nd=True) n_samples, n_neighbors = X.shape[0], Z.shape[1] - if X.dtype.kind == 'i': - X = X.astype(np.float) - if Z.dtype.kind == 'i': - Z = Z.astype(np.float) B = np.empty((n_samples, n_neighbors), dtype=X.dtype) v = np.ones(n_neighbors, dtype=X.dtype) diff --git a/sklearn/manifold/spectral_embedding_.py b/sklearn/manifold/spectral_embedding_.py index 6cd588faa3740..c7aafbda05911 100644 --- a/sklearn/manifold/spectral_embedding_.py +++ b/sklearn/manifold/spectral_embedding_.py @@ -263,7 +263,9 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, # problem. if not sparse.issparse(laplacian): warnings.warn("AMG works better for sparse matrices") - laplacian = laplacian.astype(np.float) # lobpcg needs native floats + # lobpcg needs double precision floats + laplacian = check_array(laplacian, dtype=np.float64, + accept_sparse=True) laplacian = _set_diag(laplacian, 1) ml = smoothed_aggregation_solver(check_array(laplacian, 'csr')) M = ml.aspreconditioner() @@ -276,7 +278,9 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, raise ValueError elif eigen_solver == "lobpcg": - laplacian = laplacian.astype(np.float) # lobpcg needs native floats + # lobpcg needs double precision floats + laplacian = check_array(laplacian, dtype=np.float64, + accept_sparse=True) if n_nodes < 5 * n_components + 1: # see note above under arpack why lobpcg has problems with small # number of nodes @@ -286,8 +290,6 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, lambdas, diffusion_map = eigh(laplacian) embedding = diffusion_map.T[:n_components] * dd else: - # lobpcg needs native floats - laplacian = laplacian.astype(np.float) laplacian = _set_diag(laplacian, 1) # We increase the number of eigenvectors requested, as lobpcg # doesn't behave well in low dimension diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 1ecf4ad5bcc5f..7161a2a99f44e 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -472,7 +472,8 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): msg = "X.shape[0]=%d and y.shape[0]=%d are incompatible." raise ValueError(msg % (X.shape[0], y.shape[0])) - # convert to float to support sample weight consistently + # label_binarize() returns arrays with dtype=np.int64. + # We convert it to np.float64 to support sample_weight consistently Y = Y.astype(np.float64) if sample_weight is not None: Y *= check_array(sample_weight).T @@ -520,7 +521,8 @@ def fit(self, X, y, sample_weight=None): if Y.shape[1] == 1: Y = np.concatenate((1 - Y, Y), axis=1) - # convert to float to support sample weight consistently; + # LabelBinarizer().fit_transform() returns arrays with dtype=np.int64. + # We convert it to np.float64 to support sample_weight consistently; # this means we also don't have to cast X to floating point Y = Y.astype(np.float64) if sample_weight is not None: diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index fa268280e1241..642329f851055 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -15,16 +15,14 @@ from ..base import BaseEstimator, TransformerMixin from ..externals import six from ..utils import check_array -from ..utils import warn_if_not_float from ..utils.extmath import row_norms -from ..utils.fixes import (combinations_with_replacement as combinations_w_r, - bincount) -from ..utils.fixes import isclose +from ..utils.fixes import combinations_with_replacement as combinations_w_r from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1, inplace_csr_row_normalize_l2) from ..utils.sparsefuncs import (inplace_column_scale, mean_variance_axis, min_max_axis) -from ..utils.validation import check_is_fitted +from ..utils.validation import check_is_fitted, FLOAT_DTYPES + zip = six.moves.zip map = six.moves.map @@ -115,8 +113,9 @@ def scale(X, axis=0, with_mean=True, with_std=True, copy=True): scaling using the ``Transformer`` API (e.g. as part of a preprocessing :class:`sklearn.pipeline.Pipeline`) """ - X = check_array(X, accept_sparse='csr', copy=copy, ensure_2d=False) - warn_if_not_float(X, estimator='The scale function') + X = check_array(X, accept_sparse='csr', copy=copy, ensure_2d=False, + warn_on_dtype=True, estimator='the scale function', + dtype=FLOAT_DTYPES) if sparse.issparse(X): if with_mean: raise ValueError( @@ -224,8 +223,8 @@ def fit(self, X, y=None): The data used to compute the per-feature minimum and maximum used for later scaling along the features axis. """ - X = check_array(X, copy=self.copy, ensure_2d=False) - warn_if_not_float(X, estimator=self) + X = check_array(X, copy=self.copy, ensure_2d=False, warn_on_dtype=True, + estimator=self, dtype=FLOAT_DTYPES) feature_range = self.feature_range if feature_range[0] >= feature_range[1]: raise ValueError("Minimum of desired feature range must be smaller" @@ -346,9 +345,8 @@ def fit(self, X, y=None): used for later scaling along the features axis. """ X = check_array(X, accept_sparse='csr', copy=self.copy, - ensure_2d=False) - if warn_if_not_float(X, estimator=self): - X = X.astype(np.float) + ensure_2d=False, warn_on_dtype=True, + estimator=self, dtype=FLOAT_DTYPES) if sparse.issparse(X): if self.with_mean: raise ValueError( @@ -379,9 +377,9 @@ def transform(self, X, y=None, copy=None): check_is_fitted(self, 'std_') copy = copy if copy is not None else self.copy - X = check_array(X, accept_sparse='csr', copy=copy, ensure_2d=False) - if warn_if_not_float(X, estimator=self): - X = X.astype(np.float) + X = check_array(X, accept_sparse='csr', copy=copy, + ensure_2d=False, warn_on_dtype=True, + estimator=self, dtype=FLOAT_DTYPES) if sparse.issparse(X): if self.with_mean: raise ValueError( @@ -600,8 +598,8 @@ def normalize(X, norm='l2', axis=1, copy=True): else: raise ValueError("'%d' is not a supported axis" % axis) - X = check_array(X, sparse_format, copy=copy) - warn_if_not_float(X, 'The normalize function') + X = check_array(X, sparse_format, copy=copy, warn_on_dtype=True, + estimator='the normalize function', dtype=FLOAT_DTYPES) if axis == 0: X = X.T diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index fe536517837d3..df49196f91d70 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -29,6 +29,7 @@ from sklearn.preprocessing.data import MinMaxScaler from sklearn.preprocessing.data import add_dummy_feature from sklearn.preprocessing.data import PolynomialFeatures +from sklearn.utils.validation import DataConversionWarning from sklearn import datasets @@ -499,12 +500,12 @@ def test_warning_scaling_integers(): X = np.array([[1, 2, 0], [0, 0, 0]], dtype=np.uint8) - w = "assumes floating point values as input, got uint8" + w = "Data with input dtype uint8 was converted to float64" clean_warning_registry() - assert_warns_message(UserWarning, w, scale, X) - assert_warns_message(UserWarning, w, StandardScaler().fit, X) - assert_warns_message(UserWarning, w, MinMaxScaler().fit, X) + assert_warns_message(DataConversionWarning, w, scale, X) + assert_warns_message(DataConversionWarning, w, StandardScaler().fit, X) + assert_warns_message(DataConversionWarning, w, MinMaxScaler().fit, X) def test_normalizer_l1(): diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index b11feafacf405..a3539c1c6c0df 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -9,7 +9,7 @@ from .murmurhash import murmurhash3_32 from .validation import (as_float_array, - assert_all_finite, warn_if_not_float, + assert_all_finite, check_random_state, column_or_1d, check_array, check_consistent_length, check_X_y, indexable, check_symmetric, DataConversionWarning) @@ -19,7 +19,6 @@ __all__ = ["murmurhash3_32", "as_float_array", "assert_all_finite", "check_array", - "warn_if_not_float", "check_random_state", "compute_class_weight", "compute_sample_weight", "column_or_1d", "safe_indexing", diff --git a/sklearn/utils/graph.py b/sklearn/utils/graph.py index 650e71841d359..595e0a7e15408 100644 --- a/sklearn/utils/graph.py +++ b/sklearn/utils/graph.py @@ -13,6 +13,7 @@ import numpy as np from scipy import sparse +from .validation import check_array from .graph_shortest_path import graph_shortest_path @@ -113,7 +114,7 @@ def graph_laplacian(csgraph, normed=False, return_diag=False): if normed and (np.issubdtype(csgraph.dtype, np.int) or np.issubdtype(csgraph.dtype, np.uint)): - csgraph = csgraph.astype(np.float) + csgraph = check_array(csgraph, dtype=np.float64, accept_sparse=True) if sparse.isspmatrix(csgraph): return _laplacian_sparse(csgraph, normed=normed, diff --git a/sklearn/utils/random.py b/sklearn/utils/random.py index 781983cf4541c..a52671bd4514b 100644 --- a/sklearn/utils/random.py +++ b/sklearn/utils/random.py @@ -8,7 +8,7 @@ import array from sklearn.utils import check_random_state - +from sklearn.utils.fixes import astype from ._random import sample_without_replacement __all__ = ['sample_without_replacement', 'choice'] @@ -238,7 +238,7 @@ def random_choice_csc(n_samples, classes, class_probability=None, if classes[j].dtype.kind != 'i': raise ValueError("class dtype %s is not supported" % classes[j].dtype) - classes[j] = classes[j].astype(int) + classes[j] = astype(classes[j], np.int64, copy=False) # use uniform distribution if no class_probability is given if class_probability is None: diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index d9d749dce0e9c..df5c598a375ea 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -212,7 +212,7 @@ def assert_warns_message(warning_class, message, func, *args, **kw): raise AssertionError("No warning raised when calling %s" % func.__name__) - found = [warning.category is warning_class for warning in w] + found = [issubclass(warning.category, warning_class) for warning in w] if not any(found): raise AssertionError("No warning raised for %s with class " "%s" @@ -235,8 +235,8 @@ def assert_warns_message(warning_class, message, func, *args, **kw): if not message_found: raise AssertionError("Did not receive the message you expected " - "('%s') for <%s>." - % (message, func.__name__)) + "('%s') for <%s>, got: '%s'" + % (message, func.__name__, msg)) return result diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 17f7ba57bf773..216825a4c1976 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -6,11 +6,14 @@ from itertools import product import numpy as np -from numpy.testing import assert_array_equal, assert_warns +from numpy.testing import assert_array_equal import scipy.sparse as sp from nose.tools import assert_raises, assert_true, assert_false, assert_equal from sklearn.utils.testing import assert_raises_regexp +from sklearn.utils.testing import assert_no_warnings +from sklearn.utils.testing import assert_warns_message +from sklearn.utils.testing import assert_warns from sklearn.utils import as_float_array, check_array, check_symmetric from sklearn.utils import check_X_y from sklearn.utils.mocking import MockDataFrame @@ -25,7 +28,9 @@ NotFittedError, has_fit_parameter, check_is_fitted, - check_consistent_length) + check_consistent_length, + DataConversionWarning, +) from sklearn.utils.testing import assert_raise_message @@ -234,6 +239,77 @@ def test_check_array_dtype_stability(): assert_equal(check_array(X, ensure_2d=False).dtype.kind, "i") +def test_check_array_dtype_warning(): + X_int_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + X_float64 = np.asarray(X_int_list, dtype=np.float64) + X_float32 = np.asarray(X_int_list, dtype=np.float32) + X_int64 = np.asarray(X_int_list, dtype=np.int64) + X_csr_float64 = sp.csr_matrix(X_float64) + X_csr_float32 = sp.csr_matrix(X_float32) + X_csc_float32 = sp.csc_matrix(X_float32) + X_csc_int32 = sp.csc_matrix(X_int64, dtype=np.int32) + y = [0, 0, 1] + integer_data = [X_int64, X_csc_int32] + float64_data = [X_float64, X_csr_float64] + float32_data = [X_float32, X_csr_float32, X_csc_float32] + for X in integer_data: + X_checked = assert_no_warnings(check_array, X, dtype=np.float64, + accept_sparse=True) + assert_equal(X_checked.dtype, np.float64) + + X_checked = assert_warns(DataConversionWarning, check_array, X, + dtype=np.float64, + accept_sparse=True, warn_on_dtype=True) + assert_equal(X_checked.dtype, np.float64) + + # Check that the warning message includes the name of the Estimator + X_checked = assert_warns_message(DataConversionWarning, + 'SomeEstimator', + check_array, X, + dtype=[np.float64, np.float32], + accept_sparse=True, + warn_on_dtype=True, + estimator='SomeEstimator') + assert_equal(X_checked.dtype, np.float64) + + X_checked, y_checked = assert_warns_message( + DataConversionWarning, 'KNeighborsClassifier', + check_X_y, X, y, dtype=np.float64, accept_sparse=True, + warn_on_dtype=True, estimator=KNeighborsClassifier()) + + assert_equal(X_checked.dtype, np.float64) + + for X in float64_data: + X_checked = assert_no_warnings(check_array, X, dtype=np.float64, + accept_sparse=True, warn_on_dtype=True) + assert_equal(X_checked.dtype, np.float64) + X_checked = assert_no_warnings(check_array, X, dtype=np.float64, + accept_sparse=True, warn_on_dtype=False) + assert_equal(X_checked.dtype, np.float64) + + for X in float32_data: + X_checked = assert_no_warnings(check_array, X, + dtype=[np.float64, np.float32], + accept_sparse=True) + assert_equal(X_checked.dtype, np.float32) + assert_true(X_checked is X) + + X_checked = assert_no_warnings(check_array, X, + dtype=[np.float64, np.float32], + accept_sparse=['csr', 'dok'], + copy=True) + assert_equal(X_checked.dtype, np.float32) + assert_false(X_checked is X) + + X_checked = assert_no_warnings(check_array, X_csc_float32, + dtype=[np.float64, np.float32], + accept_sparse=['csr', 'dok'], + copy=False) + assert_equal(X_checked.dtype, np.float32) + assert_false(X_checked is X_csc_float32) + assert_equal(X_checked.format, 'csr') + + def test_check_array_min_samples_and_features_messages(): # empty list is considered 2D by default: msg = "0 feature(s) (shape=(1, 0)) while a minimum of 1 is required." diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 0d5969b3e3643..dc8d938a8d077 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -15,6 +15,8 @@ from ..externals import six from inspect import getargspec +FLOAT_DTYPES = (np.float64, np.float32, np.float16) + class DataConversionWarning(UserWarning): """A warning on implicit data conversions happening in the code""" @@ -232,25 +234,27 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy, spmatrix_converted : scipy sparse matrix. Matrix that is ensured to have an allowed type. """ - if accept_sparse is None: + if accept_sparse in [None, False]: raise TypeError('A sparse matrix was passed, but dense ' 'data is required. Use X.toarray() to ' 'convert to a dense numpy array.') - sparse_type = spmatrix.format if dtype is None: dtype = spmatrix.dtype - if sparse_type in accept_sparse: - # correct type - if dtype == spmatrix.dtype: - # correct dtype - if copy: - spmatrix = spmatrix.copy() - else: - # convert dtype - spmatrix = spmatrix.astype(dtype) - else: - # create new - spmatrix = spmatrix.asformat(accept_sparse[0]).astype(dtype) + + changed_format = False + if (isinstance(accept_sparse, (list, tuple)) + and spmatrix.format not in accept_sparse): + # create new with correct sparse + spmatrix = spmatrix.asformat(accept_sparse[0]) + changed_format = True + + if dtype != spmatrix.dtype: + # convert dtype + spmatrix = spmatrix.astype(dtype) + elif copy and not changed_format: + # force copy + spmatrix = spmatrix.copy() + if force_all_finite: if not hasattr(spmatrix, "data"): warnings.warn("Can't check %s sparse matrix for nan or inf." @@ -262,7 +266,8 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy, def check_array(array, accept_sparse=None, dtype="numeric", order=None, copy=False, force_all_finite=True, ensure_2d=True, - allow_nd=False, ensure_min_samples=1, ensure_min_features=1): + allow_nd=False, ensure_min_samples=1, ensure_min_features=1, + warn_on_dtype=False, estimator=None): """Input validation on an array, list, sparse matrix or similar. By default, the input is converted to an at least 2nd numpy array. @@ -280,9 +285,11 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, If the input is sparse but not in the allowed format, it will be converted to the first listed format. - dtype : string, type or None (default="numeric") + dtype : string, type, list of types or None (default="numeric") Data type of result. If None, the dtype of the input is preserved. If "numeric", dtype is preserved unless array.dtype is object. + If dtype is a list of types, conversion on the first type is only + performed if the dtype of the input is not in the list. order : 'F', 'C' or None (default=None) Whether an array will be forced to be fortran or c-style. @@ -311,6 +318,13 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, dimensions or is originally 1D and ``ensure_2d`` is True. Setting to 0 disables this check. + warn_on_dtype : boolean (default=False) + Raise DataConversionWarning if the dtype of the input data structure + does not match the requested dtype, causing a memory copy. + + estimator : str or estimator instance (default=None) + If passed, include the name of the estimator in warning messages. + Returns ------- X_converted : object @@ -322,20 +336,34 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, # store whether originally we wanted numeric dtype dtype_numeric = dtype == "numeric" - if sp.issparse(array): - if dtype_numeric: + dtype_orig = getattr(array, "dtype", None) + if not hasattr(dtype_orig, 'kind'): + # not a data type (e.g. a column named dtype in a pandas DataFrame) + dtype_orig = None + + if dtype_numeric: + if dtype_orig is not None and dtype_orig.kind == "O": + # if input is object, convert to float. + dtype = np.float64 + else: + dtype = None + + if isinstance(dtype, (list, tuple)): + if dtype_orig is not None and dtype_orig in dtype: + # no dtype conversion required dtype = None + else: + # dtype conversion required. Let's select the first element of the + # list of accepted types. + dtype = dtype[0] + + if sp.issparse(array): array = _ensure_sparse_format(array, accept_sparse, dtype, copy, force_all_finite) else: if ensure_2d: array = np.atleast_2d(array) - if dtype_numeric: - if hasattr(array, "dtype") and getattr(array.dtype, "kind", None) == "O": - # if input is object, convert to float. - dtype = np.float64 - else: - dtype = None + array = np.array(array, dtype=dtype, order=order, copy=copy) # make sure we actually converted to numeric: if dtype_numeric and array.dtype.kind == "O": @@ -360,13 +388,23 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, raise ValueError("Found array with %d feature(s) (shape=%s) while" " a minimum of %d is required." % (n_features, shape_repr, ensure_min_features)) + + if warn_on_dtype and dtype_orig is not None and array.dtype != dtype_orig: + msg = ("Data with input dtype %s was converted to %s" + % (dtype_orig, array.dtype)) + if estimator is not None: + if not isinstance(estimator, six.string_types): + estimator = estimator.__class__.__name__ + msg += " by %s" % estimator + warnings.warn(msg, DataConversionWarning) return array def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None, copy=False, force_all_finite=True, ensure_2d=True, allow_nd=False, multi_output=False, ensure_min_samples=1, - ensure_min_features=1, y_numeric=False): + ensure_min_features=1, y_numeric=False, + warn_on_dtype=False, estimator=None): """Input validation for standard estimators. Checks X and y for consistent length, enforces X 2d and y 1d. @@ -389,9 +427,11 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None, copy=False, If the input is sparse but not in the allowed format, it will be converted to the first listed format. - dtype : string, type or None (default="numeric") + dtype : string, type, list of types or None (default="numeric") Data type of result. If None, the dtype of the input is preserved. If "numeric", dtype is preserved unless array.dtype is object. + If dtype is a list of types, conversion on the first type is only + performed if the dtype of the input is not in the list. order : 'F', 'C' or None (default=None) Whether an array will be forced to be fortran or c-style. @@ -429,6 +469,13 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None, copy=False, it is converted to float64. Should only be used for regression algorithms. + warn_on_dtype : boolean (default=False) + Raise DataConversionWarning if the dtype of the input data structure + does not match the requested dtype, causing a memory copy. + + estimator : str or estimator instance (default=None) + If passed, include the name of the estimator in warning messages. + Returns ------- X_converted : object @@ -436,7 +483,7 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None, copy=False, """ X = check_array(X, accept_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, - ensure_min_features) + ensure_min_features, warn_on_dtype, estimator) if multi_output: y = check_array(y, 'csr', force_all_finite=True, ensure_2d=False, dtype=None) @@ -480,21 +527,6 @@ def column_or_1d(y, warn=False): raise ValueError("bad input shape {0}".format(shape)) -def warn_if_not_float(X, estimator='This algorithm'): - """Warning utility function to check that data type is floating point. - - Returns True if a warning was raised (i.e. the input is not float) and - False otherwise, for easier input validation. - """ - if not isinstance(estimator, six.string_types): - estimator = estimator.__class__.__name__ - if X.dtype.kind != 'f': - warnings.warn("%s assumes floating point values as input, " - "got %s" % (estimator, X.dtype)) - return True - return False - - def check_random_state(seed): """Turn seed into a np.random.RandomState instance