From 70d8165fae6c9d94ec8b08a01c5129052126475b Mon Sep 17 00:00:00 2001 From: Vighnesh Birodkar Date: Fri, 18 Mar 2016 04:57:10 -0400 Subject: [PATCH 01/31] Added CategoricalEncoder class - deprecating OneHotEncoder --- doc/modules/classes.rst | 2 +- doc/modules/preprocessing.rst | 49 ++-- .../ensemble/plot_feature_transformation.py | 6 +- sklearn/feature_extraction/dict_vectorizer.py | 7 +- sklearn/preprocessing/__init__.py | 2 + sklearn/preprocessing/data.py | 230 ++++++++++++++++-- sklearn/preprocessing/tests/test_data.py | 81 ++++-- sklearn/utils/testing.py | 2 +- 8 files changed, 317 insertions(+), 62 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 950a9320dd1af..601d41b682791 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1197,7 +1197,7 @@ See the :ref:`metrics` section of the user guide for further details. preprocessing.MaxAbsScaler preprocessing.MinMaxScaler preprocessing.Normalizer - preprocessing.OneHotEncoder + preprocessing.CategoricalEncoder preprocessing.PolynomialFeatures preprocessing.QuantileTransformer preprocessing.RobustScaler diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index a4e1364a85ae6..dd85a6c6faa98 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -461,38 +461,45 @@ not desired (i.e. the set of browsers was ordered arbitrarily). One possibility to convert categorical features to features that can be used with scikit-learn estimators is to use a one-of-K or one-hot encoding, which is -implemented in :class:`OneHotEncoder`. This estimator transforms each +implemented in :class:`CategoricalEncoder`. This estimator transforms each categorical feature with ``m`` possible values into ``m`` binary features, with only one active. Continuing the example above:: - >>> enc = preprocessing.OneHotEncoder() - >>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], [1, 0, 2]]) # doctest: +ELLIPSIS - OneHotEncoder(categorical_features='all', dtype=<... 'numpy.float64'>, - handle_unknown='error', n_values='auto', sparse=True) - >>> enc.transform([[0, 1, 3]]).toarray() - array([[ 1., 0., 0., 1., 0., 0., 0., 0., 1.]]) + >>> enc = preprocessing.CategoricalEncoder() + >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] + >>> enc.fit(X) # doctest: +ELLIPSIS + CategoricalEncoder(categorical_features='all', classes='auto', + dtype=<... 'numpy.float64'>, handle_unknown='error', + sparse=True) + >>> enc.transform([['female', 'from US', 'uses Safari']]).toarray() + array([[ 1., 0., 0., 1., 0., 1.]]) + By default, how many values each feature can take is inferred automatically from the dataset. -It is possible to specify this explicitly using the parameter ``n_values``. +It is possible to specify this explicitly using the parameter ``classes``. There are two genders, three possible continents and four web browsers in our dataset. -Then we fit the estimator, and transform a data point. -In the result, the first two numbers encode the gender, the next set of three -numbers the continent and the last four the web browser. Note that, if there is a possibilty that the training data might have missing categorical -features, one has to explicitly set ``n_values``. For example, - - >>> enc = preprocessing.OneHotEncoder(n_values=[2, 3, 4]) - >>> # Note that there are missing categorical values for the 2nd and 3rd - >>> # features - >>> enc.fit([[1, 2, 3], [0, 2, 0]]) # doctest: +ELLIPSIS - OneHotEncoder(categorical_features='all', dtype=<... 'numpy.float64'>, - handle_unknown='error', n_values=[2, 3, 4], sparse=True) - >>> enc.transform([[1, 0, 0]]).toarray() - array([[ 0., 1., 1., 0., 0., 1., 0., 0., 0.]]) +features, one has to explicitly set ``classes``. For example, + + >>> genders = ['male', 'female'] + >>> locations = ['from Europe', 'from US', 'from Africa', 'from Asia'] + >>> browsers = ['uses Safari', 'uses Firefox', 'uses IE', 'uses Chrome'] + >>> enc = preprocessing.CategoricalEncoder(classes=[genders, locations, browsers]) + >>> # Note that for there are missing categorical values for the 2nd and 3rd + >>> # feature + >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] + >>> enc.fit(X) # doctest: +ELLIPSIS + CategoricalEncoder(categorical_features='all', + classes=[...], + dtype=<... 'numpy.float64'>, handle_unknown='error', + sparse=True) + + >>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray() + array([[ 1., 0., 0., 0., 0., 1., 1., 0., 0., 0.]]) See :ref:`dict_feature_extraction` for categorical features that are represented as a dict, not as integers. diff --git a/examples/ensemble/plot_feature_transformation.py b/examples/ensemble/plot_feature_transformation.py index e004c167e67af..50347de57f12b 100644 --- a/examples/ensemble/plot_feature_transformation.py +++ b/examples/ensemble/plot_feature_transformation.py @@ -34,7 +34,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.ensemble import (RandomTreesEmbedding, RandomForestClassifier, GradientBoostingClassifier) -from sklearn.preprocessing import OneHotEncoder +from sklearn.preprocessing import CategoricalEncoder from sklearn.model_selection import train_test_split from sklearn.metrics import roc_curve from sklearn.pipeline import make_pipeline @@ -62,7 +62,7 @@ # Supervised transformation based on random forests rf = RandomForestClassifier(max_depth=3, n_estimators=n_estimator) -rf_enc = OneHotEncoder() +rf_enc = CategoricalEncoder() rf_lm = LogisticRegression() rf.fit(X_train, y_train) rf_enc.fit(rf.apply(X_train)) @@ -72,7 +72,7 @@ fpr_rf_lm, tpr_rf_lm, _ = roc_curve(y_test, y_pred_rf_lm) grd = GradientBoostingClassifier(n_estimators=n_estimator) -grd_enc = OneHotEncoder() +grd_enc = CategoricalEncoder() grd_lm = LogisticRegression() grd.fit(X_train, y_train) grd_enc.fit(grd.apply(X_train)[:, :, 0]) diff --git a/sklearn/feature_extraction/dict_vectorizer.py b/sklearn/feature_extraction/dict_vectorizer.py index 53804ed83ac45..d01067efba19d 100644 --- a/sklearn/feature_extraction/dict_vectorizer.py +++ b/sklearn/feature_extraction/dict_vectorizer.py @@ -39,7 +39,8 @@ class DictVectorizer(BaseEstimator, TransformerMixin): However, note that this transformer will only do a binary one-hot encoding when feature values are of type string. If categorical features are represented as numeric values such as int, the DictVectorizer can be - followed by OneHotEncoder to complete binary one-hot encoding. + followed by :class:`sklearn.preprocessing.CategoricalEncoder` to complete + binary one-hot encoding. Features that do not occur in a sample (mapping) will have a zero value in the resulting array/matrix. @@ -88,8 +89,8 @@ class DictVectorizer(BaseEstimator, TransformerMixin): See also -------- FeatureHasher : performs vectorization using only a hash function. - sklearn.preprocessing.OneHotEncoder : handles nominal/categorical features - encoded as columns of integers. + sklearn.preprocessing.CategoricalEncoder : handles nominal/categorical + features encoded as columns of arbitraty data types. """ def __init__(self, dtype=np.float64, separator="=", sparse=True, diff --git a/sklearn/preprocessing/__init__.py b/sklearn/preprocessing/__init__.py index 2b105709ffe08..0f5054e57f608 100644 --- a/sklearn/preprocessing/__init__.py +++ b/sklearn/preprocessing/__init__.py @@ -22,6 +22,7 @@ from .data import minmax_scale from .data import quantile_transform from .data import OneHotEncoder +from .data import CategoricalEncoder from .data import PolynomialFeatures @@ -46,6 +47,7 @@ 'QuantileTransformer', 'Normalizer', 'OneHotEncoder', + 'CategoricalEncoder', 'RobustScaler', 'StandardScaler', 'add_dummy_feature', diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index c9de8a99a0f3d..3f84bfbae3ac8 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -29,6 +29,11 @@ min_max_axis) from ..utils.validation import (check_is_fitted, check_random_state, FLOAT_DTYPES) +from .label import LabelEncoder +from ..utils.fixes import np_version +from ..utils.deprecation import deprecated + + BOUNDS_THRESHOLD = 1e-7 @@ -1677,28 +1682,27 @@ def add_dummy_feature(X, value=1.0): return np.hstack((np.ones((n_samples, 1)) * value, X)) -def _transform_selected(X, transform, selected="all", copy=True): - """Apply a transform function to portion of selected features - +def _apply_selected(X, transform, selected="all", dtype=np.float, copy=True, + return_val=True): + """Apply a function to portion of selected features Parameters ---------- - X : {array-like, sparse matrix}, shape [n_samples, n_features] + X : {array, sparse matrix}, shape [n_samples, n_features] Dense array or sparse matrix. - transform : callable A callable transform(X) -> X_transformed - copy : boolean, optional Copy X even if it could be avoided. - selected: "all" or array of indices or mask Specify which features to apply the transform to. - + return_val : boolean, optional + Whether to return the transformed matrix. If not set `None` is + returned. Returns ------- - X : array or sparse matrix, shape=(n_samples, n_features_new) + X : array or sparse matrix, shape=(n_samples, n_features_new) """ - X = check_array(X, accept_sparse='csc', copy=copy, dtype=FLOAT_DTYPES) + X = check_array(X, accept_sparse='csc', copy=copy, dtype=None) if isinstance(selected, six.string_types) and selected == "all": return transform(X) @@ -1721,14 +1725,16 @@ def _transform_selected(X, transform, selected="all", copy=True): return transform(X) else: X_sel = transform(X[:, ind[sel]]) - X_not_sel = X[:, ind[not_sel]] + X_not_sel = X[:, ind[not_sel]].astype(dtype) - if sparse.issparse(X_sel) or sparse.issparse(X_not_sel): - return sparse.hstack((X_sel, X_not_sel)) - else: - return np.hstack((X_sel, X_not_sel)) + if return_val: + if sparse.issparse(X_sel) or sparse.issparse(X_not_sel): + return sparse.hstack((X_sel, X_not_sel)) + else: + return np.hstack((X_sel, X_not_sel)) +@deprecated('`OneHotEncoder` is deprecated, use `CategoricalEncoder` instead.') class OneHotEncoder(BaseEstimator, TransformerMixin): """Encode categorical integer features using a one-hot aka one-of-K scheme. @@ -1902,8 +1908,8 @@ def fit_transform(self, X, y=None): Equivalent to self.fit(X).transform(X), but more convenient and more efficient. See fit for the parameters, transform for the return value. """ - return _transform_selected(X, self._fit_transform, - self.categorical_features, copy=True) + return _apply_selected(X, self._fit_transform, dtype=self.dtype, + selected=self.categorical_features, copy=True) def _transform(self, X): """Assumes X contains only categorical features.""" @@ -1958,8 +1964,8 @@ def transform(self, X): X_out : sparse matrix if sparse=True else a 2-d array, dtype=int Transformed input. """ - return _transform_selected(X, self._transform, - self.categorical_features, copy=True) + return _apply_selected(X, self._transform, dtype=self.dtype, + selected=self.categorical_features, copy=True) class QuantileTransformer(BaseEstimator, TransformerMixin): @@ -2440,3 +2446,189 @@ def quantile_transform(X, axis=0, n_quantiles=1000, else: raise ValueError("axis should be either equal to 0 or 1. Got" " axis={}".format(axis)) + + +class CategoricalEncoder(BaseEstimator, TransformerMixin): + """Encode categorical features using a one-hot aka one-of-K scheme. + + The input to this transformer should be a matrix of integers or strings, + denoting the values taken on by categorical (discrete) features. The + output will be a sparse matrix where each column corresponds to one + possible value of one feature. + + This encoding is needed for feeding categorical data to many scikit-learn + estimators, notably linear models and SVMs with the standard kernels. + Read more in the :ref:`User Guide `. + + Parameters + ---------- + classes : 'auto', 2D array of ints or strings or both. + Values per feature. + + - 'auto' : Determine classes automatically from the training data. + - array: ``classes[i]`` holds the classes expected in the ith column. + + categorical_features : 'all' or array of indices or mask + Specify what features are treated as categorical. + + - 'all' (default): All features are treated as categorical. + - array of indices: Array of categorical feature indices. + - mask: Array of length n_features and with dtype=bool. + Non-categorical features are always stacked to the right of the matrix. + + dtype : number type, default=np.float + Desired dtype of output. + + sparse : boolean, default=True + Will return sparse matrix if set True else will return an array. + + handle_unknown : str, 'error' or 'ignore' + Whether to raise an error or ignore if a unknown categorical feature is + present during transform. + + Attributes + ---------- + label_encoders_ : list of size n_features. + The :class:`sklearn.preprocessing.LabelEncoder` objects used to encode + the features. ``self.label_encoders[i]_`` is the LabelEncoder object + used to encode the ith column. The unique features found on column + ``i`` can be accessed using ``self.label_encoders_[i].classes_``. + + Examples + -------- + Given a dataset with three features and two samples, we let the encoder + find the maximum value per feature and transform the data to a binary + one-hot encoding. + >>> from sklearn.preprocessing import CategoricalEncoder + >>> enc = CategoricalEncoder() + >>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], \ +[1, 0, 2]]) # doctest: +ELLIPSIS + CategoricalEncoder(categorical_features='all', classes='auto', + dtype=<... 'numpy.float64'>, handle_unknown='error', + sparse=True) + >>> enc.transform([[0, 1, 1]]).toarray() + array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0.]]) + + See also + -------- + sklearn.feature_extraction.DictVectorizer : performs a one-hot encoding of + dictionary items (also handles string-valued features). + sklearn.feature_extraction.FeatureHasher : performs an approximate one-hot + encoding of dictionary items or strings. + """ + + def __init__(self, classes='auto', categorical_features="all", + dtype=np.float64, sparse=True, handle_unknown='error'): + self.classes = classes + self.categorical_features = categorical_features + self.dtype = dtype + self.sparse = sparse + self.handle_unknown = handle_unknown + + def fit(self, X, y=None): + """Fit the CategoricalEncoder to X. + + Parameters + ---------- + X : array-like, shape [n_samples, n_feature] + Input array of type int. + + Returns + ------- + self + """ + + if self.handle_unknown not in ['error', 'ignore']: + template = ("handle_unknown should be either 'error' or " + "'ignore', got %s") + raise ValueError(template % self.handle_unknown) + + X = check_array(X, dtype=np.object, accept_sparse='csc') + n_samples, n_features = X.shape + + _apply_selected(X, self._fit, dtype=self.dtype, + selected=self.categorical_features, copy=True, + return_val=False) + return self + + def _fit(self, X): + "Assumes `X` contains only cetergorical features." + + X = check_array(X, dtype=np.object) + n_samples, n_features = X.shape + + self.label_encoders_ = [LabelEncoder() for i in range(n_features)] + + for i in range(n_features): + le = self.label_encoders_[i] + if self.classes == 'auto': + le.fit(X[:, i]) + else: + le.classes_ = np.array(self.classes[i]) + + def transform(self, X, y=None): + """Encode the selected categorical features using the one-hot scheme. + """ + X = check_array(X, dtype=np.object) + return _apply_selected(X, self._transform, copy=True, + selected=self.categorical_features) + + def _transform(self, X): + "Assumes `X` contains only categorical features." + + X = check_array(X, accept_sparse='csc', dtype=np.object) + n_samples, n_features = X.shape + X_int = np.zeros_like(X, dtype=np.int) + X_mask = np.ones_like(X, dtype=np.bool) + + for i in range(n_features): + if np_version < (1, 8): + # in1d is not supported for object datatype in np < 1.8 + valid_mask = np.ones_like(X[:, i], dtype=np.bool) + found_classes = set(np.unique(X[:, i])) + valid_classes = set(self.label_encoders_[i].classes_) + invalid_classes = found_classes - valid_classes + + for item in invalid_classes: + mask = X[:, i] == item + np.logical_not(mask, mask) + np.logical_and(valid_mask, mask, valid_mask) + + else: + valid_mask = np.in1d(X[:, i], self.label_encoders_[i].classes_) + + if not np.all(valid_mask): + if self.handle_unknown == 'error': + if np_version < (1, 8): + valid_classes = set(self.label_encoders_[i].classes_) + diff = set(X[:, i]) - valid_classes + diff = list(diff) + else: + diff = np.setdiff1d(X[:, i], + self.label_encoders_[i].classes_) + msg = 'Unknown feature(s) %s in column %d' % (diff, i) + raise ValueError(msg) + else: + # Set the problematic rows to an acceptable value and + # continue `The rows are marked `X_mask` and will be + # removed later. + X_mask[:, i] = valid_mask + X[:, i][~valid_mask] = self.label_encoders_[i].classes_[0] + X_int[:, i] = self.label_encoders_[i].transform(X[:, i]) + + mask = X_mask.ravel() + n_values = [le.classes_.shape[0] for le in self.label_encoders_] + n_values = np.hstack([[0], n_values]) + indices = np.cumsum(n_values) + self.feature_indices_ = indices + + column_indices = (X_int + indices[:-1]).ravel()[mask] + row_indices = np.repeat(np.arange(n_samples, dtype=np.int32), + n_features)[mask] + data = np.ones(n_samples * n_features)[mask] + + out = sparse.coo_matrix((data, (row_indices, column_indices)), + shape=(n_samples, indices[-1]), + dtype=self.dtype).tocsr() + + return out if self.sparse else out.toarray() diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index af7f28f8162c6..8c738d5b915b8 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -6,6 +6,7 @@ from __future__ import division import warnings +import re import numpy as np import numpy.linalg as la from scipy import sparse @@ -31,13 +32,14 @@ from sklearn.utils.testing import skip_if_32bit from sklearn.utils.sparsefuncs import mean_variance_axis -from sklearn.preprocessing.data import _transform_selected +from sklearn.preprocessing.data import _apply_selected from sklearn.preprocessing.data import _handle_zeros_in_scale from sklearn.preprocessing.data import Binarizer from sklearn.preprocessing.data import KernelCenterer from sklearn.preprocessing.data import Normalizer from sklearn.preprocessing.data import normalize from sklearn.preprocessing.data import OneHotEncoder +from sklearn.preprocessing.data import CategoricalEncoder from sklearn.preprocessing.data import StandardScaler from sklearn.preprocessing.data import scale from sklearn.preprocessing.data import MinMaxScaler @@ -1851,29 +1853,29 @@ def test_one_hot_encoder_dense(): [1., 0., 1., 0., 1.]])) -def _check_transform_selected(X, X_expected, sel): +def _check_apply_selected(X, X_expected, sel): for M in (X, sparse.csr_matrix(X)): - Xtr = _transform_selected(M, Binarizer().transform, sel) + Xtr = _apply_selected(M, Binarizer().transform, sel) assert_array_equal(toarray(Xtr), X_expected) -def test_transform_selected(): - X = [[3, 2, 1], [0, 1, 1]] +def test_apply_selected(): + X = np.array([[3, 2, 1], [0, 1, 1]]) X_expected = [[1, 2, 1], [0, 1, 1]] - _check_transform_selected(X, X_expected, [0]) - _check_transform_selected(X, X_expected, [True, False, False]) + _check_apply_selected(X, X_expected, [0]) + _check_apply_selected(X, X_expected, [True, False, False]) X_expected = [[1, 1, 1], [0, 1, 1]] - _check_transform_selected(X, X_expected, [0, 1, 2]) - _check_transform_selected(X, X_expected, [True, True, True]) - _check_transform_selected(X, X_expected, "all") + _check_apply_selected(X, X_expected, [0, 1, 2]) + _check_apply_selected(X, X_expected, [True, True, True]) + _check_apply_selected(X, X_expected, "all") - _check_transform_selected(X, X, []) - _check_transform_selected(X, X, [False, False, False]) + _check_apply_selected(X, X, []) + _check_apply_selected(X, X, [False, False, False]) -def test_transform_selected_copy_arg(): +def test_apply_selected_copy_arg(): # transformer that alters X def _mutating_transformer(X): X[0, 0] = X[0, 0] + 1 @@ -1883,7 +1885,7 @@ def _mutating_transformer(X): expected_Xtr = [[2, 2], [3, 4]] X = original_X.copy() - Xtr = _transform_selected(X, _mutating_transformer, copy=True, + Xtr = _apply_selected(X, _mutating_transformer, copy=True, selected='all') assert_array_equal(toarray(X), toarray(original_X)) @@ -1952,6 +1954,57 @@ def test_one_hot_encoder_unknown_transform(): assert_raises(ValueError, oh.transform, y) +def check_categorical(X, cat_mask): + cat_idx = np.where(cat_mask) + + enc = CategoricalEncoder(categorical_features=cat_mask) + Xtr1 = enc.fit_transform(X) + + enc = CategoricalEncoder(categorical_features=cat_idx) + Xtr2 = enc.fit_transform(X) + + enc = CategoricalEncoder(categorical_features=cat_mask, sparse=False) + Xtr3 = enc.fit_transform(X) + + assert_allclose(Xtr1.toarray(), Xtr2.toarray()) + assert_allclose(Xtr1.toarray(), Xtr3) + + assert sparse.issparse(Xtr1) + assert sparse.issparse(Xtr2) + return Xtr1.toarray() + + +def test_categorical_encoder(): + X = [['abc', 1, 55], ['def', 2, 55]] + + Xtr = check_categorical(X, [True, False, False]) + assert_allclose(Xtr, [[1, 0, 1, 55], [0, 1, 2, 55]]) + + Xtr = check_categorical(X, [True, True, False]) + assert_allclose(Xtr, [[1, 0, 1, 0, 55], [0, 1, 0, 1, 55]]) + + Xtr = CategoricalEncoder().fit_transform(X) + assert_allclose(Xtr.toarray(), [[1, 0, 1, 0, 1], [0, 1, 0, 1, 1]]) + + +def test_categorical_encoder_errors(): + + enc = CategoricalEncoder() + X = [[1, 2, 3], [4, 5, 6]] + enc.fit(X) + + X[0][0] = -1 + msg = re.escape('Unknown feature(s) [-1] in column 0') + assert_raises_regex(ValueError, msg, enc.transform, X) + + enc = CategoricalEncoder(handle_unknown='ignore') + X = [[1, 2, 3], [4, 5, 6]] + enc.fit(X) + X[0][0] = -1 + Xtr = enc.transform(X) + assert_allclose(Xtr.toarray(), [[0, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1]]) + + def test_fit_cold_start(): X = iris.data X_2d = X[:, :2] diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 20731a9458885..10df72aebc0f6 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -519,7 +519,7 @@ def uninstall_mldata_mock(): 'LabelBinarizer', 'LabelEncoder', 'MultiLabelBinarizer', 'TfidfTransformer', 'TfidfVectorizer', 'IsotonicRegression', - 'OneHotEncoder', 'RandomTreesEmbedding', + 'OneHotEncoder', 'RandomTreesEmbedding', 'CategoricalEncoder', 'FeatureHasher', 'DummyClassifier', 'DummyRegressor', 'TruncatedSVD', 'PolynomialFeatures', 'GaussianRandomProjectionHash', 'HashingVectorizer', From bea23a5497459d4534eff232dd1be4a9fdcabc7b Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 20 Jun 2017 00:56:25 +0200 Subject: [PATCH 02/31] First round of updates - remove compat code for numpy < 1.8 - remove categorical_features keyword - make label_encoders_ private - rename classes to categories --- sklearn/preprocessing/data.py | 79 ++++++------------------ sklearn/preprocessing/tests/test_data.py | 18 +++--- 2 files changed, 28 insertions(+), 69 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 3f84bfbae3ac8..ccffae358cd14 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2462,20 +2462,12 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): Parameters ---------- - classes : 'auto', 2D array of ints or strings or both. + categories : 'auto', 2D array of ints or strings or both. Values per feature. - 'auto' : Determine classes automatically from the training data. - array: ``classes[i]`` holds the classes expected in the ith column. - categorical_features : 'all' or array of indices or mask - Specify what features are treated as categorical. - - - 'all' (default): All features are treated as categorical. - - array of indices: Array of categorical feature indices. - - mask: Array of length n_features and with dtype=bool. - Non-categorical features are always stacked to the right of the matrix. - dtype : number type, default=np.float Desired dtype of output. @@ -2486,14 +2478,6 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): Whether to raise an error or ignore if a unknown categorical feature is present during transform. - Attributes - ---------- - label_encoders_ : list of size n_features. - The :class:`sklearn.preprocessing.LabelEncoder` objects used to encode - the features. ``self.label_encoders[i]_`` is the LabelEncoder object - used to encode the ith column. The unique features found on column - ``i`` can be accessed using ``self.label_encoders_[i].classes_``. - Examples -------- Given a dataset with three features and two samples, we let the encoder @@ -2503,9 +2487,8 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): >>> enc = CategoricalEncoder() >>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], \ [1, 0, 2]]) # doctest: +ELLIPSIS - CategoricalEncoder(categorical_features='all', classes='auto', - dtype=<... 'numpy.float64'>, handle_unknown='error', - sparse=True) + CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, + handle_unknown='error', sparse=True) >>> enc.transform([[0, 1, 1]]).toarray() array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0.]]) @@ -2517,10 +2500,9 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): encoding of dictionary items or strings. """ - def __init__(self, classes='auto', categorical_features="all", - dtype=np.float64, sparse=True, handle_unknown='error'): - self.classes = classes - self.categorical_features = categorical_features + def __init__(self, categories='auto', dtype=np.float64, sparse=True, + handle_unknown='error'): + self.categories = categories self.dtype = dtype self.sparse = sparse self.handle_unknown = handle_unknown @@ -2543,12 +2525,10 @@ def fit(self, X, y=None): "'ignore', got %s") raise ValueError(template % self.handle_unknown) - X = check_array(X, dtype=np.object, accept_sparse='csc') + X = check_array(X, dtype=np.object, accept_sparse='csc', copy=True) n_samples, n_features = X.shape - _apply_selected(X, self._fit, dtype=self.dtype, - selected=self.categorical_features, copy=True, - return_val=False) + self._fit(X) return self def _fit(self, X): @@ -2557,21 +2537,20 @@ def _fit(self, X): X = check_array(X, dtype=np.object) n_samples, n_features = X.shape - self.label_encoders_ = [LabelEncoder() for i in range(n_features)] + self._label_encoders_ = [LabelEncoder() for i in range(n_features)] for i in range(n_features): - le = self.label_encoders_[i] - if self.classes == 'auto': + le = self._label_encoders_[i] + if self.categories == 'auto': le.fit(X[:, i]) else: - le.classes_ = np.array(self.classes[i]) + le.classes_ = np.array(self.categories[i]) def transform(self, X, y=None): """Encode the selected categorical features using the one-hot scheme. """ - X = check_array(X, dtype=np.object) - return _apply_selected(X, self._transform, copy=True, - selected=self.categorical_features) + X = check_array(X, dtype=np.object, copy=True) + return self._transform(X) def _transform(self, X): "Assumes `X` contains only categorical features." @@ -2582,30 +2561,12 @@ def _transform(self, X): X_mask = np.ones_like(X, dtype=np.bool) for i in range(n_features): - if np_version < (1, 8): - # in1d is not supported for object datatype in np < 1.8 - valid_mask = np.ones_like(X[:, i], dtype=np.bool) - found_classes = set(np.unique(X[:, i])) - valid_classes = set(self.label_encoders_[i].classes_) - invalid_classes = found_classes - valid_classes - - for item in invalid_classes: - mask = X[:, i] == item - np.logical_not(mask, mask) - np.logical_and(valid_mask, mask, valid_mask) - - else: - valid_mask = np.in1d(X[:, i], self.label_encoders_[i].classes_) + valid_mask = np.in1d(X[:, i], self._label_encoders_[i].classes_) if not np.all(valid_mask): if self.handle_unknown == 'error': - if np_version < (1, 8): - valid_classes = set(self.label_encoders_[i].classes_) - diff = set(X[:, i]) - valid_classes - diff = list(diff) - else: - diff = np.setdiff1d(X[:, i], - self.label_encoders_[i].classes_) + diff = np.setdiff1d(X[:, i], + self._label_encoders_[i].classes_) msg = 'Unknown feature(s) %s in column %d' % (diff, i) raise ValueError(msg) else: @@ -2613,11 +2574,11 @@ def _transform(self, X): # continue `The rows are marked `X_mask` and will be # removed later. X_mask[:, i] = valid_mask - X[:, i][~valid_mask] = self.label_encoders_[i].classes_[0] - X_int[:, i] = self.label_encoders_[i].transform(X[:, i]) + X[:, i][~valid_mask] = self._label_encoders_[i].classes_[0] + X_int[:, i] = self._label_encoders_[i].transform(X[:, i]) mask = X_mask.ravel() - n_values = [le.classes_.shape[0] for le in self.label_encoders_] + n_values = [le.classes_.shape[0] for le in self._label_encoders_] n_values = np.hstack([[0], n_values]) indices = np.cumsum(n_values) self.feature_indices_ = indices diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 8c738d5b915b8..8c84408516580 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -1954,16 +1954,14 @@ def test_one_hot_encoder_unknown_transform(): assert_raises(ValueError, oh.transform, y) -def check_categorical(X, cat_mask): - cat_idx = np.where(cat_mask) - - enc = CategoricalEncoder(categorical_features=cat_mask) +def check_categorical(X): + enc = CategoricalEncoder() Xtr1 = enc.fit_transform(X) - enc = CategoricalEncoder(categorical_features=cat_idx) + enc = CategoricalEncoder() Xtr2 = enc.fit_transform(X) - enc = CategoricalEncoder(categorical_features=cat_mask, sparse=False) + enc = CategoricalEncoder(sparse=False) Xtr3 = enc.fit_transform(X) assert_allclose(Xtr1.toarray(), Xtr2.toarray()) @@ -1977,11 +1975,11 @@ def check_categorical(X, cat_mask): def test_categorical_encoder(): X = [['abc', 1, 55], ['def', 2, 55]] - Xtr = check_categorical(X, [True, False, False]) - assert_allclose(Xtr, [[1, 0, 1, 55], [0, 1, 2, 55]]) + Xtr = check_categorical(np.array(X)[:, [0]]) + assert_allclose(Xtr, [[1, 0], [0, 1]]) - Xtr = check_categorical(X, [True, True, False]) - assert_allclose(Xtr, [[1, 0, 1, 0, 55], [0, 1, 0, 1, 55]]) + Xtr = check_categorical(np.array(X)[:, [0, 1]]) + assert_allclose(Xtr, [[1, 0, 1, 0], [0, 1, 0, 1]]) Xtr = CategoricalEncoder().fit_transform(X) assert_allclose(Xtr.toarray(), [[1, 0, 1, 0, 1], [0, 1, 0, 1, 1]]) From fda6d27f3d6acd9656687cb0ea53483e2b634995 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 26 Jun 2017 15:24:32 +0200 Subject: [PATCH 03/31] fix + test specifying of categories --- sklearn/preprocessing/data.py | 18 +++++++++++--- sklearn/preprocessing/tests/test_data.py | 30 ++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index ccffae358cd14..5ddd5e594be9b 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2462,11 +2462,12 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): Parameters ---------- - categories : 'auto', 2D array of ints or strings or both. + categories : 'auto' or a list of lists/arrays of values. Values per feature. - 'auto' : Determine classes automatically from the training data. - - array: ``classes[i]`` holds the classes expected in the ith column. + - list : ``categories[i]`` holds the categories expected in the ith + column. dtype : number type, default=np.float Desired dtype of output. @@ -2544,7 +2545,18 @@ def _fit(self, X): if self.categories == 'auto': le.fit(X[:, i]) else: - le.classes_ = np.array(self.categories[i]) + if not np.all(np.in1d(X[:, i], self.categories[i])): + if self.handle_unknown == 'error': + diff = np.setdiff1d(X[:, i], self.categories[i]) + msg = 'Unknown feature(s) %s in column %d' % (diff, i) + raise ValueError(msg) + le.classes_ = np.array(np.sort(self.categories[i])) + + @staticmethod + def _check_unknown_categories(values, categories): + """Returns False if not all categories in the values are known""" + valid_mask = np.in1d(values, categories) + return np.all(valid_mask) def transform(self, X, y=None): """Encode the selected categorical features using the one-hot scheme. diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 8c84408516580..9c7429b0435a1 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2003,6 +2003,36 @@ def test_categorical_encoder_errors(): assert_allclose(Xtr.toarray(), [[0, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1]]) +def test_categorical_encoder_specified_categories(): + X = np.array([['a', 'b']], dtype=object).T + + enc = CategoricalEncoder(categories=[['a', 'b', 'c']]) + exp = np.array([[1., 0., 0.], + [0., 1., 0.]]) + assert_array_equal(enc.fit_transform(X).toarray(), exp) + + # don't follow order of passed categories, but sort them + enc = CategoricalEncoder(categories=[['c', 'b', 'a']]) + exp = np.array([[1., 0., 0.], + [0., 1., 0.]]) + assert_array_equal(enc.fit_transform(X).toarray(), exp) + + # multiple columns + X = np.array([['a', 'b'], ['A', 'C']], dtype=object).T + enc = CategoricalEncoder(categories=[['a', 'b', 'c'], ['A', 'B', 'C']]) + exp = np.array([[1., 0., 0., 1., 0., 0.], + [0., 1., 0., 0., 0., 1.]]) + assert_array_equal(enc.fit_transform(X).toarray(), exp) + + # when specifying categories manually, unknown categories should already + # raise when fitting + X = np.array([['a', 'b', 'c']]).T + enc = CategoricalEncoder(categories=[['a', 'b']]) + assert_raises(ValueError, enc.fit, X) + enc = CategoricalEncoder(categories=[['a', 'b']], handle_unknown='ignore') + enc.fit(X) + + def test_fit_cold_start(): X = iris.data X_2d = X[:, :2] From 5f2b40314b87cd7da0ec3eaa21741ff1340af6bd Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 27 Jun 2017 18:57:04 +0200 Subject: [PATCH 04/31] further clean-up + tests - check that it works on pandas frames - fix doctests - un-deprecate OneHotEncoder - undo changes in _transform_selected (as we no longer need those changes for CategoricalEncoder) - add see also to OneHotEncoder and vice versa - for now remove the self.feature_indices_ attribute --- doc/modules/classes.rst | 1 + doc/modules/preprocessing.rst | 12 ++-- sklearn/preprocessing/data.py | 87 +++++++++--------------- sklearn/preprocessing/tests/test_data.py | 39 +++++++---- 4 files changed, 66 insertions(+), 73 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 601d41b682791..d8899a22308bd 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1197,6 +1197,7 @@ See the :ref:`metrics` section of the user guide for further details. preprocessing.MaxAbsScaler preprocessing.MinMaxScaler preprocessing.Normalizer + preprocessing.OneHotEncoder preprocessing.CategoricalEncoder preprocessing.PolynomialFeatures preprocessing.QuantileTransformer diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index dd85a6c6faa98..8645c6bce0410 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -470,9 +470,8 @@ Continuing the example above:: >>> enc = preprocessing.CategoricalEncoder() >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] >>> enc.fit(X) # doctest: +ELLIPSIS - CategoricalEncoder(categorical_features='all', classes='auto', - dtype=<... 'numpy.float64'>, handle_unknown='error', - sparse=True) + CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, + handle_unknown='error', sparse=True) >>> enc.transform([['female', 'from US', 'uses Safari']]).toarray() array([[ 1., 0., 0., 1., 0., 1.]]) @@ -488,18 +487,17 @@ features, one has to explicitly set ``classes``. For example, >>> genders = ['male', 'female'] >>> locations = ['from Europe', 'from US', 'from Africa', 'from Asia'] >>> browsers = ['uses Safari', 'uses Firefox', 'uses IE', 'uses Chrome'] - >>> enc = preprocessing.CategoricalEncoder(classes=[genders, locations, browsers]) + >>> enc = preprocessing.CategoricalEncoder(categories=[genders, locations, browsers]) >>> # Note that for there are missing categorical values for the 2nd and 3rd >>> # feature >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] >>> enc.fit(X) # doctest: +ELLIPSIS - CategoricalEncoder(categorical_features='all', - classes=[...], + CategoricalEncoder(categories=[...], dtype=<... 'numpy.float64'>, handle_unknown='error', sparse=True) >>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray() - array([[ 1., 0., 0., 0., 0., 1., 1., 0., 0., 0.]]) + array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.]]) See :ref:`dict_feature_extraction` for categorical features that are represented as a dict, not as integers. diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 5ddd5e594be9b..1562f74b988bd 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -30,8 +30,6 @@ from ..utils.validation import (check_is_fitted, check_random_state, FLOAT_DTYPES) from .label import LabelEncoder -from ..utils.fixes import np_version -from ..utils.deprecation import deprecated BOUNDS_THRESHOLD = 1e-7 @@ -1682,12 +1680,11 @@ def add_dummy_feature(X, value=1.0): return np.hstack((np.ones((n_samples, 1)) * value, X)) -def _apply_selected(X, transform, selected="all", dtype=np.float, copy=True, - return_val=True): - """Apply a function to portion of selected features +def _transform_selected(X, transform, selected="all", copy=True): + """Apply a transform function to portion of selected features Parameters ---------- - X : {array, sparse matrix}, shape [n_samples, n_features] + X : {array-like, sparse matrix}, shape [n_samples, n_features] Dense array or sparse matrix. transform : callable A callable transform(X) -> X_transformed @@ -1695,14 +1692,11 @@ def _apply_selected(X, transform, selected="all", dtype=np.float, copy=True, Copy X even if it could be avoided. selected: "all" or array of indices or mask Specify which features to apply the transform to. - return_val : boolean, optional - Whether to return the transformed matrix. If not set `None` is - returned. Returns ------- - X : array or sparse matrix, shape=(n_samples, n_features_new) + X : array or sparse matrix, shape=(n_samples, n_features_new) """ - X = check_array(X, accept_sparse='csc', copy=copy, dtype=None) + X = check_array(X, accept_sparse='csc', copy=copy, dtype=FLOAT_DTYPES) if isinstance(selected, six.string_types) and selected == "all": return transform(X) @@ -1725,24 +1719,23 @@ def _apply_selected(X, transform, selected="all", dtype=np.float, copy=True, return transform(X) else: X_sel = transform(X[:, ind[sel]]) - X_not_sel = X[:, ind[not_sel]].astype(dtype) + X_not_sel = X[:, ind[not_sel]] - if return_val: - if sparse.issparse(X_sel) or sparse.issparse(X_not_sel): - return sparse.hstack((X_sel, X_not_sel)) - else: - return np.hstack((X_sel, X_not_sel)) + if sparse.issparse(X_sel) or sparse.issparse(X_not_sel): + return sparse.hstack((X_sel, X_not_sel)) + else: + return np.hstack((X_sel, X_not_sel)) -@deprecated('`OneHotEncoder` is deprecated, use `CategoricalEncoder` instead.') class OneHotEncoder(BaseEstimator, TransformerMixin): - """Encode categorical integer features using a one-hot aka one-of-K scheme. + """Encode ordinal integer features using a one-hot aka one-of-K scheme. The input to this transformer should be a matrix of integers, denoting the values taken on by categorical (discrete) features. The output will be a sparse matrix where each column corresponds to one possible value of one feature. It is assumed that input features take on values in the range - [0, n_values). + [0, n_values). For an encoder based on the unique values of the input + features, see the :class:`sklearn.preprocessing.CategoricalEncoder`. This encoding is needed for feeding categorical data to many scikit-learn estimators, notably linear models and SVMs with the standard kernels. @@ -1819,6 +1812,9 @@ class OneHotEncoder(BaseEstimator, TransformerMixin): See also -------- + sklearn.preprocessing.CategoricalEncoder : performs a one-hot encoding of + all features (also handles string-valued features). This encoder + derives the categories based on the unique values in the features. sklearn.feature_extraction.DictVectorizer : performs a one-hot encoding of dictionary items (also handles string-valued features). sklearn.feature_extraction.FeatureHasher : performs an approximate one-hot @@ -1908,8 +1904,8 @@ def fit_transform(self, X, y=None): Equivalent to self.fit(X).transform(X), but more convenient and more efficient. See fit for the parameters, transform for the return value. """ - return _apply_selected(X, self._fit_transform, dtype=self.dtype, - selected=self.categorical_features, copy=True) + return _transform_selected(X, self._fit_transform, + self.categorical_features, copy=True) def _transform(self, X): """Assumes X contains only categorical features.""" @@ -1964,8 +1960,8 @@ def transform(self, X): X_out : sparse matrix if sparse=True else a 2-d array, dtype=int Transformed input. """ - return _apply_selected(X, self._transform, dtype=self.dtype, - selected=self.categorical_features, copy=True) + return _transform_selected(X, self._transform, + self.categorical_features, copy=True) class QuantileTransformer(BaseEstimator, TransformerMixin): @@ -2465,7 +2461,7 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): categories : 'auto' or a list of lists/arrays of values. Values per feature. - - 'auto' : Determine classes automatically from the training data. + - 'auto' : Determine categories automatically from the training data. - list : ``categories[i]`` holds the categories expected in the ith column. @@ -2484,10 +2480,11 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): Given a dataset with three features and two samples, we let the encoder find the maximum value per feature and transform the data to a binary one-hot encoding. + >>> from sklearn.preprocessing import CategoricalEncoder >>> enc = CategoricalEncoder() - >>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], \ -[1, 0, 2]]) # doctest: +ELLIPSIS + >>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], [1, 0, 2]]) + ... # doctest: +ELLIPSIS CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, handle_unknown='error', sparse=True) >>> enc.transform([[0, 1, 1]]).toarray() @@ -2495,6 +2492,9 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): See also -------- + sklearn.preprocessing.OneHotEncoder : performs a one-hot encoding of + integer ordinal features. This transformer assumes that input features + take on values in the range [0, max(feature)]. sklearn.feature_extraction.DictVectorizer : performs a one-hot encoding of dictionary items (also handles string-valued features). sklearn.feature_extraction.FeatureHasher : performs an approximate one-hot @@ -2529,45 +2529,27 @@ def fit(self, X, y=None): X = check_array(X, dtype=np.object, accept_sparse='csc', copy=True) n_samples, n_features = X.shape - self._fit(X) - return self - - def _fit(self, X): - "Assumes `X` contains only cetergorical features." - - X = check_array(X, dtype=np.object) - n_samples, n_features = X.shape - - self._label_encoders_ = [LabelEncoder() for i in range(n_features)] + self._label_encoders_ = [LabelEncoder() for _ in range(n_features)] for i in range(n_features): le = self._label_encoders_[i] + Xi = X[:, i] if self.categories == 'auto': - le.fit(X[:, i]) + le.fit(Xi) else: - if not np.all(np.in1d(X[:, i], self.categories[i])): + if not np.all(np.in1d(Xi, self.categories[i])): if self.handle_unknown == 'error': - diff = np.setdiff1d(X[:, i], self.categories[i]) + diff = np.setdiff1d(Xi, self.categories[i]) msg = 'Unknown feature(s) %s in column %d' % (diff, i) raise ValueError(msg) le.classes_ = np.array(np.sort(self.categories[i])) - @staticmethod - def _check_unknown_categories(values, categories): - """Returns False if not all categories in the values are known""" - valid_mask = np.in1d(values, categories) - return np.all(valid_mask) + return self def transform(self, X, y=None): """Encode the selected categorical features using the one-hot scheme. """ - X = check_array(X, dtype=np.object, copy=True) - return self._transform(X) - - def _transform(self, X): - "Assumes `X` contains only categorical features." - - X = check_array(X, accept_sparse='csc', dtype=np.object) + X = check_array(X, accept_sparse='csc', dtype=np.object, copy=True) n_samples, n_features = X.shape X_int = np.zeros_like(X, dtype=np.int) X_mask = np.ones_like(X, dtype=np.bool) @@ -2593,7 +2575,6 @@ def _transform(self, X): n_values = [le.classes_.shape[0] for le in self._label_encoders_] n_values = np.hstack([[0], n_values]) indices = np.cumsum(n_values) - self.feature_indices_ = indices column_indices = (X_int + indices[:-1]).ravel()[mask] row_indices = np.repeat(np.arange(n_samples, dtype=np.int32), diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 9c7429b0435a1..f23bf52d38000 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -32,7 +32,7 @@ from sklearn.utils.testing import skip_if_32bit from sklearn.utils.sparsefuncs import mean_variance_axis -from sklearn.preprocessing.data import _apply_selected +from sklearn.preprocessing.data import _transform_selected from sklearn.preprocessing.data import _handle_zeros_in_scale from sklearn.preprocessing.data import Binarizer from sklearn.preprocessing.data import KernelCenterer @@ -1853,29 +1853,29 @@ def test_one_hot_encoder_dense(): [1., 0., 1., 0., 1.]])) -def _check_apply_selected(X, X_expected, sel): +def _check_transform_selected(X, X_expected, sel): for M in (X, sparse.csr_matrix(X)): - Xtr = _apply_selected(M, Binarizer().transform, sel) + Xtr = _transform_selected(M, Binarizer().transform, sel) assert_array_equal(toarray(Xtr), X_expected) -def test_apply_selected(): +def test_transform_selected(): X = np.array([[3, 2, 1], [0, 1, 1]]) X_expected = [[1, 2, 1], [0, 1, 1]] - _check_apply_selected(X, X_expected, [0]) - _check_apply_selected(X, X_expected, [True, False, False]) + _check_transform_selected(X, X_expected, [0]) + _check_transform_selected(X, X_expected, [True, False, False]) X_expected = [[1, 1, 1], [0, 1, 1]] - _check_apply_selected(X, X_expected, [0, 1, 2]) - _check_apply_selected(X, X_expected, [True, True, True]) - _check_apply_selected(X, X_expected, "all") + _check_transform_selected(X, X_expected, [0, 1, 2]) + _check_transform_selected(X, X_expected, [True, True, True]) + _check_transform_selected(X, X_expected, "all") - _check_apply_selected(X, X, []) - _check_apply_selected(X, X, [False, False, False]) + _check_transform_selected(X, X, []) + _check_transform_selected(X, X, [False, False, False]) -def test_apply_selected_copy_arg(): +def test_transform_selected_copy_arg(): # transformer that alters X def _mutating_transformer(X): X[0, 0] = X[0, 0] + 1 @@ -1885,7 +1885,7 @@ def _mutating_transformer(X): expected_Xtr = [[2, 2], [3, 4]] X = original_X.copy() - Xtr = _apply_selected(X, _mutating_transformer, copy=True, + Xtr = _transform_selected(X, _mutating_transformer, copy=True, selected='all') assert_array_equal(toarray(X), toarray(original_X)) @@ -2033,6 +2033,19 @@ def test_categorical_encoder_specified_categories(): enc.fit(X) +def test_categorical_encoder_pandas(): + + try: + import pandas as pd + except ImportError: + raise SkipTest("pandas is not installed") + + X_df = pd.DataFrame({'A': ['a', 'b'], 'B': ['c', 'd']}) + + Xtr = check_categorical(X_df) + assert_allclose(Xtr, [[1, 0, 1, 0], [0, 1, 0, 1]]) + + def test_fit_cold_start(): X = iris.data X_2d = X[:, :2] From e175e4caca35c42e60cf26bace0d4910175b3bf0 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 27 Jun 2017 20:10:02 +0200 Subject: [PATCH 05/31] fix skipping pandas test --- sklearn/preprocessing/tests/test_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index f23bf52d38000..f3ac11f21c533 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -30,6 +30,7 @@ from sklearn.utils.testing import assert_no_warnings from sklearn.utils.testing import assert_allclose from sklearn.utils.testing import skip_if_32bit +from sklearn.utils.testing import SkipTest from sklearn.utils.sparsefuncs import mean_variance_axis from sklearn.preprocessing.data import _transform_selected From 4f64648810ce8c3db0b2cc59dffb0bff8b6f55f6 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 1 Aug 2017 19:32:07 +0200 Subject: [PATCH 06/31] feedback andy --- sklearn/preprocessing/data.py | 48 ++++++++++++++++-------- sklearn/preprocessing/tests/test_data.py | 14 +++---- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 322eba0b9ce78..13affe5c8a5a1 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1836,14 +1836,15 @@ def _transform_selected(X, transform, selected="all", copy=True): class OneHotEncoder(BaseEstimator, TransformerMixin): - """Encode ordinal integer features using a one-hot aka one-of-K scheme. + """Encode categorical integer features using a one-hot aka one-of-K scheme. The input to this transformer should be a matrix of integers, denoting the values taken on by categorical (discrete) features. The output will be a sparse matrix where each column corresponds to one possible value of one feature. It is assumed that input features take on values in the range [0, n_values). For an encoder based on the unique values of the input - features, see the :class:`sklearn.preprocessing.CategoricalEncoder`. + features of any type, see the + :class:`~sklearn.preprocessing.CategoricalEncoder`. This encoding is needed for feeding categorical data to many scikit-learn estimators, notably linear models and SVMs with the standard kernels. @@ -2578,7 +2579,7 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): - 'auto' : Determine categories automatically from the training data. - list : ``categories[i]`` holds the categories expected in the ith - column. + column. The passed categories are sorted before encoding the data. dtype : number type, default=np.float Desired dtype of output. @@ -2588,7 +2589,10 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): handle_unknown : str, 'error' or 'ignore' Whether to raise an error or ignore if a unknown categorical feature is - present during transform. + present during transform (default is to raise). When this is parameter + is set to 'ignore' and an unknown category is encountered during + transform, the resulting one-hot encoded columns for this feature + will be all zeros. Examples -------- @@ -2608,8 +2612,9 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): See also -------- sklearn.preprocessing.OneHotEncoder : performs a one-hot encoding of - integer ordinal features. This transformer assumes that input features - take on values in the range [0, max(feature)]. + integer ordinal features. The ``OneHotEncoder assumes`` that input + features take on values in the range ``[0, max(feature)]`` instead of + using the unique values. sklearn.feature_extraction.DictVectorizer : performs a one-hot encoding of dictionary items (also handles string-valued features). sklearn.feature_extraction.FeatureHasher : performs an approximate one-hot @@ -2629,7 +2634,7 @@ def fit(self, X, y=None): Parameters ---------- X : array-like, shape [n_samples, n_feature] - Input array of type int. + The data to determine the categories of each feature. Returns ------- @@ -2652,17 +2657,30 @@ def fit(self, X, y=None): if self.categories == 'auto': le.fit(Xi) else: - if not np.all(np.in1d(Xi, self.categories[i])): + valid_mask = np.in1d(Xi, self.categories[i]) + if not np.all(valid_mask): if self.handle_unknown == 'error': - diff = np.setdiff1d(Xi, self.categories[i]) - msg = 'Unknown feature(s) %s in column %d' % (diff, i) + diff = np.unique(Xi[~valid_mask]) + msg = ("Found unknown categories {0} in column {1}" + " during fit".format(diff, i)) raise ValueError(msg) le.classes_ = np.array(np.sort(self.categories[i])) return self def transform(self, X, y=None): - """Encode the selected categorical features using the one-hot scheme. + """Transform X using one-hot encoding. + + Parameters + ---------- + X : array-like, shape [n_samples, n_features] + The data to encode. + + Returns + ------- + X_out : sparse matrix if sparse=True else a 2-d array + Transformed input. + """ X = check_array(X, accept_sparse='csc', dtype=np.object, copy=True) n_samples, n_features = X.shape @@ -2674,9 +2692,9 @@ def transform(self, X, y=None): if not np.all(valid_mask): if self.handle_unknown == 'error': - diff = np.setdiff1d(X[:, i], - self._label_encoders_[i].classes_) - msg = 'Unknown feature(s) %s in column %d' % (diff, i) + diff = np.unique(X[~valid_mask, i]) + msg = ("Found unknown categories {0} in column {1}" + " during transform".format(diff, i)) raise ValueError(msg) else: # Set the problematic rows to an acceptable value and @@ -2688,7 +2706,7 @@ def transform(self, X, y=None): mask = X_mask.ravel() n_values = [le.classes_.shape[0] for le in self._label_encoders_] - n_values = np.hstack([[0], n_values]) + n_values = np.array([0] + n_values) indices = np.cumsum(n_values) column_indices = (X_int + indices[:-1]).ravel()[mask] diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index f3ac11f21c533..d59cea14d84e5 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -1959,17 +1959,12 @@ def check_categorical(X): enc = CategoricalEncoder() Xtr1 = enc.fit_transform(X) - enc = CategoricalEncoder() - Xtr2 = enc.fit_transform(X) - enc = CategoricalEncoder(sparse=False) - Xtr3 = enc.fit_transform(X) + Xtr2 = enc.fit_transform(X) - assert_allclose(Xtr1.toarray(), Xtr2.toarray()) - assert_allclose(Xtr1.toarray(), Xtr3) + assert_allclose(Xtr1.toarray(), Xtr2) assert sparse.issparse(Xtr1) - assert sparse.issparse(Xtr2) return Xtr1.toarray() @@ -1993,7 +1988,7 @@ def test_categorical_encoder_errors(): enc.fit(X) X[0][0] = -1 - msg = re.escape('Unknown feature(s) [-1] in column 0') + msg = re.escape('unknown categories [-1] in column 0') assert_raises_regex(ValueError, msg, enc.transform, X) enc = CategoricalEncoder(handle_unknown='ignore') @@ -2031,7 +2026,8 @@ def test_categorical_encoder_specified_categories(): enc = CategoricalEncoder(categories=[['a', 'b']]) assert_raises(ValueError, enc.fit, X) enc = CategoricalEncoder(categories=[['a', 'b']], handle_unknown='ignore') - enc.fit(X) + exp = np.array([[1., 0.], [0., 1.], [0., 0.]]) + assert_array_equal(enc.fit(X).transform(X).toarray(), exp) def test_categorical_encoder_pandas(): From 01c3bd4a48b14f1f82ae5c5036b6074b8e9d0d6c Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 7 Aug 2017 17:21:05 +0200 Subject: [PATCH 07/31] add encoding keyword to support ordinal encoding --- doc/modules/preprocessing.rst | 2 +- sklearn/preprocessing/data.py | 34 +++++++++++++++++++++--- sklearn/preprocessing/tests/test_data.py | 15 +++++++++++ 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 6e1d1bb1abb1d..581f5df54a20c 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -471,7 +471,7 @@ Continuing the example above:: >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] >>> enc.fit(X) # doctest: +ELLIPSIS CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, - handle_unknown='error', sparse=True) + encoding='onehot', handle_unknown='error', sparse=True) >>> enc.transform([['female', 'from US', 'uses Safari']]).toarray() array([[ 1., 0., 0., 1., 0., 1.]]) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 13affe5c8a5a1..8daced2af4c00 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2574,6 +2574,15 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): Parameters ---------- + encoding : str, 'onehot' or 'ordinal' + The type of encoding to use (default is 'onehot'): + + - 'onehot': encode the features using a one-hot aka one-of-K scheme + (or also called 'dummy' encoding). This creates a binary column for + each category. + - 'ordinal': encode the features as ordinal integers. This results in + a single column of integers (0 to n_categories - 1) per feature. + categories : 'auto' or a list of lists/arrays of values. Values per feature. @@ -2582,10 +2591,12 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): column. The passed categories are sorted before encoding the data. dtype : number type, default=np.float - Desired dtype of output. + Desired dtype of output. This keyword is ignored in case of + ``encoding='ordinal'`` (the output is always integer). sparse : boolean, default=True Will return sparse matrix if set True else will return an array. + This keyword is ignored in case of ``encoding='ordinal'``. handle_unknown : str, 'error' or 'ignore' Whether to raise an error or ignore if a unknown categorical feature is @@ -2593,6 +2604,8 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): is set to 'ignore' and an unknown category is encountered during transform, the resulting one-hot encoded columns for this feature will be all zeros. + Ignoring unknown categories is not supported for + ``encoding='ordinal'``. Examples -------- @@ -2605,7 +2618,7 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): >>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], [1, 0, 2]]) ... # doctest: +ELLIPSIS CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, - handle_unknown='error', sparse=True) + encoding='onehot', handle_unknown='error', sparse=True) >>> enc.transform([[0, 1, 1]]).toarray() array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0.]]) @@ -2621,8 +2634,9 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): encoding of dictionary items or strings. """ - def __init__(self, categories='auto', dtype=np.float64, sparse=True, - handle_unknown='error'): + def __init__(self, encoding='onehot', categories='auto', dtype=np.float64, + sparse=True, handle_unknown='error'): + self.encoding = encoding self.categories = categories self.dtype = dtype self.sparse = sparse @@ -2641,11 +2655,20 @@ def fit(self, X, y=None): self """ + if self.encoding not in ['onehot', 'ordinal']: + template = ("encoding should be either 'onehot' or " + "'ordinal', got %s") + raise ValueError(template % self.handle_unknown) + if self.handle_unknown not in ['error', 'ignore']: template = ("handle_unknown should be either 'error' or " "'ignore', got %s") raise ValueError(template % self.handle_unknown) + if self.encoding == 'ordinal' and self.handle_unknown == 'ignore': + raise ValueError("handle_unknown='ignore' is not supported for" + " encoding='ordinal'") + X = check_array(X, dtype=np.object, accept_sparse='csc', copy=True) n_samples, n_features = X.shape @@ -2704,6 +2727,9 @@ def transform(self, X, y=None): X[:, i][~valid_mask] = self._label_encoders_[i].classes_[0] X_int[:, i] = self._label_encoders_[i].transform(X[:, i]) + if self.encoding == 'ordinal': + return X_int + mask = X_mask.ravel() n_values = [le.classes_.shape[0] for le in self._label_encoders_] n_values = np.array([0] + n_values) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index d59cea14d84e5..f234123ea0887 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2043,6 +2043,21 @@ def test_categorical_encoder_pandas(): assert_allclose(Xtr, [[1, 0, 1, 0], [0, 1, 0, 1]]) +def test_categorical_encoder_ordinal(): + X = [['abc', 2, 55], ['def', 1, 55]] + + enc = CategoricalEncoder(encoding='other') + assert_raises(ValueError, enc.fit, X) + + enc = CategoricalEncoder(encoding='ordinal', handle_unknown='ignore') + assert_raises(ValueError, enc.fit, X) + + enc = CategoricalEncoder(encoding='ordinal') + exp = np.array([[0, 1, 0], + [1, 0, 0]]) + assert_array_equal(enc.fit_transform(X), exp) + + def test_fit_cold_start(): X = iris.data X_2d = X[:, :2] From 2ed91e82155aa7232d16e569ddd9893390930795 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 9 Aug 2017 11:34:17 +0200 Subject: [PATCH 08/31] remove y from transform signature --- doc/modules/preprocessing.rst | 4 ++-- sklearn/preprocessing/data.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 581f5df54a20c..65939afb55c99 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -493,8 +493,8 @@ features, one has to explicitly set ``classes``. For example, >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] >>> enc.fit(X) # doctest: +ELLIPSIS CategoricalEncoder(categories=[...], - dtype=<... 'numpy.float64'>, handle_unknown='error', - sparse=True) + dtype=<... 'numpy.float64'>, encoding='onehot', + handle_unknown='error', sparse=True) >>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray() array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.]]) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 8daced2af4c00..3f63399f24f6e 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2691,7 +2691,7 @@ def fit(self, X, y=None): return self - def transform(self, X, y=None): + def transform(self, X): """Transform X using one-hot encoding. Parameters From a589dd9ed2f85c260b13dec8b579de36317be7b9 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 9 Aug 2017 12:02:20 +0200 Subject: [PATCH 09/31] Remove sparse keyword in favor of encoding='onehot-dense' --- doc/modules/preprocessing.rst | 4 ++-- sklearn/preprocessing/data.py | 23 +++++++++++------------ sklearn/preprocessing/tests/test_data.py | 14 +++++++------- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 65939afb55c99..d35b8569cb11f 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -471,7 +471,7 @@ Continuing the example above:: >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] >>> enc.fit(X) # doctest: +ELLIPSIS CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, - encoding='onehot', handle_unknown='error', sparse=True) + encoding='onehot', handle_unknown='error') >>> enc.transform([['female', 'from US', 'uses Safari']]).toarray() array([[ 1., 0., 0., 1., 0., 1.]]) @@ -494,7 +494,7 @@ features, one has to explicitly set ``classes``. For example, >>> enc.fit(X) # doctest: +ELLIPSIS CategoricalEncoder(categories=[...], dtype=<... 'numpy.float64'>, encoding='onehot', - handle_unknown='error', sparse=True) + handle_unknown='error') >>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray() array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.]]) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 3f63399f24f6e..c7aec700a7d63 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2579,7 +2579,9 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): - 'onehot': encode the features using a one-hot aka one-of-K scheme (or also called 'dummy' encoding). This creates a binary column for - each category. + each category and returns a sparse matrix. + - 'onehot-dense': the same as 'onehot' but returns a dense array + instead of a sparse matrix. - 'ordinal': encode the features as ordinal integers. This results in a single column of integers (0 to n_categories - 1) per feature. @@ -2594,10 +2596,6 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): Desired dtype of output. This keyword is ignored in case of ``encoding='ordinal'`` (the output is always integer). - sparse : boolean, default=True - Will return sparse matrix if set True else will return an array. - This keyword is ignored in case of ``encoding='ordinal'``. - handle_unknown : str, 'error' or 'ignore' Whether to raise an error or ignore if a unknown categorical feature is present during transform (default is to raise). When this is parameter @@ -2618,7 +2616,7 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): >>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], [1, 0, 2]]) ... # doctest: +ELLIPSIS CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, - encoding='onehot', handle_unknown='error', sparse=True) + encoding='onehot', handle_unknown='error') >>> enc.transform([[0, 1, 1]]).toarray() array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0.]]) @@ -2635,11 +2633,10 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): """ def __init__(self, encoding='onehot', categories='auto', dtype=np.float64, - sparse=True, handle_unknown='error'): + handle_unknown='error'): self.encoding = encoding self.categories = categories self.dtype = dtype - self.sparse = sparse self.handle_unknown = handle_unknown def fit(self, X, y=None): @@ -2655,7 +2652,7 @@ def fit(self, X, y=None): self """ - if self.encoding not in ['onehot', 'ordinal']: + if self.encoding not in ['onehot', 'onehot-dense', 'ordinal']: template = ("encoding should be either 'onehot' or " "'ordinal', got %s") raise ValueError(template % self.handle_unknown) @@ -2701,7 +2698,7 @@ def transform(self, X): Returns ------- - X_out : sparse matrix if sparse=True else a 2-d array + X_out : sparse matrix or a 2-d array Transformed input. """ @@ -2743,5 +2740,7 @@ def transform(self, X): out = sparse.coo_matrix((data, (row_indices, column_indices)), shape=(n_samples, indices[-1]), dtype=self.dtype).tocsr() - - return out if self.sparse else out.toarray() + if self.encoding == 'onehot-dense': + return out.toarray() + else: + return out diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index f234123ea0887..69f21b58d1763 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -1955,11 +1955,11 @@ def test_one_hot_encoder_unknown_transform(): assert_raises(ValueError, oh.transform, y) -def check_categorical(X): - enc = CategoricalEncoder() +def check_categorical_onehot(X): + enc = CategoricalEncoder(encoding='onehot') Xtr1 = enc.fit_transform(X) - enc = CategoricalEncoder(sparse=False) + enc = CategoricalEncoder(encoding='onehot-dense') Xtr2 = enc.fit_transform(X) assert_allclose(Xtr1.toarray(), Xtr2) @@ -1968,13 +1968,13 @@ def check_categorical(X): return Xtr1.toarray() -def test_categorical_encoder(): +def test_categorical_encoder_onehot(): X = [['abc', 1, 55], ['def', 2, 55]] - Xtr = check_categorical(np.array(X)[:, [0]]) + Xtr = check_categorical_onehot(np.array(X)[:, [0]]) assert_allclose(Xtr, [[1, 0], [0, 1]]) - Xtr = check_categorical(np.array(X)[:, [0, 1]]) + Xtr = check_categorical_onehot(np.array(X)[:, [0, 1]]) assert_allclose(Xtr, [[1, 0, 1, 0], [0, 1, 0, 1]]) Xtr = CategoricalEncoder().fit_transform(X) @@ -2039,7 +2039,7 @@ def test_categorical_encoder_pandas(): X_df = pd.DataFrame({'A': ['a', 'b'], 'B': ['c', 'd']}) - Xtr = check_categorical(X_df) + Xtr = check_categorical_onehot(X_df) assert_allclose(Xtr, [[1, 0, 1, 0], [0, 1, 0, 1]]) From 17e5e69a4a7726e170ca0e32f9616d518541b3c3 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 9 Aug 2017 12:11:20 +0200 Subject: [PATCH 10/31] Let encoding='ordinal' follow dtype keyword --- sklearn/preprocessing/data.py | 2 +- sklearn/preprocessing/tests/test_data.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index c7aec700a7d63..cd01350352d85 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2725,7 +2725,7 @@ def transform(self, X): X_int[:, i] = self._label_encoders_[i].transform(X[:, i]) if self.encoding == 'ordinal': - return X_int + return X_int.astype(self.dtype, copy=False) mask = X_mask.ravel() n_values = [le.classes_.shape[0] for le in self._label_encoders_] diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 69f21b58d1763..61ca8a0df8b88 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2054,7 +2054,9 @@ def test_categorical_encoder_ordinal(): enc = CategoricalEncoder(encoding='ordinal') exp = np.array([[0, 1, 0], - [1, 0, 0]]) + [1, 0, 0]], dtype='int64') + assert_array_equal(enc.fit_transform(X), exp.astype('float64')) + enc = CategoricalEncoder(encoding='ordinal', dtype='int64') assert_array_equal(enc.fit_transform(X), exp) From 47a88dd6405cd9e80b87690bded2723462931e7b Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 9 Aug 2017 14:41:20 +0200 Subject: [PATCH 11/31] add categories_ attribute --- sklearn/preprocessing/data.py | 31 ++++++++++++++++-------- sklearn/preprocessing/tests/test_data.py | 14 +++++++++++ 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index cd01350352d85..bc9da001ae67f 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2561,7 +2561,8 @@ def quantile_transform(X, axis=0, n_quantiles=1000, class CategoricalEncoder(BaseEstimator, TransformerMixin): - """Encode categorical features using a one-hot aka one-of-K scheme. + """Encode categorical features using specified encoding scheme (one-hot + or ordinal encoding). The input to this transformer should be a matrix of integers or strings, denoting the values taken on by categorical (discrete) features. The @@ -2590,11 +2591,11 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): - 'auto' : Determine categories automatically from the training data. - list : ``categories[i]`` holds the categories expected in the ith - column. The passed categories are sorted before encoding the data. + column. The passed categories are sorted before encoding the data + (used categories can be found in the ``categories_`` attribute). - dtype : number type, default=np.float - Desired dtype of output. This keyword is ignored in case of - ``encoding='ordinal'`` (the output is always integer). + dtype : number type, default=np.float64 + Desired dtype of output. handle_unknown : str, 'error' or 'ignore' Whether to raise an error or ignore if a unknown categorical feature is @@ -2605,6 +2606,13 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): Ignoring unknown categories is not supported for ``encoding='ordinal'``. + Attributes + ---------- + categories_ : list of arrays + The categories of each feature determined during fitting. When + categories were specified manually, this holds the sorted categories + (in order corresponding with output of `transform`). + Examples -------- Given a dataset with three features and two samples, we let the encoder @@ -2617,8 +2625,9 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): ... # doctest: +ELLIPSIS CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, encoding='onehot', handle_unknown='error') - >>> enc.transform([[0, 1, 1]]).toarray() - array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0.]]) + >>> enc.transform([[0, 1, 1], [1, 0, 3]]).toarray() + array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0.], + [ 0., 1., 1., 0., 0., 0., 0., 0., 1.]]) See also -------- @@ -2686,6 +2695,8 @@ def fit(self, X, y=None): raise ValueError(msg) le.classes_ = np.array(np.sort(self.categories[i])) + self.categories_ = [le.classes_ for le in self._label_encoders_] + return self def transform(self, X): @@ -2708,7 +2719,7 @@ def transform(self, X): X_mask = np.ones_like(X, dtype=np.bool) for i in range(n_features): - valid_mask = np.in1d(X[:, i], self._label_encoders_[i].classes_) + valid_mask = np.in1d(X[:, i], self.categories_[i]) if not np.all(valid_mask): if self.handle_unknown == 'error': @@ -2721,14 +2732,14 @@ def transform(self, X): # continue `The rows are marked `X_mask` and will be # removed later. X_mask[:, i] = valid_mask - X[:, i][~valid_mask] = self._label_encoders_[i].classes_[0] + X[:, i][~valid_mask] = self.categories_[i][0] X_int[:, i] = self._label_encoders_[i].transform(X[:, i]) if self.encoding == 'ordinal': return X_int.astype(self.dtype, copy=False) mask = X_mask.ravel() - n_values = [le.classes_.shape[0] for le in self._label_encoders_] + n_values = [cats.shape[0] for cats in self.categories_] n_values = np.array([0] + n_values) indices = np.cumsum(n_values) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 61ca8a0df8b88..21934cc4a67d0 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -1999,6 +1999,18 @@ def test_categorical_encoder_errors(): assert_allclose(Xtr.toarray(), [[0, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1]]) +def test_categorical_encoder_categories(): + X = [['abc', 1, 55], ['def', 2, 55]] + enc = CategoricalEncoder() + enc.fit(X) + + assert enc.categories == 'auto' + assert isinstance(enc.categories_, list) + cat_exp = [['abc', 'def'], [1, 2], [55]] + for res, exp in zip(enc.categories_, cat_exp): + assert res.tolist() == exp + + def test_categorical_encoder_specified_categories(): X = np.array([['a', 'b']], dtype=object).T @@ -2012,6 +2024,8 @@ def test_categorical_encoder_specified_categories(): exp = np.array([[1., 0., 0.], [0., 1., 0.]]) assert_array_equal(enc.fit_transform(X).toarray(), exp) + assert enc.categories[0] == ['c', 'b', 'a'] + assert enc.categories_[0].tolist() == ['a', 'b', 'c'] # multiple columns X = np.array([['a', 'b'], ['A', 'C']], dtype=object).T From 7b5b476e1e8e8e2719fd7beb531fe382b66e064e Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 25 Aug 2017 16:24:16 +0200 Subject: [PATCH 12/31] expand docs on ordinal + feedback --- doc/modules/preprocessing.rst | 41 ++++++++++++++++------- sklearn/preprocessing/data.py | 40 +++++++++++----------- sklearn/preprocessing/tests/test_data.py | 42 +++++++++++++----------- 3 files changed, 73 insertions(+), 50 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index d35b8569cb11f..7cdb08eeae6d2 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -455,15 +455,29 @@ Such features can be efficiently coded as integers, for instance ``[0, 1, 3]`` while ``["female", "from Asia", "uses Chrome"]`` would be ``[1, 2, 1]``. -Such integer representation can not be used directly with scikit-learn estimators, as these -expect continuous input, and would interpret the categories as being ordered, which is often -not desired (i.e. the set of browsers was ordered arbitrarily). +To convert categorical features to such integer codes, we can use the +:class:`CategoricalEncoder`. When specifying that we want to perform an +ordinal encoding, the estimator transforms each categorical feature to one +new feature of integers (0 to n_categories - 1):: + + >>> enc = preprocessing.CategoricalEncoder(encoding='ordinal') + >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] + >>> enc.fit(X) # doctest: +ELLIPSIS + CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, + encoding='ordinal', handle_unknown='error') + >>> enc.transform([['female', 'from US', 'uses Safari']]) + array([[ 0., 1., 1.]]) + +Such integer representation can, however, not be used directly with all +scikit-learn estimators, as these expect continuous input, and would interpret +the categories as being ordered, which is often not desired (i.e. the set of +browsers was ordered arbitrarily). One possibility to convert categorical features to features that can be used -with scikit-learn estimators is to use a one-of-K or one-hot encoding, which is -implemented in :class:`CategoricalEncoder`. This estimator transforms each -categorical feature with ``m`` possible values into ``m`` binary features, with -only one active. +with scikit-learn estimators is to use a one-of-K or one-hot encoding. This +type of encoding is the default behaviour of the :class:`CategoricalEncoder`. +The estimator then transforms each categorical feature with ``n_categories`` +possible values into ``n_categories`` binary features, with only one active. Continuing the example above:: @@ -475,14 +489,18 @@ Continuing the example above:: >>> enc.transform([['female', 'from US', 'uses Safari']]).toarray() array([[ 1., 0., 0., 1., 0., 1.]]) +By default, how many values each feature can take is inferred automatically +from the dataset and can be found in the ``categories_`` attribute:: -By default, how many values each feature can take is inferred automatically from the dataset. -It is possible to specify this explicitly using the parameter ``classes``. + >>> enc.categories_ + [array(['female', 'male'], dtype=object), array(['from Europe', 'from US'], dtype=object), array(['uses Firefox', 'uses Safari'], dtype=object)] + +It is possible to specify this explicitly using the parameter ``categories``. There are two genders, three possible continents and four web browsers in our dataset. Note that, if there is a possibilty that the training data might have missing categorical -features, one has to explicitly set ``classes``. For example, +features, one has to explicitly set ``categories``. For example, >>> genders = ['male', 'female'] >>> locations = ['from Europe', 'from US', 'from Africa', 'from Asia'] @@ -495,12 +513,11 @@ features, one has to explicitly set ``classes``. For example, CategoricalEncoder(categories=[...], dtype=<... 'numpy.float64'>, encoding='onehot', handle_unknown='error') - >>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray() array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.]]) See :ref:`dict_feature_extraction` for categorical features that are represented -as a dict, not as integers. +as a dict, not as scalars. .. _imputation: diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index bc9da001ae67f..e22f47c547bc1 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1921,9 +1921,10 @@ class OneHotEncoder(BaseEstimator, TransformerMixin): See also -------- - sklearn.preprocessing.CategoricalEncoder : performs a one-hot encoding of - all features (also handles string-valued features). This encoder - derives the categories based on the unique values in the features. + sklearn.preprocessing.CategoricalEncoder : performs a one-hot or ordinal + encoding of all features (also handles string-valued features). This + encoder derives the categories based on the unique values in the + features. sklearn.feature_extraction.DictVectorizer : performs a one-hot encoding of dictionary items (also handles string-valued features). sklearn.feature_extraction.FeatureHasher : performs an approximate one-hot @@ -2561,21 +2562,22 @@ def quantile_transform(X, axis=0, n_quantiles=1000, class CategoricalEncoder(BaseEstimator, TransformerMixin): - """Encode categorical features using specified encoding scheme (one-hot - or ordinal encoding). + """Encode categorical features as a numeric array. The input to this transformer should be a matrix of integers or strings, - denoting the values taken on by categorical (discrete) features. The - output will be a sparse matrix where each column corresponds to one - possible value of one feature. + denoting the values taken on by categorical (discrete) features. + The features can be encoded using a one-hot aka one-of-K scheme + (``encoding='onehot'``, the default) or converted to ordinal integers + (``encoding='ordinal'``). This encoding is needed for feeding categorical data to many scikit-learn estimators, notably linear models and SVMs with the standard kernels. + Read more in the :ref:`User Guide `. Parameters ---------- - encoding : str, 'onehot' or 'ordinal' + encoding : str, 'onehot', 'onehot-dense' or 'ordinal' The type of encoding to use (default is 'onehot'): - 'onehot': encode the features using a one-hot aka one-of-K scheme @@ -2587,17 +2589,17 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): a single column of integers (0 to n_categories - 1) per feature. categories : 'auto' or a list of lists/arrays of values. - Values per feature. + Categories (unique values) per feature: - 'auto' : Determine categories automatically from the training data. - list : ``categories[i]`` holds the categories expected in the ith column. The passed categories are sorted before encoding the data (used categories can be found in the ``categories_`` attribute). - dtype : number type, default=np.float64 + dtype : number type, default np.float64 Desired dtype of output. - handle_unknown : str, 'error' or 'ignore' + handle_unknown : 'error' (default) or 'ignore' Whether to raise an error or ignore if a unknown categorical feature is present during transform (default is to raise). When this is parameter is set to 'ignore' and an unknown category is encountered during @@ -2620,14 +2622,14 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): one-hot encoding. >>> from sklearn.preprocessing import CategoricalEncoder - >>> enc = CategoricalEncoder() + >>> enc = CategoricalEncoder(handle_unknown='ignore') >>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], [1, 0, 2]]) ... # doctest: +ELLIPSIS CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, - encoding='onehot', handle_unknown='error') - >>> enc.transform([[0, 1, 1], [1, 0, 3]]).toarray() + encoding='onehot', handle_unknown='ignore') + >>> enc.transform([[0, 1, 1], [1, 0, 4]]).toarray() array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0.], - [ 0., 1., 1., 0., 0., 0., 0., 0., 1.]]) + [ 0., 1., 1., 0., 0., 0., 0., 0., 0.]]) See also -------- @@ -2662,8 +2664,8 @@ def fit(self, X, y=None): """ if self.encoding not in ['onehot', 'onehot-dense', 'ordinal']: - template = ("encoding should be either 'onehot' or " - "'ordinal', got %s") + template = ("encoding should be either 'onehot', 'onehot-dense' " + "or 'ordinal', got %s") raise ValueError(template % self.handle_unknown) if self.handle_unknown not in ['error', 'ignore']: @@ -2748,7 +2750,7 @@ def transform(self, X): n_features)[mask] data = np.ones(n_samples * n_features)[mask] - out = sparse.coo_matrix((data, (row_indices, column_indices)), + out = sparse.csc_matrix((data, (row_indices, column_indices)), shape=(n_samples, indices[-1]), dtype=self.dtype).tocsr() if self.encoding == 'onehot-dense': diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 21934cc4a67d0..0611b006186ce 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -1964,7 +1964,7 @@ def check_categorical_onehot(X): assert_allclose(Xtr1.toarray(), Xtr2) - assert sparse.issparse(Xtr1) + assert sparse.isspmatrix_csr(Xtr1) return Xtr1.toarray() @@ -1981,34 +1981,39 @@ def test_categorical_encoder_onehot(): assert_allclose(Xtr.toarray(), [[1, 0, 1, 0, 1], [0, 1, 0, 1, 1]]) -def test_categorical_encoder_errors(): +def test_categorical_encoder_handle_unknown(): + X = [[1, 2, 3], [4, 5, 6]] + y = [[7, 5, 3]] + # Test that encoder raises error for unknown features during transform. enc = CategoricalEncoder() - X = [[1, 2, 3], [4, 5, 6]] enc.fit(X) + msg = re.escape('unknown categories [7] in column 0') + assert_raises_regex(ValueError, msg, enc.transform, y) - X[0][0] = -1 - msg = re.escape('unknown categories [-1] in column 0') - assert_raises_regex(ValueError, msg, enc.transform, X) - + # With 'ignore' you get all 0's in result enc = CategoricalEncoder(handle_unknown='ignore') - X = [[1, 2, 3], [4, 5, 6]] enc.fit(X) - X[0][0] = -1 - Xtr = enc.transform(X) - assert_allclose(Xtr.toarray(), [[0, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1]]) + Xtr = enc.transform(y) + assert_allclose(Xtr.toarray(), [[0, 0, 0, 1, 1, 0]]) + + # Invalid option + enc = CategoricalEncoder(handle_unknown='invalid') + assert_raises(ValueError, enc.fit, X) def test_categorical_encoder_categories(): X = [['abc', 1, 55], ['def', 2, 55]] - enc = CategoricalEncoder() - enc.fit(X) - assert enc.categories == 'auto' - assert isinstance(enc.categories_, list) - cat_exp = [['abc', 'def'], [1, 2], [55]] - for res, exp in zip(enc.categories_, cat_exp): - assert res.tolist() == exp + # order of categories should not depend on order of samples + for Xi in [X, X[::-1]]: + enc = CategoricalEncoder() + enc.fit(Xi) + assert enc.categories == 'auto' + assert isinstance(enc.categories_, list) + cat_exp = [['abc', 'def'], [1, 2], [55]] + for res, exp in zip(enc.categories_, cat_exp): + assert res.tolist() == exp def test_categorical_encoder_specified_categories(): @@ -2045,7 +2050,6 @@ def test_categorical_encoder_specified_categories(): def test_categorical_encoder_pandas(): - try: import pandas as pd except ImportError: From 3dcc07ffea4ea8376f8590a539f67c46517b51c7 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 19 Oct 2017 18:54:33 +0200 Subject: [PATCH 13/31] feedback Andy --- doc/modules/preprocessing.rst | 9 ++++---- sklearn/feature_extraction/dict_vectorizer.py | 2 +- sklearn/preprocessing/data.py | 23 +++++++++---------- sklearn/preprocessing/label.py | 4 ++-- sklearn/preprocessing/tests/test_data.py | 8 +++---- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 7271baa05e362..6eefe3b9729e3 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -474,10 +474,11 @@ the categories as being ordered, which is often not desired (i.e. the set of browsers was ordered arbitrarily). One possibility to convert categorical features to features that can be used -with scikit-learn estimators is to use a one-of-K or one-hot encoding. This -type of encoding is the default behaviour of the :class:`CategoricalEncoder`. -The estimator then transforms each categorical feature with ``n_categories`` -possible values into ``n_categories`` binary features, with only one active. +with scikit-learn estimators is to use a one-of-K, one-hot or dummy encoding. +This type of encoding is the default behaviour of the :class:`CategoricalEncoder`. +The :class:`CategoricalEncoder` then transforms each categorical feature with +``n_categories`` possible values into ``n_categories`` binary features, with +one of them 1, and all others 0. Continuing the example above:: diff --git a/sklearn/feature_extraction/dict_vectorizer.py b/sklearn/feature_extraction/dict_vectorizer.py index 72d9687573baa..194e922596bf5 100644 --- a/sklearn/feature_extraction/dict_vectorizer.py +++ b/sklearn/feature_extraction/dict_vectorizer.py @@ -90,7 +90,7 @@ class DictVectorizer(BaseEstimator, TransformerMixin): -------- FeatureHasher : performs vectorization using only a hash function. sklearn.preprocessing.CategoricalEncoder : handles nominal/categorical - features encoded as columns of arbitraty data types. + features encoded as columns of arbitrary data types. """ def __init__(self, dtype=np.float64, separator="=", sparse=True, diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 62efc3dcc267e..08a5ee87c836f 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2576,8 +2576,8 @@ def quantile_transform(X, axis=0, n_quantiles=1000, class CategoricalEncoder(BaseEstimator, TransformerMixin): """Encode categorical features as a numeric array. - The input to this transformer should be a matrix of integers or strings, - denoting the values taken on by categorical (discrete) features. + The input to this transformer should be an array-like of integers or + strings, denoting the values taken on by categorical (discrete) features. The features can be encoded using a one-hot aka one-of-K scheme (``encoding='onehot'``, the default) or converted to ordinal integers (``encoding='ordinal'``). @@ -2613,7 +2613,7 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): handle_unknown : 'error' (default) or 'ignore' Whether to raise an error or ignore if a unknown categorical feature is - present during transform (default is to raise). When this is parameter + present during transform (default is to raise). When this parameter is set to 'ignore' and an unknown category is encountered during transform, the resulting one-hot encoded columns for this feature will be all zeros. @@ -2623,15 +2623,14 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): Attributes ---------- categories_ : list of arrays - The categories of each feature determined during fitting. When - categories were specified manually, this holds the sorted categories + The categories of each feature determined during fitting. If + categories are specified manually, this holds the sorted categories (in order corresponding with output of `transform`). Examples -------- - Given a dataset with three features and two samples, we let the encoder - find the maximum value per feature and transform the data to a binary - one-hot encoding. + Given a dataset with three features, we let the encoder find the unique + values per feature and transform the data to a binary one-hot encoding. >>> from sklearn.preprocessing import CategoricalEncoder >>> enc = CategoricalEncoder(handle_unknown='ignore') @@ -2667,7 +2666,7 @@ def fit(self, X, y=None): Parameters ---------- - X : array-like, shape [n_samples, n_feature] + X : array-like, shape [n_samples, n_features] The data to determine the categories of each feature. Returns @@ -2689,7 +2688,7 @@ def fit(self, X, y=None): raise ValueError("handle_unknown='ignore' is not supported for" " encoding='ordinal'") - X = check_array(X, dtype=np.object, accept_sparse='csc', copy=True) + X = check_array(X, dtype=np.object, copy=True) n_samples, n_features = X.shape self._label_encoders_ = [LabelEncoder() for _ in range(n_features)] @@ -2714,7 +2713,7 @@ def fit(self, X, y=None): return self def transform(self, X): - """Transform X using one-hot encoding. + """Transform X using specified encoding scheme. Parameters ---------- @@ -2727,7 +2726,7 @@ def transform(self, X): Transformed input. """ - X = check_array(X, accept_sparse='csc', dtype=np.object, copy=True) + X = check_array(X, dtype=np.object, copy=True) n_samples, n_features = X.shape X_int = np.zeros_like(X, dtype=np.int) X_mask = np.ones_like(X, dtype=np.bool) diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index 530f376c19fa9..d68ec7d6b4fad 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -76,8 +76,8 @@ class LabelEncoder(BaseEstimator, TransformerMixin): See also -------- - sklearn.preprocessing.OneHotEncoder : encode categorical integer features - using a one-hot aka one-of-K scheme. + sklearn.preprocessing.CategoricalEncoder : encode categorical features + using a one-hot or ordinal encoding scheme. """ def fit(self, y): diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index a03fec937caa7..30ab28891730c 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -1870,7 +1870,7 @@ def _check_transform_selected(X, X_expected, sel): def test_transform_selected(): - X = np.array([[3, 2, 1], [0, 1, 1]]) + X = [[3, 2, 1], [0, 1, 1]] X_expected = [[1, 2, 1], [0, 1, 1]] _check_transform_selected(X, X_expected, [0]) @@ -1992,18 +1992,18 @@ def test_categorical_encoder_onehot(): def test_categorical_encoder_handle_unknown(): X = [[1, 2, 3], [4, 5, 6]] - y = [[7, 5, 3]] + X2 = [[7, 5, 3]] # Test that encoder raises error for unknown features during transform. enc = CategoricalEncoder() enc.fit(X) msg = re.escape('unknown categories [7] in column 0') - assert_raises_regex(ValueError, msg, enc.transform, y) + assert_raises_regex(ValueError, msg, enc.transform, X2) # With 'ignore' you get all 0's in result enc = CategoricalEncoder(handle_unknown='ignore') enc.fit(X) - Xtr = enc.transform(y) + Xtr = enc.transform(X2) assert_allclose(Xtr.toarray(), [[0, 0, 0, 1, 1, 0]]) # Invalid option From 5f5934f43f4b7d022b66d5301b78b98715738b14 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 19 Oct 2017 19:06:39 +0200 Subject: [PATCH 14/31] add whatsnew note --- doc/whats_new/_contributors.rst | 2 ++ doc/whats_new/v0.20.rst | 13 ++++++++++++- sklearn/preprocessing/data.py | 6 +++--- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/_contributors.rst b/doc/whats_new/_contributors.rst index dfbc319da88f4..0650f6d7517cc 100644 --- a/doc/whats_new/_contributors.rst +++ b/doc/whats_new/_contributors.rst @@ -141,3 +141,5 @@ .. _Neeraj Gangwar: http://neerajgangwar.in .. _Arthur Mensch: https://amensch.fr + +.. _Joris Van den Bossche: https://github.com/jorisvandenbossche diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 51d2fab65be81..25296308258ca 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -40,7 +40,18 @@ Classifiers and regressors - Added :class:`naive_bayes.ComplementNB`, which implements the Complement Naive Bayes classifier described in Rennie et al. (2003). By :user:`Michael A. Alcorn `. - + +Preprocessing: + +- Added :class:`preprocessing.CategoricalEncoder`, which allows to encode + categorical features as a numeric array, either using a one-hot (or + dummy) encoding scheme or by converting to ordinal integers. + Compared to the existing :class:`OneHotEncoder`, this new class handles + encoding of all feature types (also handles string-valued features) and + derives the categories based on the unique values in the features instead of + the maximum value in the features. + By :user:`Vighnesh Birodkar ` and `Joris Van den Bossche`_. + Model evaluation - Added the :func:`metrics.balanced_accuracy` metric and a corresponding diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 08a5ee87c836f..238bbc3d78243 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2578,9 +2578,9 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): The input to this transformer should be an array-like of integers or strings, denoting the values taken on by categorical (discrete) features. - The features can be encoded using a one-hot aka one-of-K scheme - (``encoding='onehot'``, the default) or converted to ordinal integers - (``encoding='ordinal'``). + The features can be encoded using a one-hot (aka one-of-K or dummy) + encoding scheme (``encoding='onehot'``, the default) or converted + to ordinal integers (``encoding='ordinal'``). This encoding is needed for feeding categorical data to many scikit-learn estimators, notably linear models and SVMs with the standard kernels. From c6a5d309a1abc0f539897c9af81044cbf8302be9 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 20 Oct 2017 11:42:01 +0200 Subject: [PATCH 15/31] for now raise on unsorted passed categories --- sklearn/preprocessing/data.py | 13 ++++++++++--- sklearn/preprocessing/tests/test_data.py | 11 +++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 238bbc3d78243..01b48ecb79b07 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2605,8 +2605,9 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): - 'auto' : Determine categories automatically from the training data. - list : ``categories[i]`` holds the categories expected in the ith - column. The passed categories are sorted before encoding the data - (used categories can be found in the ``categories_`` attribute). + column. The passed categories should be sorted. + + The used categories can be found in the ``categories_`` attribute. dtype : number type, default np.float64 Desired dtype of output. @@ -2688,6 +2689,12 @@ def fit(self, X, y=None): raise ValueError("handle_unknown='ignore' is not supported for" " encoding='ordinal'") + if self.categories != 'auto': + for cats in self.categories: + if not np.all(np.sort(cats) == np.array(cats)): + raise ValueError("Unsorted categories are not yet " + "supported") + X = check_array(X, dtype=np.object, copy=True) n_samples, n_features = X.shape @@ -2706,7 +2713,7 @@ def fit(self, X, y=None): msg = ("Found unknown categories {0} in column {1}" " during fit".format(diff, i)) raise ValueError(msg) - le.classes_ = np.array(np.sort(self.categories[i])) + le.classes_ = np.array(self.categories[i]) self.categories_ = [le.classes_ for le in self._label_encoders_] diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 30ab28891730c..ee4cd801688fc 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2032,14 +2032,13 @@ def test_categorical_encoder_specified_categories(): exp = np.array([[1., 0., 0.], [0., 1., 0.]]) assert_array_equal(enc.fit_transform(X).toarray(), exp) + assert enc.categories[0] == ['a', 'b', 'c'] + assert enc.categories_[0].tolist() == ['a', 'b', 'c'] - # don't follow order of passed categories, but sort them + # unsorted passed categories raises for now enc = CategoricalEncoder(categories=[['c', 'b', 'a']]) - exp = np.array([[1., 0., 0.], - [0., 1., 0.]]) - assert_array_equal(enc.fit_transform(X).toarray(), exp) - assert enc.categories[0] == ['c', 'b', 'a'] - assert enc.categories_[0].tolist() == ['a', 'b', 'c'] + msg = re.escape('Unsorted categories are not yet supported') + assert_raises_regex(ValueError, msg, enc.fit_transform, X) # multiple columns X = np.array([['a', 'b'], ['A', 'C']], dtype=object).T From ad5fdc73b6cdb142a446baa380204f212030a5b7 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 20 Oct 2017 14:29:41 +0200 Subject: [PATCH 16/31] Implement inverse_transform --- sklearn/preprocessing/data.py | 58 ++++++++++++++++++++++++ sklearn/preprocessing/tests/test_data.py | 44 ++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 01b48ecb79b07..991308e13ea65 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2775,3 +2775,61 @@ def transform(self, X): return out.toarray() else: return out + + def inverse_transform(self, X): + """Convert back the data to the original representation. + + Parameters + ---------- + X : array-like or sparse matrix, shape [n_samples, n_encoded_features] + The transformed data. + + Returns + ------- + X_tr : array-like, shape [n_samples, n_features] + Inverse transformed array. + + """ + check_is_fitted(self, 'categories_') + + n_samples, _ = X.shape + n_features = len(self.categories_) + + dt = np.find_common_type([cat.dtype for cat in self.categories_], []) + X_tr = np.empty((n_samples, n_features), dtype=dt) + + if self.encoding == 'ordinal': + for i in range(n_features): + labels = X[:, i].astype('int64') + X_tr[:, i] = self.categories_[i][labels] + + else: # encoding == 'onehot' / 'onehot-dense' + j = 0 + found_unknown = {} + + for i in range(n_features): + n_categories = len(self.categories_[i]) + sub = X[:, j:j + n_categories] + + # for sparse X argmax returns 2D matrix, ensure 1D array + labels = np.asarray(sub.argmax(axis=1)).flatten() + X_tr[:, i] = self.categories_[i][labels] + + if self.handle_unknown == 'ignore': + # ignored unknown categories: we have a row of all zero's + unknown = np.asarray(sub.sum(axis=1) == 0).flatten() + if unknown.any(): + found_unknown[i] = unknown + + j += n_categories + + # if ignored are found: potentially need to upcast result to + # insert None values + if found_unknown: + if X_tr.dtype != object: + X_tr = X_tr.astype(object) + + for idx, mask in found_unknown.items(): + X_tr[mask, idx] = None + + return X_tr diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index ee4cd801688fc..4c490805eca65 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -1990,6 +1990,42 @@ def test_categorical_encoder_onehot(): assert_allclose(Xtr.toarray(), [[1, 0, 1, 0, 1], [0, 1, 0, 1, 1]]) +def test_categorical_encoder_onehot_inverse(): + for encoding in ['onehot', 'onehot-dense']: + X = [['abc', 2, 55], ['def', 1, 55], ['abc', 3, 55]] + enc = CategoricalEncoder(encoding=encoding) + X_tr = enc.fit_transform(X) + exp = np.array(X, dtype=object) + assert_array_equal(enc.inverse_transform(X_tr), exp) + + X = [[2, 55], [1, 55], [3, 55]] + enc = CategoricalEncoder(encoding=encoding) + X_tr = enc.fit_transform(X) + exp = np.array(X) + assert_array_equal(enc.inverse_transform(X_tr), exp) + + # with unknown categories + X = [['abc', 2, 55], ['def', 1, 55], ['abc', 3, 55]] + enc = CategoricalEncoder(encoding=encoding, handle_unknown='ignore', + categories=[['abc', 'def'], [1, 2], + [54, 55, 56]]) + X_tr = enc.fit_transform(X) + exp = np.array(X, dtype=object) + exp[2, 1] = None + assert_array_equal(enc.inverse_transform(X_tr), exp) + + # with an otherwise numerical output, still object if unknown + X = [[2, 55], [1, 55], [3, 55]] + enc = CategoricalEncoder(encoding=encoding, + categories=[[1, 2], [54, 56]], + handle_unknown='ignore') + X_tr = enc.fit_transform(X) + exp = np.array(X, dtype=object) + exp[2, 0] = None + exp[:, 1] = None + assert_array_equal(enc.inverse_transform(X_tr), exp) + + def test_categorical_encoder_handle_unknown(): X = [[1, 2, 3], [4, 5, 6]] X2 = [[7, 5, 3]] @@ -2086,6 +2122,14 @@ def test_categorical_encoder_ordinal(): assert_array_equal(enc.fit_transform(X), exp) +def test_categorical_encoder_ordinal_inverse(): + X = [['abc', 2, 55], ['def', 1, 55]] + enc = CategoricalEncoder(encoding='ordinal') + X_tr = enc.fit_transform(X) + exp = np.array(X, dtype=object) + assert_array_equal(enc.inverse_transform(X_tr), exp) + + def test_fit_cold_start(): X = iris.data X_2d = X[:, :2] From eb2f4b8ea95b9f20c5a8083a7cd99d20800ba2f8 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 20 Oct 2017 21:30:00 +0200 Subject: [PATCH 17/31] fix example to have sorted categories --- doc/modules/preprocessing.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 6eefe3b9729e3..5cda4081b30df 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -503,9 +503,9 @@ dataset. Note that, if there is a possibility that the training data might have missing categorical features, one has to explicitly set ``categories``. For example, - >>> genders = ['male', 'female'] - >>> locations = ['from Europe', 'from US', 'from Africa', 'from Asia'] - >>> browsers = ['uses Safari', 'uses Firefox', 'uses IE', 'uses Chrome'] + >>> genders = ['female', 'male'] + >>> locations = ['from Africa', 'from Asia', 'from Europe', 'from US'] + >>> browsers = ['uses Chrome', 'uses Firefox', 'uses IE', 'uses Safari'] >>> enc = preprocessing.CategoricalEncoder(categories=[genders, locations, browsers]) >>> # Note that for there are missing categorical values for the 2nd and 3rd >>> # feature From ce82c289fc46473414918da5985bc2a639bcc1d2 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 27 Oct 2017 21:04:43 +0200 Subject: [PATCH 18/31] backport scipy sparse argmax --- sklearn/preprocessing/data.py | 11 ++-- sklearn/utils/fixes.py | 102 ++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 5 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 991308e13ea65..9ac37e31cf9b2 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -23,6 +23,7 @@ from ..utils import check_array from ..utils.extmath import row_norms from ..utils.extmath import _incremental_mean_and_var +from ..utils.fixes import argmax from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1, inplace_csr_row_normalize_l2) from ..utils.sparsefuncs import (inplace_column_scale, @@ -2605,7 +2606,7 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): - 'auto' : Determine categories automatically from the training data. - list : ``categories[i]`` holds the categories expected in the ith - column. The passed categories should be sorted. + column. The passed categories must be sorted. The used categories can be found in the ``categories_`` attribute. @@ -2624,9 +2625,8 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): Attributes ---------- categories_ : list of arrays - The categories of each feature determined during fitting. If - categories are specified manually, this holds the sorted categories - (in order corresponding with output of `transform`). + The categories of each feature determined during fitting + (in order corresponding with output of ``transform``). Examples -------- @@ -2791,6 +2791,7 @@ def inverse_transform(self, X): """ check_is_fitted(self, 'categories_') + X = check_array(X, accept_sparse='csr') n_samples, _ = X.shape n_features = len(self.categories_) @@ -2812,7 +2813,7 @@ def inverse_transform(self, X): sub = X[:, j:j + n_categories] # for sparse X argmax returns 2D matrix, ensure 1D array - labels = np.asarray(sub.argmax(axis=1)).flatten() + labels = np.asarray(argmax(sub, axis=1)).flatten() X_tr[:, i] = self.categories_[i][labels] if self.handle_unknown == 'ignore': diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index 0e96057f929ff..cae7d9eb71a44 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -150,6 +150,108 @@ def sparse_min_max(X, axis): from scipy.misc import comb, logsumexp # noqa +if sp_version >= (0, 19): + def argmax(arr_or_spmatrix, axis=None): + return arr_or_spmatrix.argmax(axis=axis) +else: + # Backport of argmax functionality from scipy 0.19.1, can be removed + # once support for scipy 0.18 and below is dropped + + def _find_missing_index(ind, n): + for k, a in enumerate(ind): + if k != a: + return k + + k += 1 + if k < n: + return k + else: + return -1 + + def _arg_min_or_max_axis(self, axis, op, compare): + if self.shape[axis] == 0: + raise ValueError("Can't apply the operation along a zero-sized " + "dimension.") + + if axis < 0: + axis += 2 + + zero = self.dtype.type(0) + + mat = self.tocsc() if axis == 0 else self.tocsr() + mat.sum_duplicates() + + ret_size, line_size = mat._swap(mat.shape) + ret = np.zeros(ret_size, dtype=int) + + nz_lines, = np.nonzero(np.diff(mat.indptr)) + for i in nz_lines: + p, q = mat.indptr[i:i + 2] + data = mat.data[p:q] + indices = mat.indices[p:q] + am = op(data) + m = data[am] + if compare(m, zero) or q - p == line_size: + ret[i] = indices[am] + else: + zero_ind = _find_missing_index(indices, line_size) + if m == zero: + ret[i] = min(am, zero_ind) + else: + ret[i] = zero_ind + + if axis == 1: + ret = ret.reshape(-1, 1) + + return np.asmatrix(ret) + + def _arg_min_or_max(self, axis, out, op, compare): + if out is not None: + raise ValueError("Sparse matrices do not support " + "an 'out' parameter.") + + # validateaxis(axis) + + if axis is None: + if 0 in self.shape: + raise ValueError("Can't apply the operation to " + "an empty matrix.") + + if self.nnz == 0: + return 0 + else: + zero = self.dtype.type(0) + mat = self.tocoo() + mat.sum_duplicates() + am = op(mat.data) + m = mat.data[am] + + if compare(m, zero): + return mat.row[am] * mat.shape[1] + mat.col[am] + else: + size = np.product(mat.shape) + if size == mat.nnz: + return am + else: + ind = mat.row * mat.shape[1] + mat.col + zero_ind = _find_missing_index(ind, size) + if m == zero: + return min(zero_ind, am) + else: + return zero_ind + + return _arg_min_or_max_axis(self, axis, op, compare) + + def sparse_argmax(self, axis=None, out=None): + return _arg_min_or_max(self, axis, out, np.argmax, np.greater) + + def argmax(arr_or_matrix, axis=None): + if sp.issparse(arr_or_matrix): + return sparse_argmax(arr_or_matrix, axis=axis) + else: + return arr_or_matrix.argmax(axis=axis) + + def parallel_helper(obj, methodname, *args, **kwargs): """Workaround for Python 2 limitations of pickling instance methods""" return getattr(obj, methodname)(*args, **kwargs) From 64aeff5baba35be3fe2b3defb7e572a18eb0efc8 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 27 Oct 2017 21:11:02 +0200 Subject: [PATCH 19/31] check handle_unknown before computation in fit --- sklearn/preprocessing/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 9ac37e31cf9b2..39624739fb253 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2706,9 +2706,9 @@ def fit(self, X, y=None): if self.categories == 'auto': le.fit(Xi) else: - valid_mask = np.in1d(Xi, self.categories[i]) - if not np.all(valid_mask): - if self.handle_unknown == 'error': + if self.handle_unknown == 'error': + valid_mask = np.in1d(Xi, self.categories[i]) + if not np.all(valid_mask): diff = np.unique(Xi[~valid_mask]) msg = ("Found unknown categories {0} in column {1}" " during fit".format(diff, i)) From a1c09824cc6964b488acef53b703e99359545d0b Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 27 Oct 2017 22:34:45 +0200 Subject: [PATCH 20/31] make scipy backport private --- sklearn/preprocessing/data.py | 4 ++-- sklearn/utils/fixes.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 39624739fb253..20b2efe01d34d 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -23,7 +23,7 @@ from ..utils import check_array from ..utils.extmath import row_norms from ..utils.extmath import _incremental_mean_and_var -from ..utils.fixes import argmax +from ..utils.fixes import _argmax from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1, inplace_csr_row_normalize_l2) from ..utils.sparsefuncs import (inplace_column_scale, @@ -2813,7 +2813,7 @@ def inverse_transform(self, X): sub = X[:, j:j + n_categories] # for sparse X argmax returns 2D matrix, ensure 1D array - labels = np.asarray(argmax(sub, axis=1)).flatten() + labels = np.asarray(_argmax(sub, axis=1)).flatten() X_tr[:, i] = self.categories_[i][labels] if self.handle_unknown == 'ignore': diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index cae7d9eb71a44..3c81a2f86d35b 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -151,7 +151,7 @@ def sparse_min_max(X, axis): if sp_version >= (0, 19): - def argmax(arr_or_spmatrix, axis=None): + def _argmax(arr_or_spmatrix, axis=None): return arr_or_spmatrix.argmax(axis=axis) else: # Backport of argmax functionality from scipy 0.19.1, can be removed @@ -242,12 +242,12 @@ def _arg_min_or_max(self, axis, out, op, compare): return _arg_min_or_max_axis(self, axis, op, compare) - def sparse_argmax(self, axis=None, out=None): + def _sparse_argmax(self, axis=None, out=None): return _arg_min_or_max(self, axis, out, np.argmax, np.greater) - def argmax(arr_or_matrix, axis=None): + def _argmax(arr_or_matrix, axis=None): if sp.issparse(arr_or_matrix): - return sparse_argmax(arr_or_matrix, axis=axis) + return _sparse_argmax(arr_or_matrix, axis=axis) else: return arr_or_matrix.argmax(axis=axis) From 85cf315f0847d09f0f9c21180a1a9cc3cdcc3570 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 30 Oct 2017 18:19:26 +0100 Subject: [PATCH 21/31] Directly construct CSR matrix --- sklearn/preprocessing/data.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 20b2efe01d34d..823dad140e76d 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2761,16 +2761,16 @@ def transform(self, X): mask = X_mask.ravel() n_values = [cats.shape[0] for cats in self.categories_] n_values = np.array([0] + n_values) - indices = np.cumsum(n_values) + feature_indices = np.cumsum(n_values) - column_indices = (X_int + indices[:-1]).ravel()[mask] - row_indices = np.repeat(np.arange(n_samples, dtype=np.int32), - n_features)[mask] + indices = (X_int + feature_indices[:-1]).ravel()[mask] + indptr = X_mask.sum(axis=1).cumsum() + indptr = np.insert(indptr, 0, 0) data = np.ones(n_samples * n_features)[mask] - out = sparse.csc_matrix((data, (row_indices, column_indices)), - shape=(n_samples, indices[-1]), - dtype=self.dtype).tocsr() + out = sparse.csr_matrix((data, indices, indptr), + shape=(n_samples, feature_indices[-1]), + dtype=self.dtype) if self.encoding == 'onehot-dense': return out.toarray() else: From b40bd8e4eb96fa9514fdb955af2fe1b40661b1fe Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 30 Oct 2017 18:38:49 +0100 Subject: [PATCH 22/31] try to preserve original dtype if resulting dtype is not string --- sklearn/preprocessing/data.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 823dad140e76d..bb394a0ac9d8a 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2695,7 +2695,12 @@ def fit(self, X, y=None): raise ValueError("Unsorted categories are not yet " "supported") - X = check_array(X, dtype=np.object, copy=True) + X_temp = check_array(X, dtype=None, copy=True) + if not hasattr(X, 'dtype') and np.issubdtype(X_temp.dtype, str): + X = check_array(X, dtype=np.object, copy=True) + else: + X = X_temp + n_samples, n_features = X.shape self._label_encoders_ = [LabelEncoder() for _ in range(n_features)] @@ -2733,7 +2738,12 @@ def transform(self, X): Transformed input. """ - X = check_array(X, dtype=np.object, copy=True) + X_temp = check_array(X, dtype=None, copy=True) + if not hasattr(X, 'dtype') and np.issubdtype(X_temp.dtype, str): + X = check_array(X, dtype=np.object, copy=True) + else: + X = X_temp + n_samples, n_features = X.shape X_int = np.zeros_like(X, dtype=np.int) X_mask = np.ones_like(X, dtype=np.bool) From a31bb2a48815a82e763df6997ce3aab3e8030543 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 31 Oct 2017 11:52:48 +0100 Subject: [PATCH 23/31] Remove copying of data, only copy when needed in transform + add test --- sklearn/preprocessing/data.py | 16 +++++++++------- sklearn/preprocessing/tests/test_data.py | 9 ++++++--- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index bb394a0ac9d8a..fee7ac0d7de5e 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2695,9 +2695,9 @@ def fit(self, X, y=None): raise ValueError("Unsorted categories are not yet " "supported") - X_temp = check_array(X, dtype=None, copy=True) + X_temp = check_array(X, dtype=None) if not hasattr(X, 'dtype') and np.issubdtype(X_temp.dtype, str): - X = check_array(X, dtype=np.object, copy=True) + X = check_array(X, dtype=np.object) else: X = X_temp @@ -2738,9 +2738,9 @@ def transform(self, X): Transformed input. """ - X_temp = check_array(X, dtype=None, copy=True) + X_temp = check_array(X, dtype=None) if not hasattr(X, 'dtype') and np.issubdtype(X_temp.dtype, str): - X = check_array(X, dtype=np.object, copy=True) + X = check_array(X, dtype=np.object) else: X = X_temp @@ -2749,7 +2749,8 @@ def transform(self, X): X_mask = np.ones_like(X, dtype=np.bool) for i in range(n_features): - valid_mask = np.in1d(X[:, i], self.categories_[i]) + Xi = X[:, i] + valid_mask = np.in1d(Xi, self.categories_[i]) if not np.all(valid_mask): if self.handle_unknown == 'error': @@ -2762,8 +2763,9 @@ def transform(self, X): # continue `The rows are marked `X_mask` and will be # removed later. X_mask[:, i] = valid_mask - X[:, i][~valid_mask] = self.categories_[i][0] - X_int[:, i] = self._label_encoders_[i].transform(X[:, i]) + Xi = Xi.copy() + Xi[~valid_mask] = self.categories_[i][0] + X_int[:, i] = self._label_encoders_[i].transform(Xi) if self.encoding == 'ordinal': return X_int.astype(self.dtype, copy=False) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index f22d88c389f5f..60cf15ce40aea 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2034,8 +2034,8 @@ def test_categorical_encoder_onehot_inverse(): def test_categorical_encoder_handle_unknown(): - X = [[1, 2, 3], [4, 5, 6]] - X2 = [[7, 5, 3]] + X = np.array([[1, 2, 3], [4, 5, 6]]) + X2 = np.array([[7, 5, 3]]) # Test that encoder raises error for unknown features during transform. enc = CategoricalEncoder() @@ -2046,8 +2046,11 @@ def test_categorical_encoder_handle_unknown(): # With 'ignore' you get all 0's in result enc = CategoricalEncoder(handle_unknown='ignore') enc.fit(X) - Xtr = enc.transform(X2) + X2_passed = X2.copy() + Xtr = enc.transform(X2_passed) assert_allclose(Xtr.toarray(), [[0, 0, 0, 1, 1, 0]]) + # ensure transformed data was not modified in place + assert_allclose(X2, X2_passed) # Invalid option enc = CategoricalEncoder(handle_unknown='invalid') From 2ef5fb914343827ae4febe5787e7f447f00a89d6 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 31 Oct 2017 12:45:11 +0100 Subject: [PATCH 24/31] add test for input dtypes / categories_ dtypes --- sklearn/preprocessing/tests/test_data.py | 69 ++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 3 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 60cf15ce40aea..a110915e37e57 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2080,6 +2080,7 @@ def test_categorical_encoder_specified_categories(): assert_array_equal(enc.fit_transform(X).toarray(), exp) assert enc.categories[0] == ['a', 'b', 'c'] assert enc.categories_[0].tolist() == ['a', 'b', 'c'] + assert np.issubdtype(enc.categories_[0].dtype, str) # unsorted passed categories raises for now enc = CategoricalEncoder(categories=[['c', 'b', 'a']]) @@ -2087,11 +2088,15 @@ def test_categorical_encoder_specified_categories(): assert_raises_regex(ValueError, msg, enc.fit_transform, X) # multiple columns - X = np.array([['a', 'b'], ['A', 'C']], dtype=object).T - enc = CategoricalEncoder(categories=[['a', 'b', 'c'], ['A', 'B', 'C']]) + X = np.array([['a', 'b'], [0, 2]], dtype=object).T + enc = CategoricalEncoder(categories=[['a', 'b', 'c'], [0, 1, 2]]) exp = np.array([[1., 0., 0., 1., 0., 0.], [0., 1., 0., 0., 0., 1.]]) assert_array_equal(enc.fit_transform(X).toarray(), exp) + assert enc.categories_[0].tolist() == ['a', 'b', 'c'] + assert np.issubdtype(enc.categories_[0].dtype, str) + assert enc.categories_[1].tolist() == [0, 1, 2] + assert np.issubdtype(enc.categories_[1].dtype, np.integer) # when specifying categories manually, unknown categories should already # raise when fitting @@ -2109,7 +2114,7 @@ def test_categorical_encoder_pandas(): except ImportError: raise SkipTest("pandas is not installed") - X_df = pd.DataFrame({'A': ['a', 'b'], 'B': ['c', 'd']}) + X_df = pd.DataFrame({'A': ['a', 'b'], 'B': [1, 2]}) Xtr = check_categorical_onehot(X_df) assert_allclose(Xtr, [[1, 0, 1, 0], [0, 1, 0, 1]]) @@ -2140,6 +2145,64 @@ def test_categorical_encoder_ordinal_inverse(): assert_array_equal(enc.inverse_transform(X_tr), exp) +def test_categorical_encoder_dtypes(): + # check that dtypes are preserved when determining categories + enc = CategoricalEncoder() + exp = np.array([[1., 0., 1., 0.], [0., 1., 0., 1.]], dtype='float64') + + X = np.array([[1, 2], [3, 4]], dtype='int64') + enc.fit(X) + assert all([enc.categories_[i].dtype == 'int64' for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + X = np.array([[1, 2], [3, 4]], dtype='float64') + enc.fit(X) + assert all([enc.categories_[i].dtype == 'float64' for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + X = np.array([['a', 'b'], ['c', 'd']]) # string dtype + enc.fit(X) + assert all([enc.categories_[i].dtype == X.dtype for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + X = np.array([[1, 'a'], [3, 'b']], dtype='object') + enc.fit(X) + assert all([enc.categories_[i].dtype == 'object' for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + X = [[1, 2], [3, 4]] + enc.fit(X) + assert all([np.issubdtype(enc.categories_[i].dtype, np.integer) + for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + X = [[1, 'a'], [3, 'b']] + enc.fit(X) + assert all([enc.categories_[i].dtype == 'object' for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + +def test_categorical_encoder_dtypes_pandas(): + # check dtype (similar to test_categorical_encoder_dtypes for dataframes) + try: + import pandas as pd + except ImportError: + raise SkipTest("pandas is not installed") + + enc = CategoricalEncoder() + exp = np.array([[1., 0., 1., 0.], [0., 1., 0., 1.]], dtype='float64') + + X = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}, dtype='int64') + enc.fit(X) + assert all([enc.categories_[i].dtype == 'int64' for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + X = pd.DataFrame({'A': [1, 2], 'B': ['a', 'b']}) + enc.fit(X) + assert all([enc.categories_[i].dtype == 'object' for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + def test_fit_cold_start(): X = iris.data X_2d = X[:, :2] From 937446ef7d691d9628dc70b8d7f9f36bfdba485c Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 31 Oct 2017 15:26:15 +0100 Subject: [PATCH 25/31] doc updates based on feedback --- doc/modules/preprocessing.rst | 27 +++++++++++++++---------- sklearn/preprocessing/data.py | 37 ++++++++++++++++++++++++----------- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 77919ce7d137a..d09de4bbeda1b 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -473,8 +473,9 @@ scikit-learn estimators, as these expect continuous input, and would interpret the categories as being ordered, which is often not desired (i.e. the set of browsers was ordered arbitrarily). -One possibility to convert categorical features to features that can be used -with scikit-learn estimators is to use a one-of-K, one-hot or dummy encoding. +Another possibility to convert categorical features to features that can be used +with scikit-learn estimators is to use a one-of-K, also known as one-hot or +dummy encoding. This type of encoding is the default behaviour of the :class:`CategoricalEncoder`. The :class:`CategoricalEncoder` then transforms each categorical feature with ``n_categories`` possible values into ``n_categories`` binary features, with @@ -487,21 +488,20 @@ Continuing the example above:: >>> enc.fit(X) # doctest: +ELLIPSIS CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, encoding='onehot', handle_unknown='error') - >>> enc.transform([['female', 'from US', 'uses Safari']]).toarray() - array([[ 1., 0., 0., 1., 0., 1.]]) + >>> enc.transform([['female', 'from US', 'uses Safari'], + ... ['male', 'from Europe', 'uses Safari']]).toarray() + array([[ 1., 0., 0., 1., 0., 1.], + [ 0., 1., 1., 0., 0., 1.]]) -By default, how many values each feature can take is inferred automatically +By default, the values each feature can take is inferred automatically from the dataset and can be found in the ``categories_`` attribute:: >>> enc.categories_ [array(['female', 'male'], dtype=object), array(['from Europe', 'from US'], dtype=object), array(['uses Firefox', 'uses Safari'], dtype=object)] It is possible to specify this explicitly using the parameter ``categories``. -There are two genders, three possible continents and four web browsers in our -dataset. - -Note that, if there is a possibility that the training data might have missing categorical -features, one has to explicitly set ``categories``. For example, +There are two genders, four possible continents and four web browsers in our +dataset:: >>> genders = ['female', 'male'] >>> locations = ['from Africa', 'from Asia', 'from Europe', 'from US'] @@ -517,6 +517,13 @@ features, one has to explicitly set ``categories``. For example, >>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray() array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.]]) +If there is a possibility that the training data might have missing categorical +features, it is especially important to explicitly set ``categories`` manually. +Alternatively, one can specify ``handle_unknown='ignore'``. In that case, when +unknown categories are encountered during transform, no error will be raised but +the resulting one-hot encoded columns for this feature will be all zeros +(``handle_unknown='ignore'`` is only supported for one one-hot encoding). + See :ref:`dict_feature_extraction` for categorical features that are represented as a dict, not as scalars. diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index fee7ac0d7de5e..2fe9ffc907ae7 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1803,16 +1803,21 @@ def add_dummy_feature(X, value=1.0): def _transform_selected(X, transform, selected="all", copy=True): """Apply a transform function to portion of selected features + Parameters ---------- X : {array-like, sparse matrix}, shape [n_samples, n_features] Dense array or sparse matrix. + transform : callable A callable transform(X) -> X_transformed + copy : boolean, optional Copy X even if it could be avoided. + selected: "all" or array of indices or mask Specify which features to apply the transform to. + Returns ------- X : array or sparse matrix, shape=(n_samples, n_features_new) @@ -1936,8 +1941,8 @@ class OneHotEncoder(BaseEstimator, TransformerMixin): -------- sklearn.preprocessing.CategoricalEncoder : performs a one-hot or ordinal encoding of all features (also handles string-valued features). This - encoder derives the categories based on the unique values in the - features. + encoder derives the categories based on the unique values in each + feature. sklearn.feature_extraction.DictVectorizer : performs a one-hot encoding of dictionary items (also handles string-valued features). sklearn.feature_extraction.FeatureHasher : performs an approximate one-hot @@ -2606,7 +2611,8 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): - 'auto' : Determine categories automatically from the training data. - list : ``categories[i]`` holds the categories expected in the ith - column. The passed categories must be sorted. + column. The passed categories must be sorted and should not mix + strings and numeric values. The used categories can be found in the ``categories_`` attribute. @@ -2618,7 +2624,8 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): present during transform (default is to raise). When this parameter is set to 'ignore' and an unknown category is encountered during transform, the resulting one-hot encoded columns for this feature - will be all zeros. + will be all zeros. In the inverse transform, an unknown category + will be denoted as None. Ignoring unknown categories is not supported for ``encoding='ordinal'``. @@ -2630,18 +2637,23 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): Examples -------- - Given a dataset with three features, we let the encoder find the unique + Given a dataset with two features, we let the encoder find the unique values per feature and transform the data to a binary one-hot encoding. >>> from sklearn.preprocessing import CategoricalEncoder >>> enc = CategoricalEncoder(handle_unknown='ignore') - >>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], [1, 0, 2]]) - ... # doctest: +ELLIPSIS - CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, + >>> X = [['Male', 1], ['Female', 3], ['Female', 2]] + >>> enc.fit(X) + CategoricalEncoder(categories='auto', dtype=, encoding='onehot', handle_unknown='ignore') - >>> enc.transform([[0, 1, 1], [1, 0, 4]]).toarray() - array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0.], - [ 0., 1., 1., 0., 0., 0., 0., 0., 0.]]) + >>> enc.categories_ + [array(['Female', 'Male'], dtype=object), array([1, 2, 3], dtype=object)] + >>> enc.transform([['Female', 1], ['Male', 4]]).toarray() + array([[ 1., 0., 1., 0., 0.], + [ 0., 1., 0., 0., 0.]]) + >>> enc.inverse_transform([[0, 1, 1, 0, 0], [0, 0, 0, 1, 0]]) + array([['Male', 1], + [None, 2]], dtype=object) See also -------- @@ -2791,6 +2803,9 @@ def transform(self, X): def inverse_transform(self, X): """Convert back the data to the original representation. + In case unknown categories are encountered (all zero's in the + one-hot encoding), ``None`` is used to represent this category. + Parameters ---------- X : array-like or sparse matrix, shape [n_samples, n_encoded_features] From a83102c549f113e31a0febdd1fbfb1dd4445ae54 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 31 Oct 2017 16:54:25 +0100 Subject: [PATCH 26/31] fix docstring example for python 2 --- sklearn/preprocessing/data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 2fe9ffc907ae7..1e40bf0136a59 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2644,7 +2644,8 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin): >>> enc = CategoricalEncoder(handle_unknown='ignore') >>> X = [['Male', 1], ['Female', 3], ['Female', 2]] >>> enc.fit(X) - CategoricalEncoder(categories='auto', dtype=, + ... # doctest: +ELLIPSIS + CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, encoding='onehot', handle_unknown='ignore') >>> enc.categories_ [array(['Female', 'Male'], dtype=object), array([1, 2, 3], dtype=object)] @@ -2685,8 +2686,8 @@ def fit(self, X, y=None): Returns ------- self - """ + """ if self.encoding not in ['onehot', 'onehot-dense', 'ordinal']: template = ("encoding should be either 'onehot', 'onehot-dense' " "or 'ordinal', got %s") From 21d9c0c98a431bc5cd618cbdaee10c75ac50a162 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 7 Nov 2017 10:44:21 +0100 Subject: [PATCH 27/31] add checking of shape of X in inverse_transform --- sklearn/preprocessing/data.py | 13 ++++++++++++- sklearn/preprocessing/tests/test_data.py | 10 ++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 1e40bf0136a59..b4549e09e6291 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2823,7 +2823,18 @@ def inverse_transform(self, X): n_samples, _ = X.shape n_features = len(self.categories_) - + n_transformed_features = sum([len(cats) for cats in self.categories_]) + + # validate shape of passed X + msg = ("Shape of the passed X data is not correct. Expected {0} " + "columns, got {1}.") + if self.encoding == 'ordinal' and X.shape[1] != n_features: + raise ValueError(msg.format(n_features, X.shape[1])) + elif (self.encoding.startswith('onehot') + and X.shape[1] != n_transformed_features): + raise ValueError(msg.format(n_transformed_features, X.shape[1])) + + # create resulting array of appropriate dtype dt = np.find_common_type([cat.dtype for cat in self.categories_], []) X_tr = np.empty((n_samples, n_features), dtype=dt) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index a110915e37e57..15bfedfcd1c8a 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2032,6 +2032,11 @@ def test_categorical_encoder_onehot_inverse(): exp[:, 1] = None assert_array_equal(enc.inverse_transform(X_tr), exp) + # incorrect shape raises + X_tr = np.array([[0, 1, 1], [1, 0, 1]]) + msg = re.escape('Shape of the passed X data is not correct') + assert_raises_regex(ValueError, msg, enc.inverse_transform, X_tr) + def test_categorical_encoder_handle_unknown(): X = np.array([[1, 2, 3], [4, 5, 6]]) @@ -2144,6 +2149,11 @@ def test_categorical_encoder_ordinal_inverse(): exp = np.array(X, dtype=object) assert_array_equal(enc.inverse_transform(X_tr), exp) + # incorrect shape raises + X_tr = np.array([[0, 1, 1, 2], [1, 0, 1, 0]]) + msg = re.escape('Shape of the passed X data is not correct') + assert_raises_regex(ValueError, msg, enc.inverse_transform, X_tr) + def test_categorical_encoder_dtypes(): # check that dtypes are preserved when determining categories From 929362fbc1f90a6374ac48a99faffc0f76c42d98 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 9 Nov 2017 09:40:26 +0100 Subject: [PATCH 28/31] loopify dtype tests --- sklearn/preprocessing/tests/test_data.py | 26 +++++++----------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 15bfedfcd1c8a..e715ceacfac25 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -2160,25 +2160,13 @@ def test_categorical_encoder_dtypes(): enc = CategoricalEncoder() exp = np.array([[1., 0., 1., 0.], [0., 1., 0., 1.]], dtype='float64') - X = np.array([[1, 2], [3, 4]], dtype='int64') - enc.fit(X) - assert all([enc.categories_[i].dtype == 'int64' for i in range(2)]) - assert_array_equal(enc.transform(X).toarray(), exp) - - X = np.array([[1, 2], [3, 4]], dtype='float64') - enc.fit(X) - assert all([enc.categories_[i].dtype == 'float64' for i in range(2)]) - assert_array_equal(enc.transform(X).toarray(), exp) - - X = np.array([['a', 'b'], ['c', 'd']]) # string dtype - enc.fit(X) - assert all([enc.categories_[i].dtype == X.dtype for i in range(2)]) - assert_array_equal(enc.transform(X).toarray(), exp) - - X = np.array([[1, 'a'], [3, 'b']], dtype='object') - enc.fit(X) - assert all([enc.categories_[i].dtype == 'object' for i in range(2)]) - assert_array_equal(enc.transform(X).toarray(), exp) + for X in [np.array([[1, 2], [3, 4]], dtype='int64'), + np.array([[1, 2], [3, 4]], dtype='float64'), + np.array([['a', 'b'], ['c', 'd']]), # string dtype + np.array([[1, 'a'], [3, 'b']], dtype='object')]: + enc.fit(X) + assert all([enc.categories_[i].dtype == X.dtype for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) X = [[1, 2], [3, 4]] enc.fit(X) From a6d55d1cd309aa13f0df6144ab8337fb415b1bf4 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 9 Nov 2017 09:53:32 +0100 Subject: [PATCH 29/31] reword example on unknown categories --- doc/modules/preprocessing.rst | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index d09de4bbeda1b..4537074e9308f 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -518,11 +518,20 @@ dataset:: array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.]]) If there is a possibility that the training data might have missing categorical -features, it is especially important to explicitly set ``categories`` manually. -Alternatively, one can specify ``handle_unknown='ignore'``. In that case, when +features, it can often be better to specify ``handle_unknown='ignore'`` instead +of setting the ``categories`` manually as above. In that case, when unknown categories are encountered during transform, no error will be raised but the resulting one-hot encoded columns for this feature will be all zeros -(``handle_unknown='ignore'`` is only supported for one one-hot encoding). +(``handle_unknown='ignore'`` is only supported for one one-hot encoding):: + + >>> enc = preprocessing.CategoricalEncoder(handle_unknown='ignore') + >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] + >>> enc.fit(X) # doctest: +ELLIPSIS + CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, + encoding='onehot', handle_unknown='ignore') + >>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray() + array([[ 1., 0., 0., 0., 0., 0.]]) + See :ref:`dict_feature_extraction` for categorical features that are represented as a dict, not as scalars. From 9aeeb6d1ee919a47ea2b99d46e6a961db4fb707d Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 9 Nov 2017 13:47:12 +0100 Subject: [PATCH 30/31] clarify docs --- doc/modules/preprocessing.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 4537074e9308f..22b97054ec209 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -519,9 +519,10 @@ dataset:: If there is a possibility that the training data might have missing categorical features, it can often be better to specify ``handle_unknown='ignore'`` instead -of setting the ``categories`` manually as above. In that case, when -unknown categories are encountered during transform, no error will be raised but -the resulting one-hot encoded columns for this feature will be all zeros +of setting the ``categories`` manually as above. When +``handle_unknown='ignore'`` is specified and unknown categories are encountered +during transform, no error will be raised but the resulting one-hot encoded +columns for this feature will be all zeros (``handle_unknown='ignore'`` is only supported for one one-hot encoding):: >>> enc = preprocessing.CategoricalEncoder(handle_unknown='ignore') From c39aa0c289a632cd22d733195a672f3b93265c61 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 9 Nov 2017 23:05:44 +0100 Subject: [PATCH 31/31] remove repeated one --- doc/modules/preprocessing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 22b97054ec209..0c08063331c61 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -523,7 +523,7 @@ of setting the ``categories`` manually as above. When ``handle_unknown='ignore'`` is specified and unknown categories are encountered during transform, no error will be raised but the resulting one-hot encoded columns for this feature will be all zeros -(``handle_unknown='ignore'`` is only supported for one one-hot encoding):: +(``handle_unknown='ignore'`` is only supported for one-hot encoding):: >>> enc = preprocessing.CategoricalEncoder(handle_unknown='ignore') >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']]