Skip to content

[WIP] Add feature_extraction.ColumnTransformer #3886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ partial dependence

feature_extraction.DictVectorizer
feature_extraction.FeatureHasher
feature_extraction.ColumnTransformer

From images
-----------
Expand Down
56 changes: 56 additions & 0 deletions doc/modules/feature_extraction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,62 @@ of the time. So as to make the resulting data structure able to fit in
memory the ``DictVectorizer`` class uses a ``scipy.sparse`` matrix by
default instead of a ``numpy.ndarray``.

.. _column_transformer:

Columnar Data
=============
Many datasets contain features of different types, say text, floats and dates,
where each type of feature requires separate preprocessing.
Often it is easiest to preprocess data before applying scikit-learn methods, for example using
pandas.
If the preprocessing has parameters that you want to adjust within a
grid-search, however, they need to be inside a transformer. This can be
achieved very simply with the :class:`ColumnTransformer`. The
:class:`ColumnTransformer` works on pandas dataframe, dictionaries, and other
objects that implement ``getattr`` so select a certain attribute or column.

.. note::
:class:`ColumnTransformer` expects a very different data format from the numpy arrays usually used in scikit-learn.
For a numpy array ``X_array``, ``X_array[1]`` will give a single sample (``X_array[1].shape == (n_samples.)``), but all features.
For columnar data like a dict or pandas dataframe ``X_columns``, ``X_columns[1]`` is expected to give a feature called
``1`` for each sample (``X_columns[1].shape == (n_samples,)``).

To each column, a different transformation can be applied, such as
preprocessing or a specific feature extraction method::

>>> X = {'city': ['London', 'London', 'Paris', 'New York'],
... 'title': ["His Last Bow", "How Watson Learned the Trick", "A Moveable Feast", "The Great Gatsby"]}

In contrast to the :class:`DictVectorizer` here the whole dataset is a dict,
with each value having the same lenght ``n_samples``.
For this data, we might want to apply a :class:`OneHotEncoder` to the
``'city'`` column, but a :class:`CountVectorizer` to the ``'title'`` column.
As we might use multiple feature extraction methods on the same column, we give each
transformer a unique name, say ``'city_category'`` and ``'title_bow'``::

>>> from sklearn.feature_extraction import ColumnTransformer
>>> from sklearn.preprocessing import OneHotEncoder
>>> from sklearn.feature_extraction.text import CountVectorizer
>>> column_trans = ColumnTransformer({'city_category': (CountVectorizer(analyzer=lambda x: [x]), 'city'),
... 'title_bow': (CountVectorizer(), 'title')})

>>> column_trans.fit(X) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
ColumnTransformer(n_jobs=1, transformer_weights=None,
transformers=...)

>>> column_trans.get_feature_names() == [
... 'city_category__London', 'city_category__New York', 'city_category__Paris',
... 'title_bow__bow', 'title_bow__feast', 'title_bow__gatsby',
... 'title_bow__great', 'title_bow__his', 'title_bow__how', 'title_bow__last',
... 'title_bow__learned', 'title_bow__moveable', 'title_bow__the',
... 'title_bow__trick', 'title_bow__watson']
True

>>> column_trans.transform(X).toarray() # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
array([[1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1],
[0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0]]...)

.. _feature_hashing:

Expand Down
10 changes: 6 additions & 4 deletions doc/modules/pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,11 @@ FeatureUnion: composite feature spaces
:class:`FeatureUnion` combines several transformer objects into a new
transformer that combines their output. A :class:`FeatureUnion` takes
a list of transformer objects. During fitting, each of these
is fit to the data independently. For transforming data, the
transformers are applied in parallel, and the sample vectors they output
are concatenated end-to-end into larger vectors.
is fit to the data independently. It can also be used to apply different
transformations to each field of the data, producing a homogeneous feature
matrix from a heterogeneous data source.
The transformers are applied in parallel, and the feature matrices they output
are concatenated side-by-side into a larger matrix.

:class:`FeatureUnion` serves the same purposes as :class:`Pipeline` -
convenience and joint parameter estimation and validation.
Expand Down Expand Up @@ -166,4 +168,4 @@ Like pipelines, feature unions have a shorthand constructor called
.. topic:: Examples:

* :ref:`example_feature_stacker.py`
* :ref:`example_hetero_feature_union.py`
* :ref:`example_hetero_feature_union.py` illustrates the ``fields`` parameter.
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ New features
By `Alexandre Gramfort`_, `Jan Hendrik Metzen`_, `Mathieu Blondel`_
and `Balazs Kegl`_.

- :class:`pipeline.FeatureUnion` now allows the extraction of particular
features from dictionaries or pandas dataframes via the ``fields``
parameter. By `Andreas Müller`_.

Enhancements
............
Expand Down
72 changes: 13 additions & 59 deletions examples/hetero_feature_union.py → examples/column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,48 +38,9 @@
from sklearn.feature_extraction import DictVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import classification_report
from sklearn.pipeline import FeatureUnion
from sklearn.feature_extraction import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC


class ItemSelector(BaseEstimator, TransformerMixin):
"""For data grouped by feature, select subset of data at a provided key.

The data is expected to be stored in a 2D data structure, where the first
index is over features and the second is over samples. i.e.

>> len(data[key]) == n_samples

Please note that this is the opposite convention to sklearn feature
matrixes (where the first index corresponds to sample).

ItemSelector only requires that the collection implement getitem
(data[key]). Examples include: a dict of lists, 2D numpy array, Pandas
DataFrame, numpy record array, etc.

>> data = {'a': [1, 5, 2, 5, 2, 8],
'b': [9, 4, 1, 4, 1, 3]}
>> ds = ItemSelector(key='a')
>> data['a'] == ds.transform(data)

ItemSelector is not designed to handle data grouped by sample. (e.g. a
list of dicts). If your data is structured this way, consider a
transformer along the lines of `sklearn.feature_extraction.DictVectorizer`.

Parameters
----------
key : hashable, required
The key corresponding to the desired value in a mappable.
"""
def __init__(self, key):
self.key = key

def fit(self, x, y=None):
return self

def transform(self, data_dict):
return data_dict[self.key]
from sklearn.svm import LinearSVC


class TextStats(BaseEstimator, TransformerMixin):
Expand Down Expand Up @@ -128,41 +89,34 @@ def transform(self, posts):
('subjectbody', SubjectBodyExtractor()),

# Use FeatureUnion to combine the features from subject and body
('union', FeatureUnion(
transformer_list=[

# Pipeline for pulling features from the post's subject line
('subject', Pipeline([
('selector', ItemSelector(key='subject')),
('tfidf', TfidfVectorizer(min_df=50)),
])),
('union', ColumnTransformer(
{
# Pulling features from the post's subject line
'subject': (TfidfVectorizer(min_df=50), 'subject'),

# Pipeline for standard bag-of-words model for body
('body_bow', Pipeline([
('selector', ItemSelector(key='body')),
'body_bow': (Pipeline([
('tfidf', TfidfVectorizer()),
('best', TruncatedSVD(n_components=50)),
])),
]), 'body'),

# Pipeline for pulling ad hoc features from post's body
('body_stats', Pipeline([
('selector', ItemSelector(key='body')),
'body_stats': (Pipeline([
('stats', TextStats()), # returns a list of dicts
('vect', DictVectorizer()), # list of dicts -> feature matrix
])),

],
]), 'body'),
},

# weight components in FeatureUnion
transformer_weights={
'subject': 0.8,
'body_bow': 0.5,
'body_stats': 1.0,
},
}
)),

# Use a SVC classifier on the combined features
('svc', SVC(kernel='linear')),
('svc', LinearSVC(dual=False)),
])

# limit the list of categories to make running this exmaple faster.
Expand Down
3 changes: 2 additions & 1 deletion sklearn/feature_extraction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
"""

from .dict_vectorizer import DictVectorizer
from .heterogeneous import ColumnTransformer
from .hashing import FeatureHasher
from .image import img_to_graph, grid_to_graph
from . import text

__all__ = ['DictVectorizer', 'image', 'img_to_graph', 'grid_to_graph', 'text',
'FeatureHasher']
'FeatureHasher', 'ColumnTransformer']
148 changes: 148 additions & 0 deletions sklearn/feature_extraction/heterogeneous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from scipy import sparse
import numpy as np

from ..base import BaseEstimator, TransformerMixin
from ..pipeline import _fit_one_transformer, _fit_transform_one, _transform_one
from ..externals.joblib import Parallel, delayed
from ..externals.six import iteritems


class ColumnTransformer(BaseEstimator, TransformerMixin):
"""Applies transformers to columns of a dataframe / dict.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I see of point of this transformer on dataframe and dicts, I find it too bad we cannot apply it on Numpy arrays. I would love to have see a built-in to apply transformers on selected columns only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Coming late to the party, this might have been discussed before...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be pretty easy with the FunctionTransformer #4798

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, +1


This estimator applies transformer objects to columns or fields of the
input, then concatenates the results. This is useful for heterogeneous or
columnar data, to combine several feature extraction mechanisms into a
single transformer.

Read more in the :ref:`User Guide <column_transformer>`.

Parameters
----------
transformers : dict from string to (string, transformer) tuples
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation is expected the dict values to be (transformer, string) tuples, and not (string, transformer) as documented here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, does the key used to access the column always need to be a string? Eg. what if I use a int to a access the n-th column, or even a list to access several columns at once?

Keys are arbitrary names, values are tuples of column names and
transformer objects.

n_jobs : int, optional
Number of jobs to run in parallel (default 1).

transformer_weights : dict, optional
Multiplicative weights for features per transformer.
Keys are transformer names, values the weights.

Examples
--------
>>> from sklearn.preprocessing import Normalizer
>>> union = ColumnTransformer({"norm1": (Normalizer(norm='l1'), 'subset1'), \
"norm2": (Normalizer(norm='l1'), 'subset2')})
>>> X = {'subset1': [[0., 1.], [2., 2.]], 'subset2': [[1., 1.], [0., 1.]]}
>>> union.fit_transform(X) # doctest: +NORMALIZE_WHITESPACE
array([[ 0. , 1. , 0.5, 0.5],
[ 0.5, 0.5, 0. , 1. ]])

"""
def __init__(self, transformers, n_jobs=1, transformer_weights=None):
self.transformers = transformers
self.n_jobs = n_jobs
self.transformer_weights = transformer_weights

def get_feature_names(self):
"""Get feature names from all transformers.

Returns
-------
feature_names : list of strings
Names of the features produced by transform.
"""
feature_names = []
for name, (trans, column) in sorted(self.transformers.items()):
if not hasattr(trans, 'get_feature_names'):
raise AttributeError("Transformer %s does not provide"
" get_feature_names." % str(name))
feature_names.extend([name + "__" + f for f in
trans.get_feature_names()])
return feature_names

def get_params(self, deep=True):
if not deep:
return super(ColumnTransformer, self).get_params(deep=False)
else:
out = dict(self.transformers)
for name, (trans, _) in self.transformers.items():
for key, value in iteritems(trans.get_params(deep=True)):
out['%s__%s' % (name, key)] = value
out.update(super(ColumnTransformer, self).get_params(deep=False))
return out

def fit(self, X, y=None):
"""Fit all transformers using X.

Parameters
----------
X : array-like or sparse matrix, shape (n_samples, n_features)
Input data, used to fit transformers.
"""
transformers = Parallel(n_jobs=self.n_jobs)(
delayed(_fit_one_transformer)(trans, X[column], y)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use .iloc if it exists otherwise slice in second direction, and allow multiple columns.

for name, (trans, column) in sorted(self.transformers.items()))
self._update_transformers(transformers)
return self

def fit_transform(self, X, y=None, **fit_params):
"""Fit all transformers using X, transform the data and concatenate
results.

Parameters
----------
X : array-like or sparse matrix, shape (n_samples, n_features)
Input data to be transformed.

Returns
-------
X_t : array-like or sparse matrix, shape (n_samples, sum_n_components)
hstack of results of transformers. sum_n_components is the
sum of n_components (output dimension) over transformers.
"""
result = Parallel(n_jobs=self.n_jobs)(
delayed(_fit_transform_one)(trans, name, X[column], y,
self.transformer_weights,
**fit_params)
for name, (trans, column) in sorted(self.transformers.items()))

Xs, transformers = zip(*result)
self._update_transformers(transformers)
if any(sparse.issparse(f) for f in Xs):
Xs = sparse.hstack(Xs).tocsr()
else:
Xs = np.hstack(Xs)
return Xs

def transform(self, X):
"""Transform X separately by each transformer, concatenate results.

Parameters
----------
X : array-like or sparse matrix, shape (n_samples, n_features)
Input data to be transformed.

Returns
-------
X_t : array-like or sparse matrix, shape (n_samples, sum_n_components)
hstack of results of transformers. sum_n_components is the
sum of n_components (output dimension) over transformers.
"""
Xs = Parallel(n_jobs=self.n_jobs)(
delayed(_transform_one)(trans, name, X[column], self.transformer_weights)
for name, (trans, column) in sorted(self.transformers.items()))
if any(sparse.issparse(f) for f in Xs):
Xs = sparse.hstack(Xs).tocsr()
else:
Xs = np.hstack(Xs)
return Xs

def _update_transformers(self, transformers):
# use a dict constructor instaed of a dict comprehension for python2.6
self.transformers.update(dict(
(name, (new, column))
for ((name, (old, column)), new) in zip(sorted(self.transformers.items()), transformers))
)
Loading