Skip to content

Commit 68e3f92

Browse files
committed
commit 11
rebased commit 8,9 commit 10
1 parent 9808810 commit 68e3f92

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

sklearn/preprocessing/data.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
from scipy import sparse
1818
from scipy import stats
19+
import sys
1920

2021
from ..base import BaseEstimator, TransformerMixin
2122
from ..externals import six
@@ -2655,6 +2656,8 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin):
26552656
>>> enc.inverse_transform([[0, 1, 1, 0, 0], [0, 0, 0, 1, 0]])
26562657
array([['Male', 1],
26572658
[None, 2]], dtype=object)
2659+
>>> enc.get_feature_names()
2660+
['x0_Female', 'x0_Male', 'x1_1', 'x1_2', 'x1_3']
26582661
26592662
See also
26602663
--------
@@ -2873,3 +2876,37 @@ def inverse_transform(self, X):
28732876
X_tr[mask, idx] = None
28742877

28752878
return X_tr
2879+
2880+
def get_feature_names(self, input_features=None):
2881+
"""Return feature names for output features
2882+
2883+
Parameters
2884+
----------
2885+
input_features : list of string, length n_features, optional
2886+
String names for input features if available. By default,
2887+
"x0", "x1", ... "xn_features" is used.
2888+
2889+
Returns
2890+
-------
2891+
output_feature_names : list of string, length n_output_features
2892+
2893+
"""
2894+
is_python3 = sys.version_info.major == 3
2895+
if is_python3:
2896+
unicode = str
2897+
2898+
cats = self.categories_
2899+
feature_names = []
2900+
if input_features is None:
2901+
input_features = ['x%d' % i for i in range(len(cats))]
2902+
elif(len(input_features) != len(self.categories_)):
2903+
raise ValueError('input_features should have length equal to number of features')
2904+
2905+
def to_unicode(text):
2906+
#If text is unicode, it is returned as is. If it's str, convert it to Unicode using UTF-8 encoding
2907+
return text if isinstance(text, unicode) else text.encode('utf8')
2908+
2909+
for i in range(len(cats)):
2910+
feature_names.extend( to_unicode(input_features[i] + '_' + str(t)) for t in cats[i])
2911+
2912+
return feature_names

sklearn/preprocessing/tests/test_data.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,3 +2224,30 @@ def test_quantile_transform_valid_axis():
22242224

22252225
assert_raises_regex(ValueError, "axis should be either equal to 0 or 1"
22262226
". Got axis=2", quantile_transform, X.T, axis=2)
2227+
2228+
2229+
def test_categorical_encoder_feature_names():
2230+
enc = CategoricalEncoder()
2231+
X = [['Male', 1, 'girl', 2, 3],
2232+
['Female', 41, 'girl', 1, 10],
2233+
['Male', 51, 'boy', 12, 3],
2234+
['Male', 91, 'girl', 21, 30]]
2235+
2236+
enc.fit(X)
2237+
feature_names = enc.get_feature_names()
2238+
2239+
assert_array_equal(['x0_Female', 'x0_Male',
2240+
'x1_1', 'x1_41', 'x1_51', 'x1_91',
2241+
'x2_boy', 'x2_girl',
2242+
'x3_1', 'x3_2', 'x3_12', 'x3_21',
2243+
'x4_3',
2244+
'x4_10', 'x4_30'], feature_names)
2245+
2246+
feature_names2 = enc.get_feature_names(['one', 'two',
2247+
'three', 'four', 'five'])
2248+
2249+
assert_array_equal(['one_Female', 'one_Male',
2250+
'two_1', 'two_41', 'two_51', 'two_91',
2251+
'three_boy', 'three_girl',
2252+
'four_1', 'four_2', 'four_12', 'four_21',
2253+
'five_3', 'five_10', 'five_30'], feature_names2)

0 commit comments

Comments
 (0)