Skip to content

[MRG+1] FIX Pipelined fitting of Clustering algorithms, scoring of K-Means in pipelines #4064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 6, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions doc/developers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,11 @@ is not met, an exception of type ``ValueError`` should be raised.
``y`` might be ignored in the case of unsupervised learning. However, to
make it possible to use the estimator as part of a pipeline that can
mix both supervised and unsupervised transformers, even unsupervised
estimators are kindly asked to accept a ``y=None`` keyword argument in
estimators need to accept a ``y=None`` keyword argument in
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this doc was already here. I guess I only looked at a glance at what I thought was an issue, but not really.

the second position that is just ignored by the estimator.
For the same reason, ``fit_predict``, ``fit_transform``, ``score``
and ``partial_fit`` methods need to accept a ``y`` argument in
the second place if they are implemented.

The method should return the object (``self``). This pattern is useful
to be able to implement quick one liners in an IPython session such as::
Expand Down Expand Up @@ -857,9 +860,10 @@ last step, it needs to provide a ``fit`` or ``fit_transform`` function.
To be able to evaluate the pipeline on any data but the training set,
it also needs to provide a ``transform`` function.
There are no special requirements for the last step in a pipeline, except that
it has a ``fit`` function. All ``fit`` and ``fit_transform`` functions must
take arguments ``X, y``, even if y is not used.

it has a ``fit`` function. All ``fit`` and ``fit_transform`` functions must
take arguments ``X, y``, even if y is not used. Similarly, for ``score`` to be
usable, the last step of the pipeline needs to have a ``score`` function that
accepts an optional ``y``.

Working notes
-------------
Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ Enhancements
- Parallelized calculation of :func:`pairwise_distances` is now supported
for scipy metrics and custom callables. By `Joel Nothman`_.

- Allow the fitting and scoring of all clustering algorithms in
:class:`pipeline.Pipeline`. By `Andreas Müller`_.

Documentation improvements
..........................

Expand Down
2 changes: 1 addition & 1 deletion sklearn/cluster/affinity_propagation_.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def __init__(self, damping=.5, max_iter=200, convergence_iter=15,
def _pairwise(self):
return self.affinity == "precomputed"

def fit(self, X):
def fit(self, X, y=None):
""" Create affinity matrix from negative euclidean distances, then
apply affinity propagation clustering.

Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/dbscan_.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class DBSCAN(BaseEstimator, ClusterMixin):
of the construction and query, as well as the memory required
to store the tree. The optimal value depends
on the nature of the problem.

Attributes
----------
core_sample_indices_ : array, shape = [n_core_samples]
Expand Down Expand Up @@ -224,7 +224,7 @@ def __init__(self, eps=0.5, min_samples=5, metric='euclidean',
self.p = p
self.random_state = random_state

def fit(self, X, sample_weight=None):
def fit(self, X, y=None, sample_weight=None):
"""Perform DBSCAN clustering from features or distance matrix.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion sklearn/cluster/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def __init__(self, n_clusters=2, affinity="euclidean",
self.affinity = affinity
self.pooling_func = pooling_func

def fit(self, X):
def fit(self, X, y=None):
"""Fit the hierarchical clustering on the data

Parameters
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/k_means_.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def fit(self, X, y=None):
n_jobs=self.n_jobs)
return self

def fit_predict(self, X):
def fit_predict(self, X, y=None):
"""Compute cluster centers and predict cluster index for each sample.

Convenience method; equivalent to calling fit(X) followed by
Expand Down Expand Up @@ -864,7 +864,7 @@ def predict(self, X):
x_squared_norms = row_norms(X, squared=True)
return _labels_inertia(X, x_squared_norms, self.cluster_centers_)[0]

def score(self, X):
def score(self, X, y=None):
"""Opposite of the value of X on the K-means objective.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion sklearn/cluster/mean_shift_.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def __init__(self, bandwidth=None, seeds=None, bin_seeding=False,
self.cluster_all = cluster_all
self.min_bin_freq = min_bin_freq

def fit(self, X):
def fit(self, X, y=None):
"""Perform clustering.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion sklearn/cluster/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def __init__(self, n_clusters=8, eigen_solver=None, random_state=None,
self.coef0 = coef0
self.kernel_params = kernel_params

def fit(self, X):
def fit(self, X, y=None):
"""Creates an affinity matrix for X using the selected affinity,
then applies spectral clustering to this affinity matrix.

Expand Down
11 changes: 4 additions & 7 deletions sklearn/decomposition/dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ def dict_learning(X, n_components, alpha, max_iter=100, tol=1e-8,
SparsePCA
MiniBatchSparsePCA
"""

if method not in ('lars', 'cd'):
raise ValueError('Coding method %r not supported as a fit algorithm.'
% method)
Expand Down Expand Up @@ -604,6 +603,8 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
MiniBatchSparsePCA

"""
if n_components is None:
n_components = X.shape[1]

if method not in ('lars', 'cd'):
raise ValueError('Coding method not supported as a fit algorithm.')
Expand Down Expand Up @@ -750,7 +751,7 @@ def transform(self, X, y=None):
Transformed data

"""
check_is_fitted(self, 'components_')
check_is_fitted(self, 'components_')

# XXX : kwargs is not documented
X = check_array(X)
Expand Down Expand Up @@ -1159,13 +1160,9 @@ def fit(self, X, y=None):
"""
random_state = check_random_state(self.random_state)
X = check_array(X)
if self.n_components is None:
n_components = X.shape[1]
else:
n_components = self.n_components

U, (A, B), self.n_iter_ = dict_learning_online(
X, n_components, self.alpha,
X, self.n_components, self.alpha,
n_iter=self.n_iter, return_code=False,
method=self.fit_algorithm,
n_jobs=self.n_jobs, dict_init=self.dict_init,
Expand Down
2 changes: 1 addition & 1 deletion sklearn/decomposition/incremental_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def fit(self, X, y=None):
self.partial_fit(X[batch])
return self

def partial_fit(self, X):
def partial_fit(self, X, y=None):
"""Incremental fit with X. All of X is processed as a single batch.

Parameters
Expand Down
4 changes: 2 additions & 2 deletions sklearn/neural_network/rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def gibbs(self, v):

return v_

def partial_fit(self, X):
def partial_fit(self, X, y=None):
"""Fit the model to the data X which should contain a partial
segment of the data.

Expand Down Expand Up @@ -301,7 +301,7 @@ def score_samples(self, X):
returns the log of the logistic function of the difference.
"""
check_is_fitted(self, "components_")

v = check_array(X, accept_sparse='csr')
rng = check_random_state(self.random_state)

Expand Down
5 changes: 5 additions & 0 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
check_regressor_data_not_an_array,
check_transformer_data_not_an_array,
check_transformer_n_iter,
check_fit_score_takes_y,
check_non_transformer_estimators_n_iter,
check_pipeline_consistency,
CROSS_DECOMPOSITION)


Expand Down Expand Up @@ -87,6 +89,9 @@ def test_non_meta_estimators():
estimators = all_estimators(type_filter=['classifier', 'regressor',
'transformer', 'cluster'])
for name, Estimator in estimators:
if name not in CROSS_DECOMPOSITION:
yield check_fit_score_takes_y, name, Estimator
yield check_pipeline_consistency, name, Estimator
if name not in CROSS_DECOMPOSITION + ['Imputer']:
# Test that all estimators check their input for NaN's and infs
yield check_estimators_nan_inf, name, Estimator
Expand Down
82 changes: 55 additions & 27 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import SkipTest
from sklearn.utils.testing import check_skip_travis
from sklearn.utils.testing import ignore_warnings

from sklearn.base import (clone, ClusterMixin, ClassifierMixin, RegressorMixin,
TransformerMixin)
from sklearn.base import clone, ClassifierMixin
from sklearn.metrics import accuracy_score, adjusted_rand_score, f1_score

from sklearn.lda import LDA
from sklearn.random_projection import BaseRandomProjection
from sklearn.feature_selection import SelectKBest
from sklearn.svm.base import BaseLibSVM
from sklearn.pipeline import make_pipeline

from sklearn.utils.validation import DataConversionWarning, NotFittedError
from sklearn.cross_validation import train_test_split
Expand All @@ -44,13 +45,6 @@
CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD']


def is_supervised(estimator):
return (isinstance(estimator, ClassifierMixin)
or isinstance(estimator, RegressorMixin)
# transformers can all take a y
or isinstance(estimator, TransformerMixin))


def _boston_subset(n_samples=200):
global BOSTON
if BOSTON is None:
Expand Down Expand Up @@ -88,6 +82,10 @@ def set_fast_parameters(estimator):
# K-Means
estimator.set_params(n_init=2)

if estimator.__class__.__name__ == "SelectFdr":
# avoid not selecting any features
estimator.set_params(alpha=.5)

if isinstance(estimator, BaseRandomProjection):
# Due to the jl lemma and often very few samples, the number
# of components of the random matrix projection will be probably
Expand Down Expand Up @@ -131,10 +129,7 @@ def check_estimator_sparse_data(name, Estimator):
set_fast_parameters(estimator)
# fit and predict
try:
if is_supervised(estimator):
estimator.fit(X, y)
else:
estimator.fit(X)
estimator.fit(X, y)
if hasattr(estimator, "predict"):
estimator.predict(X)
if hasattr(estimator, 'predict_proba'):
Expand Down Expand Up @@ -252,6 +247,50 @@ def _check_transformer(name, Transformer, X, y):
assert_raises(ValueError, transformer.transform, X.T)


@ignore_warnings
def check_pipeline_consistency(name, Estimator):
# check that make_pipeline(est) gives same score as est
X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
random_state=0, n_features=2, cluster_std=0.1)
X -= X.min()
y = multioutput_estimator_convert_y_2d(name, y)
estimator = Estimator()
pipeline = make_pipeline(estimator)
set_fast_parameters(estimator)
set_random_state(estimator)
estimator.fit(X, y)
pipeline.fit(X, y)
funcs = ["score", "fit_transform"]
for func_name in funcs:
func = getattr(estimator, func_name, None)
if func is not None:
func_pipeline = getattr(pipeline, func_name)
result = func(X, y)
result_pipe = func_pipeline(X, y)
assert_array_almost_equal(result, result_pipe)


@ignore_warnings
def check_fit_score_takes_y(name, Estimator):
# check that all estimators accept an optional y
# in fit and score so they can be used in pipelines
rnd = np.random.RandomState(0)
X = rnd.uniform(size=(10, 3))
y = (X[:, 0] * 4).astype(np.int)
y = multioutput_estimator_convert_y_2d(name, y)
estimator = Estimator()
set_fast_parameters(estimator)
set_random_state(estimator)
funcs = ["fit", "score", "partial_fit", "fit_predict", "fit_transform"]

for func_name in funcs:
func = getattr(estimator, func_name, None)
if func is not None:
func(X, y)
args = inspect.getargspec(func).args
assert_true(args[2] in ["y", "Y"])


def check_estimators_nan_inf(name, Estimator):
rnd = np.random.RandomState(0)
X_train_finite = rnd.uniform(size=(10, 3))
Expand All @@ -275,10 +314,7 @@ def check_estimators_nan_inf(name, Estimator):
set_random_state(estimator, 1)
# try to fit
try:
if issubclass(Estimator, ClusterMixin):
estimator.fit(X_train)
else:
estimator.fit(X_train, y)
estimator.fit(X_train, y)
except ValueError as e:
if 'inf' not in repr(e) and 'NaN' not in repr(e):
print(error_string_fit, Estimator, e)
Expand All @@ -291,12 +327,7 @@ def check_estimators_nan_inf(name, Estimator):
else:
raise AssertionError(error_string_fit, Estimator)
# actually fit
if issubclass(Estimator, ClusterMixin):
# All estimators except clustering algorithm
# support fitting with (optional) y
estimator.fit(X_train_finite)
else:
estimator.fit(X_train_finite, y)
estimator.fit(X_train_finite, y)

# predict
if hasattr(estimator, "predict"):
Expand Down Expand Up @@ -833,10 +864,7 @@ def check_estimators_overwrite_params(name, Estimator):
set_random_state(estimator)

params = estimator.get_params()
if is_supervised(estimator):
estimator.fit(X, y)
else:
estimator.fit(X)
estimator.fit(X, y)
new_params = estimator.get_params()
for k, v in params.items():
assert_false(np.any(new_params[k] != v),
Expand Down