diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index be1057f7d7bdd..dcc7b0fb64f4e 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -595,6 +595,15 @@ def make_column_transformer(*transformers, **kwargs): ---------- *transformers : tuples of column selections and transformers + remainder : {'passthrough', 'drop'}, default 'passthrough' + By default, all remaining columns that were not specified in + `transformers` will be automatically passed through (default of + ``'passthrough'``). This subset of columns is concatenated with the + output of the transformers. + By using ``remainder='drop'``, only the specified columns in + `transformers` are transformed and combined in the output, and the + non-specified columns are dropped. + n_jobs : int, optional Number of jobs to run in parallel (default 1). @@ -627,8 +636,10 @@ def make_column_transformer(*transformers, **kwargs): """ n_jobs = kwargs.pop('n_jobs', 1) + remainder = kwargs.pop('remainder', 'passthrough') if kwargs: raise TypeError('Unknown keyword arguments: "{}"' .format(list(kwargs.keys())[0])) transformer_list = _get_transformer_list(transformers) - return ColumnTransformer(transformer_list, n_jobs=n_jobs) + return ColumnTransformer(transformer_list, n_jobs=n_jobs, + remainder=remainder) diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index 420481ff0ad4a..06f5f2d67ca8e 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -342,12 +342,13 @@ def test_make_column_transformer_kwargs(): scaler = StandardScaler() norm = Normalizer() ct = make_column_transformer(('first', scaler), (['second'], norm), - n_jobs=3) + n_jobs=3, remainder='drop') assert_equal( ct.transformers, make_column_transformer(('first', scaler), (['second'], norm)).transformers) assert_equal(ct.n_jobs, 3) + assert_equal(ct.remainder, 'drop') # invalid keyword parameters should raise an error message assert_raise_message( TypeError,