Skip to content

Commit c3815c9

Browse files
thomasjpfanamuellerogrisellorentzenchr
authored andcommitted
API Implements get_feature_names_out for transformers that support get_feature_names (#18444)
Co-authored-by: Andreas Mueller <andreas.mueller@columbia.edu> Co-authored-by: Andreas Mueller <andreasmuellerml@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com> Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
1 parent 73327dc commit c3815c9

33 files changed

+1446
-160
lines changed

doc/glossary.rst

+12
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,7 @@ Class APIs and Estimator Types
894894
* :term:`fit`
895895
* :term:`transform`
896896
* :term:`get_feature_names`
897+
* :term:`get_feature_names_out`
897898

898899
meta-estimator
899900
meta-estimators
@@ -1262,6 +1263,17 @@ Methods
12621263
to the names of input columns from which output column names can
12631264
be generated. By default input features are named x0, x1, ....
12641265

1266+
``get_feature_names_out``
1267+
Primarily for :term:`feature extractors`, but also used for other
1268+
transformers to provide string names for each column in the output of
1269+
the estimator's :term:`transform` method. It outputs an array of
1270+
strings and may take an array-like of strings as input, corresponding
1271+
to the names of input columns from which output column names can
1272+
be generated. If `input_features` is not passed in, then the
1273+
`feature_names_in_` attribute will be used. If the
1274+
`feature_names_in_` attribute is not defined, then the
1275+
input names are named `[x0, x1, ..., x(n_features_in_)]`.
1276+
12651277
``get_n_splits``
12661278
On a :term:`CV splitter` (not an estimator), returns the number of
12671279
elements one would get if iterating through the return value of

doc/modules/compose.rst

+29-9
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,27 @@ or by name::
139139
>>> pipe['reduce_dim']
140140
PCA()
141141

142+
To enable model inspection, :class:`~sklearn.pipeline.Pipeline` has a
143+
``get_feature_names_out()`` method, just like all transformers. You can use
144+
pipeline slicing to get the feature names going into each step::
145+
146+
>>> from sklearn.datasets import load_iris
147+
>>> from sklearn.feature_selection import SelectKBest
148+
>>> iris = load_iris()
149+
>>> pipe = Pipeline(steps=[
150+
... ('select', SelectKBest(k=2)),
151+
... ('clf', LogisticRegression())])
152+
>>> pipe.fit(iris.data, iris.target)
153+
Pipeline(steps=[('select', SelectKBest(...)), ('clf', LogisticRegression(...))])
154+
>>> pipe[:-1].get_feature_names_out()
155+
array(['x2', 'x3'], ...)
156+
157+
You can also provide custom feature names for the input data using
158+
``get_feature_names_out``::
159+
160+
>>> pipe[:-1].get_feature_names_out(iris.feature_names)
161+
array(['petal length (cm)', 'petal width (cm)'], ...)
162+
142163
.. topic:: Examples:
143164

144165
* :ref:`sphx_glr_auto_examples_feature_selection_plot_feature_selection_pipeline.py`
@@ -426,21 +447,20 @@ By default, the remaining rating columns are ignored (``remainder='drop'``)::
426447
>>> from sklearn.feature_extraction.text import CountVectorizer
427448
>>> from sklearn.preprocessing import OneHotEncoder
428449
>>> column_trans = ColumnTransformer(
429-
... [('city_category', OneHotEncoder(dtype='int'),['city']),
450+
... [('categories', OneHotEncoder(dtype='int'), ['city']),
430451
... ('title_bow', CountVectorizer(), 'title')],
431-
... remainder='drop')
452+
... remainder='drop', prefix_feature_names_out=False)
432453

433454
>>> column_trans.fit(X)
434-
ColumnTransformer(transformers=[('city_category', OneHotEncoder(dtype='int'),
455+
ColumnTransformer(prefix_feature_names_out=False,
456+
transformers=[('categories', OneHotEncoder(dtype='int'),
435457
['city']),
436458
('title_bow', CountVectorizer(), 'title')])
437459

438-
>>> column_trans.get_feature_names()
439-
['city_category__x0_London', 'city_category__x0_Paris', 'city_category__x0_Sallisaw',
440-
'title_bow__bow', 'title_bow__feast', 'title_bow__grapes', 'title_bow__his',
441-
'title_bow__how', 'title_bow__last', 'title_bow__learned', 'title_bow__moveable',
442-
'title_bow__of', 'title_bow__the', 'title_bow__trick', 'title_bow__watson',
443-
'title_bow__wrath']
460+
>>> column_trans.get_feature_names_out()
461+
array(['city_London', 'city_Paris', 'city_Sallisaw', 'bow', 'feast',
462+
'grapes', 'his', 'how', 'last', 'learned', 'moveable', 'of', 'the',
463+
'trick', 'watson', 'wrath'], ...)
444464

445465
>>> column_trans.transform(X).toarray()
446466
array([[1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],

doc/modules/feature_extraction.rst

+20-24
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.. _feature_extraction:
1+
.. _feature_extraction:
22

33
==================
44
Feature extraction
@@ -53,8 +53,8 @@ is a traditional numerical feature::
5353
[ 0., 1., 0., 12.],
5454
[ 0., 0., 1., 18.]])
5555

56-
>>> vec.get_feature_names()
57-
['city=Dubai', 'city=London', 'city=San Francisco', 'temperature']
56+
>>> vec.get_feature_names_out()
57+
array(['city=Dubai', 'city=London', 'city=San Francisco', 'temperature'], ...)
5858

5959
:class:`DictVectorizer` accepts multiple string values for one
6060
feature, like, e.g., multiple categories for a movie.
@@ -69,10 +69,9 @@ and its year of release.
6969
array([[0.000e+00, 1.000e+00, 0.000e+00, 1.000e+00, 2.003e+03],
7070
[1.000e+00, 0.000e+00, 1.000e+00, 0.000e+00, 2.011e+03],
7171
[0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.974e+03]])
72-
>>> vec.get_feature_names() == ['category=animation', 'category=drama',
73-
... 'category=family', 'category=thriller',
74-
... 'year']
75-
True
72+
>>> vec.get_feature_names_out()
73+
array(['category=animation', 'category=drama', 'category=family',
74+
'category=thriller', 'year'], ...)
7675
>>> vec.transform({'category': ['thriller'],
7776
... 'unseen_feature': '3'}).toarray()
7877
array([[0., 0., 0., 1., 0.]])
@@ -111,8 +110,9 @@ suitable for feeding into a classifier (maybe after being piped into a
111110
with 6 stored elements in Compressed Sparse ... format>
112111
>>> pos_vectorized.toarray()
113112
array([[1., 1., 1., 1., 1., 1.]])
114-
>>> vec.get_feature_names()
115-
['pos+1=PP', 'pos-1=NN', 'pos-2=DT', 'word+1=on', 'word-1=cat', 'word-2=the']
113+
>>> vec.get_feature_names_out()
114+
array(['pos+1=PP', 'pos-1=NN', 'pos-2=DT', 'word+1=on', 'word-1=cat',
115+
'word-2=the'], ...)
116116

117117
As you can imagine, if one extracts such a context around each individual
118118
word of a corpus of documents the resulting matrix will be very wide
@@ -340,10 +340,9 @@ Each term found by the analyzer during the fit is assigned a unique
340340
integer index corresponding to a column in the resulting matrix. This
341341
interpretation of the columns can be retrieved as follows::
342342

343-
>>> vectorizer.get_feature_names() == (
344-
... ['and', 'document', 'first', 'is', 'one',
345-
... 'second', 'the', 'third', 'this'])
346-
True
343+
>>> vectorizer.get_feature_names_out()
344+
array(['and', 'document', 'first', 'is', 'one', 'second', 'the',
345+
'third', 'this'], ...)
347346

348347
>>> X.toarray()
349348
array([[0, 1, 1, 1, 0, 0, 1, 0, 1],
@@ -406,8 +405,8 @@ however, similar words are useful for prediction, such as in classifying
406405
writing style or personality.
407406

408407
There are several known issues in our provided 'english' stop word list. It
409-
does not aim to be a general, 'one-size-fits-all' solution as some tasks
410-
may require a more custom solution. See [NQY18]_ for more details.
408+
does not aim to be a general, 'one-size-fits-all' solution as some tasks
409+
may require a more custom solution. See [NQY18]_ for more details.
411410

412411
Please take care in choosing a stop word list.
413412
Popular stop word lists may include words that are highly informative to
@@ -742,9 +741,8 @@ decide better::
742741

743742
>>> ngram_vectorizer = CountVectorizer(analyzer='char_wb', ngram_range=(2, 2))
744743
>>> counts = ngram_vectorizer.fit_transform(['words', 'wprds'])
745-
>>> ngram_vectorizer.get_feature_names() == (
746-
... [' w', 'ds', 'or', 'pr', 'rd', 's ', 'wo', 'wp'])
747-
True
744+
>>> ngram_vectorizer.get_feature_names_out()
745+
array([' w', 'ds', 'or', 'pr', 'rd', 's ', 'wo', 'wp'], ...)
748746
>>> counts.toarray().astype(int)
749747
array([[1, 1, 1, 0, 1, 1, 1, 0],
750748
[1, 1, 0, 1, 1, 1, 0, 1]])
@@ -758,17 +756,15 @@ span across words::
758756
>>> ngram_vectorizer.fit_transform(['jumpy fox'])
759757
<1x4 sparse matrix of type '<... 'numpy.int64'>'
760758
with 4 stored elements in Compressed Sparse ... format>
761-
>>> ngram_vectorizer.get_feature_names() == (
762-
... [' fox ', ' jump', 'jumpy', 'umpy '])
763-
True
759+
>>> ngram_vectorizer.get_feature_names_out()
760+
array([' fox ', ' jump', 'jumpy', 'umpy '], ...)
764761

765762
>>> ngram_vectorizer = CountVectorizer(analyzer='char', ngram_range=(5, 5))
766763
>>> ngram_vectorizer.fit_transform(['jumpy fox'])
767764
<1x5 sparse matrix of type '<... 'numpy.int64'>'
768765
with 5 stored elements in Compressed Sparse ... format>
769-
>>> ngram_vectorizer.get_feature_names() == (
770-
... ['jumpy', 'mpy f', 'py fo', 'umpy ', 'y fox'])
771-
True
766+
>>> ngram_vectorizer.get_feature_names_out()
767+
array(['jumpy', 'mpy f', 'py fo', 'umpy ', 'y fox'], ...)
772768

773769
The word boundaries-aware variant ``char_wb`` is especially interesting
774770
for languages that use white-spaces for word separation as it generates

doc/whats_new/v1.0.rst

+7
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ Changelog
134134
- |API| `np.matrix` usage is deprecated in 1.0 and will raise a `TypeError` in
135135
1.2. :pr:`20165` by `Thomas Fan`_.
136136

137+
- |API| :term:`get_feature_names_out` has been added to the transformer API
138+
to get the names of the output features. :pr:`18444` by `Thomas Fan`_.
139+
137140
- |API| All estimators store `feature_names_in_` when fitted on pandas Dataframes.
138141
These feature names are compared to names seen in `non-fit` methods,
139142
`i.e.` `transform` and will raise a `FutureWarning` if they are not consistent.
@@ -225,6 +228,10 @@ Changelog
225228
:mod:`sklearn.compose`
226229
......................
227230

231+
- |API| Adds `prefix_feature_names_out` to :class:`compose.ColumnTransformer`.
232+
This flag controls the prefixing of feature names out in
233+
:term:`get_feature_names_out`. :pr:`18444` by `Thomas Fan`_.
234+
228235
- |Enhancement| :class:`compose.ColumnTransformer` now records the output
229236
of each transformer in `output_indices_`. :pr:`18393` by
230237
:user:`Luca Bittarello <lbittarello>`.

examples/applications/plot_topics_extraction_with_nmf_lda.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def plot_top_words(model, feature_names, n_top_words, title):
103103
print("done in %0.3fs." % (time() - t0))
104104

105105

106-
tfidf_feature_names = tfidf_vectorizer.get_feature_names()
106+
tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()
107107
plot_top_words(nmf, tfidf_feature_names, n_top_words,
108108
'Topics in NMF model (Frobenius norm)')
109109

@@ -117,7 +117,7 @@ def plot_top_words(model, feature_names, n_top_words, title):
117117
l1_ratio=.5).fit(tfidf)
118118
print("done in %0.3fs." % (time() - t0))
119119

120-
tfidf_feature_names = tfidf_vectorizer.get_feature_names()
120+
tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()
121121
plot_top_words(nmf, tfidf_feature_names, n_top_words,
122122
'Topics in NMF model (generalized Kullback-Leibler divergence)')
123123

@@ -132,5 +132,5 @@ def plot_top_words(model, feature_names, n_top_words, title):
132132
lda.fit(tf)
133133
print("done in %0.3fs." % (time() - t0))
134134

135-
tf_feature_names = tf_vectorizer.get_feature_names()
135+
tf_feature_names = tf_vectorizer.get_feature_names_out()
136136
plot_top_words(lda, tf_feature_names, n_top_words, 'Topics in LDA model')

examples/bicluster/plot_bicluster_newsgroups.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def build_tokenizer(self):
8989
time() - start_time,
9090
v_measure_score(y_kmeans, y_true)))
9191

92-
feature_names = vectorizer.get_feature_names()
92+
feature_names = vectorizer.get_feature_names_out()
9393
document_names = list(newsgroups.target_names[i] for i in newsgroups.target)
9494

9595

examples/inspection/plot_linear_model_coefficient_interpretation.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@
133133
numerical_columns = ["EDUCATION", "EXPERIENCE", "AGE"]
134134

135135
preprocessor = make_column_transformer(
136-
(OneHotEncoder(drop="if_binary"), categorical_columns), remainder="passthrough"
136+
(OneHotEncoder(drop="if_binary"), categorical_columns),
137+
remainder="passthrough",
138+
prefix_feature_names_out=False,
137139
)
138140

139141
# %%
@@ -199,13 +201,7 @@
199201
#
200202
# First of all, we can take a look to the values of the coefficients of the
201203
# regressor we have fitted.
202-
203-
feature_names = (
204-
model.named_steps["columntransformer"]
205-
.named_transformers_["onehotencoder"]
206-
.get_feature_names(input_features=categorical_columns)
207-
)
208-
feature_names = np.concatenate([feature_names, numerical_columns])
204+
feature_names = model[:-1].get_feature_names_out()
209205

210206
coefs = pd.DataFrame(
211207
model.named_steps["transformedtargetregressor"].regressor_.coef_,

examples/inspection/plot_permutation_importance.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@
120120
# capacity).
121121
ohe = (rf.named_steps['preprocess']
122122
.named_transformers_['cat'])
123-
feature_names = ohe.get_feature_names(input_features=categorical_columns)
123+
feature_names = ohe.get_feature_names_out(categorical_columns)
124124
feature_names = np.r_[feature_names, numerical_columns]
125125

126126
tree_feature_importances = (

examples/text/plot_document_classification_20newsgroups.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def size_mb(docs):
174174
if opts.use_hashing:
175175
feature_names = None
176176
else:
177-
feature_names = vectorizer.get_feature_names()
177+
feature_names = vectorizer.get_feature_names_out()
178178

179179
if opts.select_chi2:
180180
print("Extracting %d best features by a chi-squared test" %
@@ -183,16 +183,12 @@ def size_mb(docs):
183183
ch2 = SelectKBest(chi2, k=opts.select_chi2)
184184
X_train = ch2.fit_transform(X_train, y_train)
185185
X_test = ch2.transform(X_test)
186-
if feature_names:
186+
if feature_names is not None:
187187
# keep selected feature names
188-
feature_names = [feature_names[i] for i
189-
in ch2.get_support(indices=True)]
188+
feature_names = feature_names[ch2.get_support()]
190189
print("done in %fs" % (time() - t0))
191190
print()
192191

193-
if feature_names:
194-
feature_names = np.asarray(feature_names)
195-
196192

197193
def trim(s):
198194
"""Trim string to fit on terminal (assuming 80-column display)"""

examples/text/plot_document_clustering.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def is_interactive():
217217
else:
218218
order_centroids = km.cluster_centers_.argsort()[:, ::-1]
219219

220-
terms = vectorizer.get_feature_names()
220+
terms = vectorizer.get_feature_names_out()
221221
for i in range(true_k):
222222
print("Cluster %d:" % i, end='')
223223
for ind in order_centroids[i, :10]:

examples/text/plot_hashing_vs_dict_vectorizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def token_freqs(doc):
8989
vectorizer.fit_transform(token_freqs(d) for d in raw_data)
9090
duration = time() - t0
9191
print("done in %fs at %0.3fMB/s" % (duration, data_size_mb / duration))
92-
print("Found %d unique terms" % len(vectorizer.get_feature_names()))
92+
print("Found %d unique terms" % len(vectorizer.get_feature_names_out()))
9393
print()
9494

9595
print("FeatureHasher on frequency dicts")

sklearn/base.py

+30
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .utils.validation import check_array
2424
from .utils.validation import _check_y
2525
from .utils.validation import _num_features
26+
from .utils.validation import _check_feature_names_in
2627
from .utils._estimator_html_repr import estimator_html_repr
2728
from .utils.validation import _get_feature_names
2829

@@ -846,6 +847,35 @@ def fit_transform(self, X, y=None, **fit_params):
846847
return self.fit(X, y, **fit_params).transform(X)
847848

848849

850+
class _OneToOneFeatureMixin:
851+
"""Provides `get_feature_names_out` for simple transformers.
852+
853+
Assumes there's a 1-to-1 correspondence between input features
854+
and output features.
855+
"""
856+
857+
def get_feature_names_out(self, input_features=None):
858+
"""Get output feature names for transformation.
859+
860+
Parameters
861+
----------
862+
input_features : array-like of str or None, default=None
863+
Input features.
864+
865+
- If `input_features` is `None`, then `feature_names_in_` is
866+
used as feature names in. If `feature_names_in_` is not defined,
867+
then names are generated: `[x0, x1, ..., x(n_features_in_)]`.
868+
- If `input_features` is an array-like, then `input_features` must
869+
match `feature_names_in_` if `feature_names_in_` is defined.
870+
871+
Returns
872+
-------
873+
feature_names_out : ndarray of str objects
874+
Same as input features.
875+
"""
876+
return _check_feature_names_in(self, input_features)
877+
878+
849879
class DensityMixin:
850880
"""Mixin class for all density estimators in scikit-learn."""
851881

0 commit comments

Comments
 (0)