diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 54210f2453cb0..2a76955518a30 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -489,7 +489,7 @@ Continuing the example above:: >>> enc = preprocessing.OneHotEncoder() >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] >>> enc.fit(X) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - OneHotEncoder(categorical_features=None, categories=None, + OneHotEncoder(categorical_features=None, categories=None, drop_first=False, dtype=<... 'numpy.float64'>, handle_unknown='error', n_values=None, sparse=True) >>> enc.transform([['female', 'from US', 'uses Safari'], @@ -516,7 +516,7 @@ dataset:: >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] >>> enc.fit(X) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE OneHotEncoder(categorical_features=None, - categories=[...], + categories=[...], drop_first=False, dtype=<... 'numpy.float64'>, handle_unknown='error', n_values=None, sparse=True) >>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray() @@ -533,12 +533,31 @@ columns for this feature will be all zeros >>> enc = preprocessing.OneHotEncoder(handle_unknown='ignore') >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] >>> enc.fit(X) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - OneHotEncoder(categorical_features=None, categories=None, + OneHotEncoder(categorical_features=None, categories=None, drop_first=False, dtype=<... 'numpy.float64'>, handle_unknown='ignore', n_values=None, sparse=True) >>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray() array([[1., 0., 0., 0., 0., 0.]]) +Using ``drop_first=True``, each column is encoded into ``n_categories - 1`` +columns instead of ``n_categories`` columns. In this case ``'handle_unknown'`` +must be set to ``'error'``. This is useful to avoid co-linearity in the +input matrix in non-regularized logistic regression +(:class:`LinearRegression `), which +would cause the covariance matrix to be non-invertible:: + + >>> enc = preprocessing.OneHotEncoder(drop_first=True) + >>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']] + >>> enc.fit(X) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + OneHotEncoder(categorical_features=None, categories=None, drop_first=True, + dtype=, handle_unknown='error', + n_values=None, sparse=True) + >>> enc.transform([['female', 'from US', 'uses Safari'], + ... ['male', 'from Europe', 'uses Safari'], + ... ['female', 'from US', 'uses Firefox']]).toarray() + array([[0., 1., 1.], + [1., 0., 1.], + [0., 1., 0.]]) See :ref:`dict_feature_extraction` for categorical features that are represented as a dict, not as scalars. diff --git a/sklearn/preprocessing/_encoders.py b/sklearn/preprocessing/_encoders.py index 2b78cd676ed0e..2e300e9be6524 100644 --- a/sklearn/preprocessing/_encoders.py +++ b/sklearn/preprocessing/_encoders.py @@ -175,6 +175,14 @@ class OneHotEncoder(_BaseEncoder): will be all zeros. In the inverse transform, an unknown category will be denoted as None. + drop_first: bool, default=False. + If ``True``, the columns of the first categories are dropped and each + variable is thus encoded into ``n_categories - 1`` columns. This is + useful in unregularized linear regression (:class:`LinearRegression + `) where co-linearity between + the input features must be avoided. If ``True``, ``handle_unkown`` + must be set to ``'error'``. + n_values : 'auto', int or array of ints, default='auto' Number of values per feature. @@ -246,7 +254,7 @@ class OneHotEncoder(_BaseEncoder): >>> enc.fit(X) ... # doctest: +ELLIPSIS ... # doctest: +NORMALIZE_WHITESPACE - OneHotEncoder(categorical_features=None, categories=None, + OneHotEncoder(categorical_features=None, categories=None, drop_first=False, dtype=<... 'numpy.float64'>, handle_unknown='ignore', n_values=None, sparse=True) @@ -278,13 +286,14 @@ class OneHotEncoder(_BaseEncoder): def __init__(self, n_values=None, categorical_features=None, categories=None, sparse=True, dtype=np.float64, - handle_unknown='error'): + handle_unknown='error', drop_first=False): self.categories = categories self.sparse = sparse self.dtype = dtype self.handle_unknown = handle_unknown self.n_values = n_values self.categorical_features = categorical_features + self.drop_first = drop_first # Deprecated attributes @@ -417,16 +426,30 @@ def fit(self, X, y=None): "got {0}.".format(self.handle_unknown)) raise ValueError(msg) + if self.drop_first and self.handle_unknown != 'error': + raise ValueError( + "handle_unkown must be 'error' to use drop_first=True") + # This is because with handle_unkown=ignore, an unkown category + # will be represented as all zeros in the corresponding feature + # columns. But this is also how we represent the first category + # (all zeros) when drop_first=True. + self._handle_deprecations(X) + if self._legacy_mode and self.drop_first: + raise ValueError( + 'Using drop_first=True requires you not to use any ' + 'deprecated parameter (categorical_featuers and/or n_values).' + ) + if self._legacy_mode: _transform_selected(X, self._legacy_fit_transform, self.dtype, self._categorical_features, copy=True) - return self else: self._fit(X, handle_unknown=self.handle_unknown) - return self + + return self def _legacy_fit_transform(self, X): """Assumes X contains only categorical features.""" @@ -509,6 +532,12 @@ def fit_transform(self, X, y=None): self._handle_deprecations(X) + if self._legacy_mode and self.drop_first: + raise ValueError( + 'Using drop_first=True requires you not to use any ' + 'deprecated parameter (categorical_featuers and/or n_values).' + ) + if self._legacy_mode: return _transform_selected( X, self._legacy_fit_transform, self.dtype, @@ -585,6 +614,14 @@ def _transform_new(self, X): out = sparse.csr_matrix((data, indices, indptr), shape=(n_samples, feature_indices[-1]), dtype=self.dtype) + + if self.drop_first: + # Remove the columns of the first categories + firsts = feature_indices[:-1] + to_keep = np.full(out.shape[1], True, dtype=np.bool) + to_keep[firsts] = False + out = out[:, to_keep] + if not self.sparse: return out.toarray() else: @@ -636,11 +673,17 @@ def inverse_transform(self, X): n_samples, _ = X.shape n_features = len(self.categories_) - n_transformed_features = sum([len(cats) for cats in self.categories_]) + if self.drop_first: + n_transformed_features = sum(len(cats) - 1 + for cats in self.categories_) + else: + n_transformed_features = sum(len(cats) + for cats in self.categories_) # validate shape of passed X msg = ("Shape of the passed X data is not correct. Expected {0} " "columns, got {1}.") + if X.shape[1] != n_transformed_features: raise ValueError(msg.format(n_transformed_features, X.shape[1])) @@ -653,10 +696,23 @@ def inverse_transform(self, X): for i in range(n_features): n_categories = len(self.categories_[i]) + if self.drop_first: + n_categories -= 1 # we 'removed' the first one + if n_categories == 0: + # Only happens if there was a column with a unique + # category. In this case we just fill the column with this + # unique category value. + X_tr[:, i] = self.categories_[i][0] + continue sub = X[:, j:j + n_categories] # for sparse X argmax returns 2D matrix, ensure 1D array labels = np.asarray(_argmax(sub, axis=1)).flatten() + if self.drop_first: + # first category (dropped): we have a row of all zero's + not_dropped = np.asarray(sub.sum(axis=1) != 0).flatten() + labels[not_dropped] += 1 + X_tr[:, i] = self.categories_[i][labels] if self.handle_unknown == 'ignore': diff --git a/sklearn/preprocessing/tests/test_encoders.py b/sklearn/preprocessing/tests/test_encoders.py index 792de88aa37de..5e65dbd1708f0 100644 --- a/sklearn/preprocessing/tests/test_encoders.py +++ b/sklearn/preprocessing/tests/test_encoders.py @@ -97,6 +97,28 @@ def test_one_hot_encoder_sparse(): enc.fit([[0], [1]]) assert_raises(ValueError, enc.transform, [[0], [-1]]) + # test drop_first=True and handle_unkown='ignore' + enc = OneHotEncoder(handle_unknown='ignore', drop_first=True) + assert_raises_regex( + ValueError, + "handle_unkown must be 'error' to use drop_first=True", + enc.fit, [[0], [-1]]) + + # test drop_first=True and legacy_mode=True + with ignore_warnings(category=(DeprecationWarning, FutureWarning)): + enc = OneHotEncoder(drop_first=True, n_values=1) + for method in (enc.fit, enc.fit_transform): + assert_raises_regex( + ValueError, + 'Using drop_first=True requires you not to use any ', + method, [[0], [-1]]) + + enc = OneHotEncoder(drop_first=True, categorical_features='all') + assert_raises_regex( + ValueError, + 'Using drop_first=True requires you not to use any ', + method, [[0], [-1]]) + def test_one_hot_encoder_dense(): # check for sparse=False @@ -362,21 +384,25 @@ def test_one_hot_encoder(X): assert_allclose(Xtr.toarray(), [[0, 1, 1, 0, 1], [1, 0, 0, 1, 1]]) -def test_one_hot_encoder_inverse(): - for sparse_ in [True, False]: - X = [['abc', 2, 55], ['def', 1, 55], ['abc', 3, 55]] - enc = OneHotEncoder(sparse=sparse_) - X_tr = enc.fit_transform(X) - exp = np.array(X, dtype=object) - assert_array_equal(enc.inverse_transform(X_tr), exp) +@pytest.mark.parametrize('sparse_', [False, True]) +@pytest.mark.parametrize('drop_first', [False, True]) +def test_one_hot_encoder_inverse(sparse_, drop_first): + X = [['abc', 2, 55], ['def', 1, 55], ['abc', 3, 55]] + enc = OneHotEncoder(sparse=sparse_, drop_first=drop_first) + X_tr = enc.fit_transform(X) + exp = np.array(X, dtype=object) + assert_array_equal(enc.inverse_transform(X_tr), exp) - X = [[2, 55], [1, 55], [3, 55]] - enc = OneHotEncoder(sparse=sparse_, categories='auto') - X_tr = enc.fit_transform(X) - exp = np.array(X) - assert_array_equal(enc.inverse_transform(X_tr), exp) + X = [[2, 55], [1, 55], [3, 55]] + enc = OneHotEncoder(sparse=sparse_, categories='auto', + drop_first=drop_first) + X_tr = enc.fit_transform(X) + exp = np.array(X) + assert_array_equal(enc.inverse_transform(X_tr), exp) + if not drop_first: # with unknown categories + # drop_first is incompatible with handle_unknown=ignore X = [['abc', 2, 55], ['def', 1, 55], ['abc', 3, 55]] enc = OneHotEncoder(sparse=sparse_, handle_unknown='ignore', categories=[['abc', 'def'], [1, 2], @@ -396,10 +422,10 @@ def test_one_hot_encoder_inverse(): exp[:, 1] = None assert_array_equal(enc.inverse_transform(X_tr), exp) - # incorrect shape raises - X_tr = np.array([[0, 1, 1], [1, 0, 1]]) - msg = re.escape('Shape of the passed X data is not correct') - assert_raises_regex(ValueError, msg, enc.inverse_transform, X_tr) + # incorrect shape raises + X_tr = np.array([[0, 1, 1], [1, 0, 1]]) + msg = re.escape('Shape of the passed X data is not correct') + assert_raises_regex(ValueError, msg, enc.inverse_transform, X_tr) @pytest.mark.parametrize("X, cat_exp, cat_dtype", [ @@ -487,6 +513,17 @@ def test_one_hot_encoder_specified_categories_mixed_columns(): assert np.issubdtype(enc.categories_[1].dtype, np.object_) +@pytest.mark.parametrize('X, exp', [ + ([[0], [5], [5]], [[0], [1], [1]]), # 2 cat => 1 columns + ([[0], [1], [2]], [[0, 0], [1, 0], [0, 1]]), # C cat => C - 1 columns + ([[0, 1], [1, 1], [0, 1]], [[0], [1], [0]]) # 1 cat => dropped column +]) +def test_one_hot_encoder_drop_first(X, exp): + + enc = OneHotEncoder(categories='auto', sparse=False, drop_first=True) + assert_array_equal(enc.fit_transform(X), exp) + + def test_one_hot_encoder_pandas(): pd = pytest.importorskip('pandas')