Skip to content

Commit 61fa315

Browse files
Nirvan101amueller
authored andcommitted
[MRG] Add get_feature_names to OneHotEncoder (#10198)
**Reference Issues/PRs** Fixes #10181 **What does this implement/fix? Explain your changes.** Added function **get_feature_names()** to **CategoricalEncoder** class. This is in `data.py` under `sklearn.preprocessing` <!-- Please be aware that we are a loose team of volunteers so patience is necessary; assistance handling other issues is very welcome. We value all user contributions, no matter how minor they are. If we are slow to review, either the pull request needs some benchmarking, tinkering, convincing, etc. or more likely the reviewers are simply busy. In either case, we ask for your understanding during the review process. For more information, see our FAQ on this topic: http://scikit-learn.org/dev/faq.html#why-is-my-pull-request-not-getting-any-attention. Thanks for contributing! -->
1 parent 726fa36 commit 61fa315

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

doc/whats_new/v0.20.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,11 @@ Preprocessing
302302
:class:`feature_extraction.text.CountVectorizer` initialized with a
303303
vocabulary. :issue:`10908` by :user:`Mohamed Maskani <maskani-moh>`.
304304

305+
- :class:`preprocessing.OneHotEncoder` now supports the
306+
:meth:`get_feature_names` method to obtain the transformed feature names.
307+
:issue:`10181` by :user:`Nirvan Anjirbag <Nirvan101>` and
308+
`Joris Van den Bossche`_.
309+
305310
- The ``transform`` method of :class:`sklearn.preprocessing.MultiLabelBinarizer`
306311
now ignores any unknown classes. A warning is raised stating the unknown classes
307312
classes found which are ignored.

sklearn/preprocessing/_encoders.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ class OneHotEncoder(_BaseEncoder):
240240
>>> enc.inverse_transform([[0, 1, 1, 0, 0], [0, 0, 0, 1, 0]])
241241
array([['Male', 1],
242242
[None, 2]], dtype=object)
243+
>>> enc.get_feature_names()
244+
array(['x0_Female', 'x0_Male', 'x1_1', 'x1_2', 'x1_3'], dtype=object)
243245
244246
See also
245247
--------
@@ -639,6 +641,38 @@ def inverse_transform(self, X):
639641

640642
return X_tr
641643

644+
def get_feature_names(self, input_features=None):
645+
"""Return feature names for output features.
646+
647+
Parameters
648+
----------
649+
input_features : list of string, length n_features, optional
650+
String names for input features if available. By default,
651+
"x0", "x1", ... "xn_features" is used.
652+
653+
Returns
654+
-------
655+
output_feature_names : array of string, length n_output_features
656+
657+
"""
658+
check_is_fitted(self, 'categories_')
659+
cats = self.categories_
660+
if input_features is None:
661+
input_features = ['x%d' % i for i in range(len(cats))]
662+
elif(len(input_features) != len(self.categories_)):
663+
raise ValueError(
664+
"input_features should have length equal to number of "
665+
"features ({}), got {}".format(len(self.categories_),
666+
len(input_features)))
667+
668+
feature_names = []
669+
for i in range(len(cats)):
670+
names = [
671+
input_features[i] + '_' + six.text_type(t) for t in cats[i]]
672+
feature_names.extend(names)
673+
674+
return np.array(feature_names, dtype=object)
675+
642676

643677
class OrdinalEncoder(_BaseEncoder):
644678
"""Encode categorical features as an integer array.

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# -*- coding: utf-8 -*-
12
from __future__ import division
23

34
import re
@@ -455,6 +456,47 @@ def test_one_hot_encoder_pandas():
455456
assert_allclose(Xtr, [[1, 0, 1, 0], [0, 1, 0, 1]])
456457

457458

459+
def test_one_hot_encoder_feature_names():
460+
enc = OneHotEncoder()
461+
X = [['Male', 1, 'girl', 2, 3],
462+
['Female', 41, 'girl', 1, 10],
463+
['Male', 51, 'boy', 12, 3],
464+
['Male', 91, 'girl', 21, 30]]
465+
466+
enc.fit(X)
467+
feature_names = enc.get_feature_names()
468+
assert isinstance(feature_names, np.ndarray)
469+
470+
assert_array_equal(['x0_Female', 'x0_Male',
471+
'x1_1', 'x1_41', 'x1_51', 'x1_91',
472+
'x2_boy', 'x2_girl',
473+
'x3_1', 'x3_2', 'x3_12', 'x3_21',
474+
'x4_3',
475+
'x4_10', 'x4_30'], feature_names)
476+
477+
feature_names2 = enc.get_feature_names(['one', 'two',
478+
'three', 'four', 'five'])
479+
480+
assert_array_equal(['one_Female', 'one_Male',
481+
'two_1', 'two_41', 'two_51', 'two_91',
482+
'three_boy', 'three_girl',
483+
'four_1', 'four_2', 'four_12', 'four_21',
484+
'five_3', 'five_10', 'five_30'], feature_names2)
485+
486+
with pytest.raises(ValueError, match="input_features should have length"):
487+
enc.get_feature_names(['one', 'two'])
488+
489+
490+
def test_one_hot_encoder_feature_names_unicode():
491+
enc = OneHotEncoder()
492+
X = np.array([[u'c❤t1', u'dat2']], dtype=object).T
493+
enc.fit(X)
494+
feature_names = enc.get_feature_names()
495+
assert_array_equal([u'x0_c❤t1', u'x0_dat2'], feature_names)
496+
feature_names = enc.get_feature_names(input_features=[u'n👍me'])
497+
assert_array_equal([u'n👍me_c❤t1', u'n👍me_dat2'], feature_names)
498+
499+
458500
@pytest.mark.parametrize("X", [
459501
[['abc', 2, 55], ['def', 1, 55]],
460502
np.array([[10, 2, 55], [20, 1, 55]]),

0 commit comments

Comments
 (0)