diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 5e53e99dcc176..f1a2e973d187f 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1198,6 +1198,7 @@ Model validation preprocessing.MinMaxScaler preprocessing.Normalizer preprocessing.OneHotEncoder + preprocessing.CategoricalEncoder preprocessing.PolynomialFeatures preprocessing.QuantileTransformer preprocessing.RobustScaler diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 8bcb14363d69c..0c08063331c61 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -455,47 +455,87 @@ Such features can be efficiently coded as integers, for instance ``[0, 1, 3]`` while ``["female", "from Asia", "uses Chrome"]`` would be ``[1, 2, 1]``. -Such integer representation can not be used directly with scikit-learn estimators, as these -expect continuous input, and would interpret the categories as being ordered, which is often -not desired (i.e. the set of browsers was ordered arbitrarily). - -One possibility to convert categorical features to features that can be used -with scikit-learn estimators is to use a one-of-K or one-hot encoding, which is -implemented in :class:`OneHotEncoder`. This estimator transforms each -categorical feature with ``m`` possible values into ``m`` binary features, with -only one active. +To convert categorical features to such integer codes, we can use the +:class:`CategoricalEncoder`. When specifying that we want to perform an +ordinal encoding, the estimator transforms each categorical feature to one +new feature of integers (0 to n_categories - 1):: + + >>> enc = preprocessing.CategoricalEncoder(encoding='ordinal') + >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] + >>> enc.fit(X) # doctest: +ELLIPSIS + CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, + encoding='ordinal', handle_unknown='error') + >>> enc.transform([['female', 'from US', 'uses Safari']]) + array([[ 0., 1., 1.]]) + +Such integer representation can, however, not be used directly with all +scikit-learn estimators, as these expect continuous input, and would interpret +the categories as being ordered, which is often not desired (i.e. the set of +browsers was ordered arbitrarily). + +Another possibility to convert categorical features to features that can be used +with scikit-learn estimators is to use a one-of-K, also known as one-hot or +dummy encoding. +This type of encoding is the default behaviour of the :class:`CategoricalEncoder`. +The :class:`CategoricalEncoder` then transforms each categorical feature with +``n_categories`` possible values into ``n_categories`` binary features, with +one of them 1, and all others 0. Continuing the example above:: - >>> enc = preprocessing.OneHotEncoder() - >>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], [1, 0, 2]]) # doctest: +ELLIPSIS - OneHotEncoder(categorical_features='all', dtype=<... 'numpy.float64'>, - handle_unknown='error', n_values='auto', sparse=True) - >>> enc.transform([[0, 1, 3]]).toarray() - array([[ 1., 0., 0., 1., 0., 0., 0., 0., 1.]]) - -By default, how many values each feature can take is inferred automatically from the dataset. -It is possible to specify this explicitly using the parameter ``n_values``. -There are two genders, three possible continents and four web browsers in our -dataset. -Then we fit the estimator, and transform a data point. -In the result, the first two numbers encode the gender, the next set of three -numbers the continent and the last four the web browser. - -Note that, if there is a possibility that the training data might have missing categorical -features, one has to explicitly set ``n_values``. For example, - - >>> enc = preprocessing.OneHotEncoder(n_values=[2, 3, 4]) - >>> # Note that there are missing categorical values for the 2nd and 3rd - >>> # features - >>> enc.fit([[1, 2, 3], [0, 2, 0]]) # doctest: +ELLIPSIS - OneHotEncoder(categorical_features='all', dtype=<... 'numpy.float64'>, - handle_unknown='error', n_values=[2, 3, 4], sparse=True) - >>> enc.transform([[1, 0, 0]]).toarray() - array([[ 0., 1., 1., 0., 0., 1., 0., 0., 0.]]) + >>> enc = preprocessing.CategoricalEncoder() + >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] + >>> enc.fit(X) # doctest: +ELLIPSIS + CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, + encoding='onehot', handle_unknown='error') + >>> enc.transform([['female', 'from US', 'uses Safari'], + ... ['male', 'from Europe', 'uses Safari']]).toarray() + array([[ 1., 0., 0., 1., 0., 1.], + [ 0., 1., 1., 0., 0., 1.]]) + +By default, the values each feature can take is inferred automatically +from the dataset and can be found in the ``categories_`` attribute:: + + >>> enc.categories_ + [array(['female', 'male'], dtype=object), array(['from Europe', 'from US'], dtype=object), array(['uses Firefox', 'uses Safari'], dtype=object)] + +It is possible to specify this explicitly using the parameter ``categories``. +There are two genders, four possible continents and four web browsers in our +dataset:: + + >>> genders = ['female', 'male'] + >>> locations = ['from Africa', 'from Asia', 'from Europe', 'from US'] + >>> browsers = ['uses Chrome', 'uses Firefox', 'uses IE', 'uses Safari'] + >>> enc = preprocessing.CategoricalEncoder(categories=[genders, locations, browsers]) + >>> # Note that for there are missing categorical values for the 2nd and 3rd + >>> # feature + >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] + >>> enc.fit(X) # doctest: +ELLIPSIS + CategoricalEncoder(categories=[...], + dtype=<... 'numpy.float64'>, encoding='onehot', + handle_unknown='error') + >>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray() + array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.]]) + +If there is a possibility that the training data might have missing categorical +features, it can often be better to specify ``handle_unknown='ignore'`` instead +of setting the ``categories`` manually as above. When +``handle_unknown='ignore'`` is specified and unknown categories are encountered +during transform, no error will be raised but the resulting one-hot encoded +columns for this feature will be all zeros +(``handle_unknown='ignore'`` is only supported for one-hot encoding):: + + >>> enc = preprocessing.CategoricalEncoder(handle_unknown='ignore') + >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] + >>> enc.fit(X) # doctest: +ELLIPSIS + CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, + encoding='onehot', handle_unknown='ignore') + >>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray() + array([[ 1., 0., 0., 0., 0., 0.]]) + See :ref:`dict_feature_extraction` for categorical features that are represented -as a dict, not as integers. +as a dict, not as scalars. .. _imputation: diff --git a/doc/whats_new/_contributors.rst b/doc/whats_new/_contributors.rst index a80c220192582..558d49b9e2548 100644 --- a/doc/whats_new/_contributors.rst +++ b/doc/whats_new/_contributors.rst @@ -150,3 +150,5 @@ .. _Neeraj Gangwar: http://neerajgangwar.in .. _Arthur Mensch: https://amensch.fr + +.. _Joris Van den Bossche: https://github.com/jorisvandenbossche diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 0897f331ebda0..a89d482555308 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -44,6 +44,17 @@ Classifiers and regressors Naive Bayes classifier described in Rennie et al. (2003). By :user:`Michael A. Alcorn `. +Preprocessing + +- Added :class:`preprocessing.CategoricalEncoder`, which allows to encode + categorical features as a numeric array, either using a one-hot (or + dummy) encoding scheme or by converting to ordinal integers. + Compared to the existing :class:`OneHotEncoder`, this new class handles + encoding of all feature types (also handles string-valued features) and + derives the categories based on the unique values in the features instead of + the maximum value in the features. + By :user:`Vighnesh Birodkar ` and `Joris Van den Bossche`_. + Model evaluation - Added the :func:`metrics.balanced_accuracy` metric and a corresponding diff --git a/examples/ensemble/plot_feature_transformation.py b/examples/ensemble/plot_feature_transformation.py index e004c167e67af..50347de57f12b 100644 --- a/examples/ensemble/plot_feature_transformation.py +++ b/examples/ensemble/plot_feature_transformation.py @@ -34,7 +34,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.ensemble import (RandomTreesEmbedding, RandomForestClassifier, GradientBoostingClassifier) -from sklearn.preprocessing import OneHotEncoder +from sklearn.preprocessing import CategoricalEncoder from sklearn.model_selection import train_test_split from sklearn.metrics import roc_curve from sklearn.pipeline import make_pipeline @@ -62,7 +62,7 @@ # Supervised transformation based on random forests rf = RandomForestClassifier(max_depth=3, n_estimators=n_estimator) -rf_enc = OneHotEncoder() +rf_enc = CategoricalEncoder() rf_lm = LogisticRegression() rf.fit(X_train, y_train) rf_enc.fit(rf.apply(X_train)) @@ -72,7 +72,7 @@ fpr_rf_lm, tpr_rf_lm, _ = roc_curve(y_test, y_pred_rf_lm) grd = GradientBoostingClassifier(n_estimators=n_estimator) -grd_enc = OneHotEncoder() +grd_enc = CategoricalEncoder() grd_lm = LogisticRegression() grd.fit(X_train, y_train) grd_enc.fit(grd.apply(X_train)[:, :, 0]) diff --git a/sklearn/feature_extraction/dict_vectorizer.py b/sklearn/feature_extraction/dict_vectorizer.py index e6b52c8009cad..194e922596bf5 100644 --- a/sklearn/feature_extraction/dict_vectorizer.py +++ b/sklearn/feature_extraction/dict_vectorizer.py @@ -39,7 +39,8 @@ class DictVectorizer(BaseEstimator, TransformerMixin): However, note that this transformer will only do a binary one-hot encoding when feature values are of type string. If categorical features are represented as numeric values such as int, the DictVectorizer can be - followed by OneHotEncoder to complete binary one-hot encoding. + followed by :class:`sklearn.preprocessing.CategoricalEncoder` to complete + binary one-hot encoding. Features that do not occur in a sample (mapping) will have a zero value in the resulting array/matrix. @@ -88,8 +89,8 @@ class DictVectorizer(BaseEstimator, TransformerMixin): See also -------- FeatureHasher : performs vectorization using only a hash function. - sklearn.preprocessing.OneHotEncoder : handles nominal/categorical features - encoded as columns of integers. + sklearn.preprocessing.CategoricalEncoder : handles nominal/categorical + features encoded as columns of arbitrary data types. """ def __init__(self, dtype=np.float64, separator="=", sparse=True, diff --git a/sklearn/preprocessing/__init__.py b/sklearn/preprocessing/__init__.py index 2b105709ffe08..0f5054e57f608 100644 --- a/sklearn/preprocessing/__init__.py +++ b/sklearn/preprocessing/__init__.py @@ -22,6 +22,7 @@ from .data import minmax_scale from .data import quantile_transform from .data import OneHotEncoder +from .data import CategoricalEncoder from .data import PolynomialFeatures @@ -46,6 +47,7 @@ 'QuantileTransformer', 'Normalizer', 'OneHotEncoder', + 'CategoricalEncoder', 'RobustScaler', 'StandardScaler', 'add_dummy_feature', diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 0d88f6c4c56f7..b4549e09e6291 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -23,6 +23,7 @@ from ..utils import check_array from ..utils.extmath import row_norms from ..utils.extmath import _incremental_mean_and_var +from ..utils.fixes import _argmax from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1, inplace_csr_row_normalize_l2) from ..utils.sparsefuncs import (inplace_column_scale, @@ -30,6 +31,9 @@ min_max_axis) from ..utils.validation import (check_is_fitted, check_random_state, FLOAT_DTYPES) +from .label import LabelEncoder + + BOUNDS_THRESHOLD = 1e-7 @@ -1856,7 +1860,9 @@ class OneHotEncoder(BaseEstimator, TransformerMixin): the values taken on by categorical (discrete) features. The output will be a sparse matrix where each column corresponds to one possible value of one feature. It is assumed that input features take on values in the range - [0, n_values). + [0, n_values). For an encoder based on the unique values of the input + features of any type, see the + :class:`~sklearn.preprocessing.CategoricalEncoder`. This encoding is needed for feeding categorical data to many scikit-learn estimators, notably linear models and SVMs with the standard kernels. @@ -1933,6 +1939,10 @@ class OneHotEncoder(BaseEstimator, TransformerMixin): See also -------- + sklearn.preprocessing.CategoricalEncoder : performs a one-hot or ordinal + encoding of all features (also handles string-valued features). This + encoder derives the categories based on the unique values in each + feature. sklearn.feature_extraction.DictVectorizer : performs a one-hot encoding of dictionary items (also handles string-valued features). sklearn.feature_extraction.FeatureHasher : performs an approximate one-hot @@ -2567,3 +2577,299 @@ def quantile_transform(X, axis=0, n_quantiles=1000, else: raise ValueError("axis should be either equal to 0 or 1. Got" " axis={}".format(axis)) + + +class CategoricalEncoder(BaseEstimator, TransformerMixin): + """Encode categorical features as a numeric array. + + The input to this transformer should be an array-like of integers or + strings, denoting the values taken on by categorical (discrete) features. + The features can be encoded using a one-hot (aka one-of-K or dummy) + encoding scheme (``encoding='onehot'``, the default) or converted + to ordinal integers (``encoding='ordinal'``). + + This encoding is needed for feeding categorical data to many scikit-learn + estimators, notably linear models and SVMs with the standard kernels. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + encoding : str, 'onehot', 'onehot-dense' or 'ordinal' + The type of encoding to use (default is 'onehot'): + + - 'onehot': encode the features using a one-hot aka one-of-K scheme + (or also called 'dummy' encoding). This creates a binary column for + each category and returns a sparse matrix. + - 'onehot-dense': the same as 'onehot' but returns a dense array + instead of a sparse matrix. + - 'ordinal': encode the features as ordinal integers. This results in + a single column of integers (0 to n_categories - 1) per feature. + + categories : 'auto' or a list of lists/arrays of values. + Categories (unique values) per feature: + + - 'auto' : Determine categories automatically from the training data. + - list : ``categories[i]`` holds the categories expected in the ith + column. The passed categories must be sorted and should not mix + strings and numeric values. + + The used categories can be found in the ``categories_`` attribute. + + dtype : number type, default np.float64 + Desired dtype of output. + + handle_unknown : 'error' (default) or 'ignore' + Whether to raise an error or ignore if a unknown categorical feature is + present during transform (default is to raise). When this parameter + is set to 'ignore' and an unknown category is encountered during + transform, the resulting one-hot encoded columns for this feature + will be all zeros. In the inverse transform, an unknown category + will be denoted as None. + Ignoring unknown categories is not supported for + ``encoding='ordinal'``. + + Attributes + ---------- + categories_ : list of arrays + The categories of each feature determined during fitting + (in order corresponding with output of ``transform``). + + Examples + -------- + Given a dataset with two features, we let the encoder find the unique + values per feature and transform the data to a binary one-hot encoding. + + >>> from sklearn.preprocessing import CategoricalEncoder + >>> enc = CategoricalEncoder(handle_unknown='ignore') + >>> X = [['Male', 1], ['Female', 3], ['Female', 2]] + >>> enc.fit(X) + ... # doctest: +ELLIPSIS + CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>, + encoding='onehot', handle_unknown='ignore') + >>> enc.categories_ + [array(['Female', 'Male'], dtype=object), array([1, 2, 3], dtype=object)] + >>> enc.transform([['Female', 1], ['Male', 4]]).toarray() + array([[ 1., 0., 1., 0., 0.], + [ 0., 1., 0., 0., 0.]]) + >>> enc.inverse_transform([[0, 1, 1, 0, 0], [0, 0, 0, 1, 0]]) + array([['Male', 1], + [None, 2]], dtype=object) + + See also + -------- + sklearn.preprocessing.OneHotEncoder : performs a one-hot encoding of + integer ordinal features. The ``OneHotEncoder assumes`` that input + features take on values in the range ``[0, max(feature)]`` instead of + using the unique values. + sklearn.feature_extraction.DictVectorizer : performs a one-hot encoding of + dictionary items (also handles string-valued features). + sklearn.feature_extraction.FeatureHasher : performs an approximate one-hot + encoding of dictionary items or strings. + """ + + def __init__(self, encoding='onehot', categories='auto', dtype=np.float64, + handle_unknown='error'): + self.encoding = encoding + self.categories = categories + self.dtype = dtype + self.handle_unknown = handle_unknown + + def fit(self, X, y=None): + """Fit the CategoricalEncoder to X. + + Parameters + ---------- + X : array-like, shape [n_samples, n_features] + The data to determine the categories of each feature. + + Returns + ------- + self + + """ + if self.encoding not in ['onehot', 'onehot-dense', 'ordinal']: + template = ("encoding should be either 'onehot', 'onehot-dense' " + "or 'ordinal', got %s") + raise ValueError(template % self.handle_unknown) + + if self.handle_unknown not in ['error', 'ignore']: + template = ("handle_unknown should be either 'error' or " + "'ignore', got %s") + raise ValueError(template % self.handle_unknown) + + if self.encoding == 'ordinal' and self.handle_unknown == 'ignore': + raise ValueError("handle_unknown='ignore' is not supported for" + " encoding='ordinal'") + + if self.categories != 'auto': + for cats in self.categories: + if not np.all(np.sort(cats) == np.array(cats)): + raise ValueError("Unsorted categories are not yet " + "supported") + + X_temp = check_array(X, dtype=None) + if not hasattr(X, 'dtype') and np.issubdtype(X_temp.dtype, str): + X = check_array(X, dtype=np.object) + else: + X = X_temp + + n_samples, n_features = X.shape + + self._label_encoders_ = [LabelEncoder() for _ in range(n_features)] + + for i in range(n_features): + le = self._label_encoders_[i] + Xi = X[:, i] + if self.categories == 'auto': + le.fit(Xi) + else: + if self.handle_unknown == 'error': + valid_mask = np.in1d(Xi, self.categories[i]) + if not np.all(valid_mask): + diff = np.unique(Xi[~valid_mask]) + msg = ("Found unknown categories {0} in column {1}" + " during fit".format(diff, i)) + raise ValueError(msg) + le.classes_ = np.array(self.categories[i]) + + self.categories_ = [le.classes_ for le in self._label_encoders_] + + return self + + def transform(self, X): + """Transform X using specified encoding scheme. + + Parameters + ---------- + X : array-like, shape [n_samples, n_features] + The data to encode. + + Returns + ------- + X_out : sparse matrix or a 2-d array + Transformed input. + + """ + X_temp = check_array(X, dtype=None) + if not hasattr(X, 'dtype') and np.issubdtype(X_temp.dtype, str): + X = check_array(X, dtype=np.object) + else: + X = X_temp + + n_samples, n_features = X.shape + X_int = np.zeros_like(X, dtype=np.int) + X_mask = np.ones_like(X, dtype=np.bool) + + for i in range(n_features): + Xi = X[:, i] + valid_mask = np.in1d(Xi, self.categories_[i]) + + if not np.all(valid_mask): + if self.handle_unknown == 'error': + diff = np.unique(X[~valid_mask, i]) + msg = ("Found unknown categories {0} in column {1}" + " during transform".format(diff, i)) + raise ValueError(msg) + else: + # Set the problematic rows to an acceptable value and + # continue `The rows are marked `X_mask` and will be + # removed later. + X_mask[:, i] = valid_mask + Xi = Xi.copy() + Xi[~valid_mask] = self.categories_[i][0] + X_int[:, i] = self._label_encoders_[i].transform(Xi) + + if self.encoding == 'ordinal': + return X_int.astype(self.dtype, copy=False) + + mask = X_mask.ravel() + n_values = [cats.shape[0] for cats in self.categories_] + n_values = np.array([0] + n_values) + feature_indices = np.cumsum(n_values) + + indices = (X_int + feature_indices[:-1]).ravel()[mask] + indptr = X_mask.sum(axis=1).cumsum() + indptr = np.insert(indptr, 0, 0) + data = np.ones(n_samples * n_features)[mask] + + out = sparse.csr_matrix((data, indices, indptr), + shape=(n_samples, feature_indices[-1]), + dtype=self.dtype) + if self.encoding == 'onehot-dense': + return out.toarray() + else: + return out + + def inverse_transform(self, X): + """Convert back the data to the original representation. + + In case unknown categories are encountered (all zero's in the + one-hot encoding), ``None`` is used to represent this category. + + Parameters + ---------- + X : array-like or sparse matrix, shape [n_samples, n_encoded_features] + The transformed data. + + Returns + ------- + X_tr : array-like, shape [n_samples, n_features] + Inverse transformed array. + + """ + check_is_fitted(self, 'categories_') + X = check_array(X, accept_sparse='csr') + + n_samples, _ = X.shape + n_features = len(self.categories_) + n_transformed_features = sum([len(cats) for cats in self.categories_]) + + # validate shape of passed X + msg = ("Shape of the passed X data is not correct. Expected {0} " + "columns, got {1}.") + if self.encoding == 'ordinal' and X.shape[1] != n_features: + raise ValueError(msg.format(n_features, X.shape[1])) + elif (self.encoding.startswith('onehot') + and X.shape[1] != n_transformed_features): + raise ValueError(msg.format(n_transformed_features, X.shape[1])) + + # create resulting array of appropriate dtype + dt = np.find_common_type([cat.dtype for cat in self.categories_], []) + X_tr = np.empty((n_samples, n_features), dtype=dt) + + if self.encoding == 'ordinal': + for i in range(n_features): + labels = X[:, i].astype('int64') + X_tr[:, i] = self.categories_[i][labels] + + else: # encoding == 'onehot' / 'onehot-dense' + j = 0 + found_unknown = {} + + for i in range(n_features): + n_categories = len(self.categories_[i]) + sub = X[:, j:j + n_categories] + + # for sparse X argmax returns 2D matrix, ensure 1D array + labels = np.asarray(_argmax(sub, axis=1)).flatten() + X_tr[:, i] = self.categories_[i][labels] + + if self.handle_unknown == 'ignore': + # ignored unknown categories: we have a row of all zero's + unknown = np.asarray(sub.sum(axis=1) == 0).flatten() + if unknown.any(): + found_unknown[i] = unknown + + j += n_categories + + # if ignored are found: potentially need to upcast result to + # insert None values + if found_unknown: + if X_tr.dtype != object: + X_tr = X_tr.astype(object) + + for idx, mask in found_unknown.items(): + X_tr[mask, idx] = None + + return X_tr diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index 530f376c19fa9..d68ec7d6b4fad 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -76,8 +76,8 @@ class LabelEncoder(BaseEstimator, TransformerMixin): See also -------- - sklearn.preprocessing.OneHotEncoder : encode categorical integer features - using a one-hot aka one-of-K scheme. + sklearn.preprocessing.CategoricalEncoder : encode categorical features + using a one-hot or ordinal encoding scheme. """ def fit(self, y): diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index e777fb5ffe98b..e715ceacfac25 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -6,6 +6,7 @@ from __future__ import division import warnings +import re import numpy as np import numpy.linalg as la from scipy import sparse @@ -30,6 +31,7 @@ from sklearn.utils.testing import assert_no_warnings from sklearn.utils.testing import assert_allclose from sklearn.utils.testing import skip_if_32bit +from sklearn.utils.testing import SkipTest from sklearn.utils.sparsefuncs import mean_variance_axis from sklearn.preprocessing.data import _transform_selected @@ -39,6 +41,7 @@ from sklearn.preprocessing.data import Normalizer from sklearn.preprocessing.data import normalize from sklearn.preprocessing.data import OneHotEncoder +from sklearn.preprocessing.data import CategoricalEncoder from sklearn.preprocessing.data import StandardScaler from sklearn.preprocessing.data import scale from sklearn.preprocessing.data import MinMaxScaler @@ -1968,6 +1971,236 @@ def test_one_hot_encoder_unknown_transform(): assert_raises(ValueError, oh.transform, y) +def check_categorical_onehot(X): + enc = CategoricalEncoder(encoding='onehot') + Xtr1 = enc.fit_transform(X) + + enc = CategoricalEncoder(encoding='onehot-dense') + Xtr2 = enc.fit_transform(X) + + assert_allclose(Xtr1.toarray(), Xtr2) + + assert sparse.isspmatrix_csr(Xtr1) + return Xtr1.toarray() + + +def test_categorical_encoder_onehot(): + X = [['abc', 1, 55], ['def', 2, 55]] + + Xtr = check_categorical_onehot(np.array(X)[:, [0]]) + assert_allclose(Xtr, [[1, 0], [0, 1]]) + + Xtr = check_categorical_onehot(np.array(X)[:, [0, 1]]) + assert_allclose(Xtr, [[1, 0, 1, 0], [0, 1, 0, 1]]) + + Xtr = CategoricalEncoder().fit_transform(X) + assert_allclose(Xtr.toarray(), [[1, 0, 1, 0, 1], [0, 1, 0, 1, 1]]) + + +def test_categorical_encoder_onehot_inverse(): + for encoding in ['onehot', 'onehot-dense']: + X = [['abc', 2, 55], ['def', 1, 55], ['abc', 3, 55]] + enc = CategoricalEncoder(encoding=encoding) + X_tr = enc.fit_transform(X) + exp = np.array(X, dtype=object) + assert_array_equal(enc.inverse_transform(X_tr), exp) + + X = [[2, 55], [1, 55], [3, 55]] + enc = CategoricalEncoder(encoding=encoding) + X_tr = enc.fit_transform(X) + exp = np.array(X) + assert_array_equal(enc.inverse_transform(X_tr), exp) + + # with unknown categories + X = [['abc', 2, 55], ['def', 1, 55], ['abc', 3, 55]] + enc = CategoricalEncoder(encoding=encoding, handle_unknown='ignore', + categories=[['abc', 'def'], [1, 2], + [54, 55, 56]]) + X_tr = enc.fit_transform(X) + exp = np.array(X, dtype=object) + exp[2, 1] = None + assert_array_equal(enc.inverse_transform(X_tr), exp) + + # with an otherwise numerical output, still object if unknown + X = [[2, 55], [1, 55], [3, 55]] + enc = CategoricalEncoder(encoding=encoding, + categories=[[1, 2], [54, 56]], + handle_unknown='ignore') + X_tr = enc.fit_transform(X) + exp = np.array(X, dtype=object) + exp[2, 0] = None + exp[:, 1] = None + assert_array_equal(enc.inverse_transform(X_tr), exp) + + # incorrect shape raises + X_tr = np.array([[0, 1, 1], [1, 0, 1]]) + msg = re.escape('Shape of the passed X data is not correct') + assert_raises_regex(ValueError, msg, enc.inverse_transform, X_tr) + + +def test_categorical_encoder_handle_unknown(): + X = np.array([[1, 2, 3], [4, 5, 6]]) + X2 = np.array([[7, 5, 3]]) + + # Test that encoder raises error for unknown features during transform. + enc = CategoricalEncoder() + enc.fit(X) + msg = re.escape('unknown categories [7] in column 0') + assert_raises_regex(ValueError, msg, enc.transform, X2) + + # With 'ignore' you get all 0's in result + enc = CategoricalEncoder(handle_unknown='ignore') + enc.fit(X) + X2_passed = X2.copy() + Xtr = enc.transform(X2_passed) + assert_allclose(Xtr.toarray(), [[0, 0, 0, 1, 1, 0]]) + # ensure transformed data was not modified in place + assert_allclose(X2, X2_passed) + + # Invalid option + enc = CategoricalEncoder(handle_unknown='invalid') + assert_raises(ValueError, enc.fit, X) + + +def test_categorical_encoder_categories(): + X = [['abc', 1, 55], ['def', 2, 55]] + + # order of categories should not depend on order of samples + for Xi in [X, X[::-1]]: + enc = CategoricalEncoder() + enc.fit(Xi) + assert enc.categories == 'auto' + assert isinstance(enc.categories_, list) + cat_exp = [['abc', 'def'], [1, 2], [55]] + for res, exp in zip(enc.categories_, cat_exp): + assert res.tolist() == exp + + +def test_categorical_encoder_specified_categories(): + X = np.array([['a', 'b']], dtype=object).T + + enc = CategoricalEncoder(categories=[['a', 'b', 'c']]) + exp = np.array([[1., 0., 0.], + [0., 1., 0.]]) + assert_array_equal(enc.fit_transform(X).toarray(), exp) + assert enc.categories[0] == ['a', 'b', 'c'] + assert enc.categories_[0].tolist() == ['a', 'b', 'c'] + assert np.issubdtype(enc.categories_[0].dtype, str) + + # unsorted passed categories raises for now + enc = CategoricalEncoder(categories=[['c', 'b', 'a']]) + msg = re.escape('Unsorted categories are not yet supported') + assert_raises_regex(ValueError, msg, enc.fit_transform, X) + + # multiple columns + X = np.array([['a', 'b'], [0, 2]], dtype=object).T + enc = CategoricalEncoder(categories=[['a', 'b', 'c'], [0, 1, 2]]) + exp = np.array([[1., 0., 0., 1., 0., 0.], + [0., 1., 0., 0., 0., 1.]]) + assert_array_equal(enc.fit_transform(X).toarray(), exp) + assert enc.categories_[0].tolist() == ['a', 'b', 'c'] + assert np.issubdtype(enc.categories_[0].dtype, str) + assert enc.categories_[1].tolist() == [0, 1, 2] + assert np.issubdtype(enc.categories_[1].dtype, np.integer) + + # when specifying categories manually, unknown categories should already + # raise when fitting + X = np.array([['a', 'b', 'c']]).T + enc = CategoricalEncoder(categories=[['a', 'b']]) + assert_raises(ValueError, enc.fit, X) + enc = CategoricalEncoder(categories=[['a', 'b']], handle_unknown='ignore') + exp = np.array([[1., 0.], [0., 1.], [0., 0.]]) + assert_array_equal(enc.fit(X).transform(X).toarray(), exp) + + +def test_categorical_encoder_pandas(): + try: + import pandas as pd + except ImportError: + raise SkipTest("pandas is not installed") + + X_df = pd.DataFrame({'A': ['a', 'b'], 'B': [1, 2]}) + + Xtr = check_categorical_onehot(X_df) + assert_allclose(Xtr, [[1, 0, 1, 0], [0, 1, 0, 1]]) + + +def test_categorical_encoder_ordinal(): + X = [['abc', 2, 55], ['def', 1, 55]] + + enc = CategoricalEncoder(encoding='other') + assert_raises(ValueError, enc.fit, X) + + enc = CategoricalEncoder(encoding='ordinal', handle_unknown='ignore') + assert_raises(ValueError, enc.fit, X) + + enc = CategoricalEncoder(encoding='ordinal') + exp = np.array([[0, 1, 0], + [1, 0, 0]], dtype='int64') + assert_array_equal(enc.fit_transform(X), exp.astype('float64')) + enc = CategoricalEncoder(encoding='ordinal', dtype='int64') + assert_array_equal(enc.fit_transform(X), exp) + + +def test_categorical_encoder_ordinal_inverse(): + X = [['abc', 2, 55], ['def', 1, 55]] + enc = CategoricalEncoder(encoding='ordinal') + X_tr = enc.fit_transform(X) + exp = np.array(X, dtype=object) + assert_array_equal(enc.inverse_transform(X_tr), exp) + + # incorrect shape raises + X_tr = np.array([[0, 1, 1, 2], [1, 0, 1, 0]]) + msg = re.escape('Shape of the passed X data is not correct') + assert_raises_regex(ValueError, msg, enc.inverse_transform, X_tr) + + +def test_categorical_encoder_dtypes(): + # check that dtypes are preserved when determining categories + enc = CategoricalEncoder() + exp = np.array([[1., 0., 1., 0.], [0., 1., 0., 1.]], dtype='float64') + + for X in [np.array([[1, 2], [3, 4]], dtype='int64'), + np.array([[1, 2], [3, 4]], dtype='float64'), + np.array([['a', 'b'], ['c', 'd']]), # string dtype + np.array([[1, 'a'], [3, 'b']], dtype='object')]: + enc.fit(X) + assert all([enc.categories_[i].dtype == X.dtype for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + X = [[1, 2], [3, 4]] + enc.fit(X) + assert all([np.issubdtype(enc.categories_[i].dtype, np.integer) + for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + X = [[1, 'a'], [3, 'b']] + enc.fit(X) + assert all([enc.categories_[i].dtype == 'object' for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + +def test_categorical_encoder_dtypes_pandas(): + # check dtype (similar to test_categorical_encoder_dtypes for dataframes) + try: + import pandas as pd + except ImportError: + raise SkipTest("pandas is not installed") + + enc = CategoricalEncoder() + exp = np.array([[1., 0., 1., 0.], [0., 1., 0., 1.]], dtype='float64') + + X = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}, dtype='int64') + enc.fit(X) + assert all([enc.categories_[i].dtype == 'int64' for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + X = pd.DataFrame({'A': [1, 2], 'B': ['a', 'b']}) + enc.fit(X) + assert all([enc.categories_[i].dtype == 'object' for i in range(2)]) + assert_array_equal(enc.transform(X).toarray(), exp) + + def test_fit_cold_start(): X = iris.data X_2d = X[:, :2] diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index 0e96057f929ff..3c81a2f86d35b 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -150,6 +150,108 @@ def sparse_min_max(X, axis): from scipy.misc import comb, logsumexp # noqa +if sp_version >= (0, 19): + def _argmax(arr_or_spmatrix, axis=None): + return arr_or_spmatrix.argmax(axis=axis) +else: + # Backport of argmax functionality from scipy 0.19.1, can be removed + # once support for scipy 0.18 and below is dropped + + def _find_missing_index(ind, n): + for k, a in enumerate(ind): + if k != a: + return k + + k += 1 + if k < n: + return k + else: + return -1 + + def _arg_min_or_max_axis(self, axis, op, compare): + if self.shape[axis] == 0: + raise ValueError("Can't apply the operation along a zero-sized " + "dimension.") + + if axis < 0: + axis += 2 + + zero = self.dtype.type(0) + + mat = self.tocsc() if axis == 0 else self.tocsr() + mat.sum_duplicates() + + ret_size, line_size = mat._swap(mat.shape) + ret = np.zeros(ret_size, dtype=int) + + nz_lines, = np.nonzero(np.diff(mat.indptr)) + for i in nz_lines: + p, q = mat.indptr[i:i + 2] + data = mat.data[p:q] + indices = mat.indices[p:q] + am = op(data) + m = data[am] + if compare(m, zero) or q - p == line_size: + ret[i] = indices[am] + else: + zero_ind = _find_missing_index(indices, line_size) + if m == zero: + ret[i] = min(am, zero_ind) + else: + ret[i] = zero_ind + + if axis == 1: + ret = ret.reshape(-1, 1) + + return np.asmatrix(ret) + + def _arg_min_or_max(self, axis, out, op, compare): + if out is not None: + raise ValueError("Sparse matrices do not support " + "an 'out' parameter.") + + # validateaxis(axis) + + if axis is None: + if 0 in self.shape: + raise ValueError("Can't apply the operation to " + "an empty matrix.") + + if self.nnz == 0: + return 0 + else: + zero = self.dtype.type(0) + mat = self.tocoo() + mat.sum_duplicates() + am = op(mat.data) + m = mat.data[am] + + if compare(m, zero): + return mat.row[am] * mat.shape[1] + mat.col[am] + else: + size = np.product(mat.shape) + if size == mat.nnz: + return am + else: + ind = mat.row * mat.shape[1] + mat.col + zero_ind = _find_missing_index(ind, size) + if m == zero: + return min(zero_ind, am) + else: + return zero_ind + + return _arg_min_or_max_axis(self, axis, op, compare) + + def _sparse_argmax(self, axis=None, out=None): + return _arg_min_or_max(self, axis, out, np.argmax, np.greater) + + def _argmax(arr_or_matrix, axis=None): + if sp.issparse(arr_or_matrix): + return _sparse_argmax(arr_or_matrix, axis=axis) + else: + return arr_or_matrix.argmax(axis=axis) + + def parallel_helper(obj, methodname, *args, **kwargs): """Workaround for Python 2 limitations of pickling instance methods""" return getattr(obj, methodname)(*args, **kwargs) diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 035a2e3175add..1406b74d7e3e0 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -521,7 +521,7 @@ def uninstall_mldata_mock(): 'LabelBinarizer', 'LabelEncoder', 'MultiLabelBinarizer', 'TfidfTransformer', 'TfidfVectorizer', 'IsotonicRegression', - 'OneHotEncoder', 'RandomTreesEmbedding', + 'OneHotEncoder', 'RandomTreesEmbedding', 'CategoricalEncoder', 'FeatureHasher', 'DummyClassifier', 'DummyRegressor', 'TruncatedSVD', 'PolynomialFeatures', 'GaussianRandomProjectionHash', 'HashingVectorizer',