Skip to content

[WIP] Consistent and informative error message for partial_fit when n_features changes #12465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions sklearn/cluster/birch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..externals.six.moves import xrange
from ..utils import check_array
from ..utils.extmath import row_norms, safe_sparse_dot
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, check_partial_fit_n_features
from ..exceptions import NotFittedError, ConvergenceWarning
from .hierarchical import AgglomerativeClustering

Expand Down Expand Up @@ -546,11 +546,8 @@ def _check_fit(self, X):
# Should raise an error if one does not fit before predicting.
if not (is_fitted or has_partial_fit):
raise NotFittedError("Fit training data before predicting")

if is_fitted and X.shape[1] != self.subcluster_centers_.shape[1]:
raise ValueError(
"Training data and predicted data do "
"not have same number of features.")
if is_fitted:
check_partial_fit_n_features(X, self.subcluster_centers_, self)

def predict(self, X):
"""
Expand Down
3 changes: 2 additions & 1 deletion sklearn/cluster/k_means_.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..utils import check_array
from ..utils import gen_batches
from ..utils import check_random_state
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, check_partial_fit_n_features
from ..utils.validation import FLOAT_DTYPES
from ..utils import Parallel
from ..utils import delayed
Expand Down Expand Up @@ -1694,6 +1694,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
random_reassign = False
distances = None
else:
check_partial_fit_n_features(X, self.cluster_centers_, self)
# The lower the minimum count is, the more we do random
# reassignment, however, we don't want to do random
# reassignment too often, to allow for building up counts
Expand Down
3 changes: 2 additions & 1 deletion sklearn/decomposition/dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..utils import (check_array, check_random_state, gen_even_slices,
gen_batches)
from ..utils.extmath import randomized_svd, row_norms
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, check_partial_fit_n_features
from ..linear_model import Lasso, orthogonal_mp_gram, LassoLars, Lars


Expand Down Expand Up @@ -1412,6 +1412,7 @@ def partial_fit(self, X, y=None, iter_offset=None):
self.random_state_ = check_random_state(self.random_state)
X = check_array(X)
if hasattr(self, 'components_'):
check_partial_fit_n_features(X, self.components_, self)
dict_init = self.components_
else:
dict_init = self.dict_init
Expand Down
5 changes: 4 additions & 1 deletion sklearn/decomposition/incremental_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .base import _BasePCA
from ..utils import check_array, gen_batches
from ..utils.extmath import svd_flip, _incremental_mean_and_var
from ..utils.validation import check_partial_fit_n_features


class IncrementalPCA(_BasePCA):
Expand Down Expand Up @@ -224,6 +225,8 @@ def partial_fit(self, X, y=None, check_input=True):
n_samples, n_features = X.shape
if not hasattr(self, 'components_'):
self.components_ = None
elif self.components_ is not None:
check_partial_fit_n_features(X, self.components_, self)

if self.n_components is None:
if self.components_ is None:
Expand All @@ -243,7 +246,7 @@ def partial_fit(self, X, y=None, check_input=True):

if (self.components_ is not None) and (self.components_.shape[0] !=
self.n_components_):
raise ValueError("Number of input features has changed from %i "
raise ValueError("Number of components has changed from %i "
"to %i between calls to partial_fit! Try "
"setting n_components to a fixed value." %
(self.components_.shape[0], self.n_components_))
Expand Down
10 changes: 3 additions & 7 deletions sklearn/decomposition/online_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..utils import (check_random_state, check_array,
gen_batches, gen_even_slices)
from ..utils.fixes import logsumexp
from ..utils.validation import check_non_negative
from ..utils.validation import check_non_negative, check_partial_fit_n_features
from ..utils import Parallel, delayed, effective_n_jobs
from ..externals.six.moves import xrange
from ..exceptions import NotFittedError
Expand Down Expand Up @@ -493,12 +493,8 @@ def partial_fit(self, X, y=None):
# initialize parameters or check
if not hasattr(self, 'components_'):
self._init_latent_vars(n_features)

if n_features != self.components_.shape[1]:
raise ValueError(
"The provided data has %d dimensions while "
"the model was trained with feature size %d." %
(n_features, self.components_.shape[1]))
else:
check_partial_fit_n_features(X, self.components_, self)

n_jobs = effective_n_jobs(self.n_jobs)
with Parallel(n_jobs=n_jobs, verbose=max(0,
Expand Down
17 changes: 1 addition & 16 deletions sklearn/decomposition/tests/test_online_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,6 @@ def test_lda_fit_transform(method):
assert_array_almost_equal(X_fit, X_trans, 4)


def test_lda_partial_fit_dim_mismatch():
# test `n_features` mismatch in `partial_fit`
rng = np.random.RandomState(0)
n_components = rng.randint(3, 6)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this test because it does exactly what the common test check_estimators_partial_fit_n_features does.

n_col = rng.randint(6, 10)
X_1 = np.random.randint(4, size=(10, n_col))
X_2 = np.random.randint(4, size=(10, n_col + 1))
lda = LatentDirichletAllocation(n_components=n_components,
learning_offset=5., total_samples=20,
random_state=rng)
lda.partial_fit(X_1)
assert_raises_regexp(ValueError, r"^The provided data has",
lda.partial_fit, X_2)


def test_invalid_params():
# test `_check_params` method
X = np.ones((5, 10))
Expand Down Expand Up @@ -202,7 +187,7 @@ def test_lda_transform_mismatch():
random_state=rng)
lda.partial_fit(X)
assert_raises_regexp(ValueError, r"^The provided data has",
lda.partial_fit, X_2)
lda.transform, X_2)
Copy link
Member Author

@jeremiedbb jeremiedbb Oct 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed that because as is, it was a duplicate of test_lda_partial_fit_dim_mismatch and I think it did not do what it was expected to do, i.e. "# test n_features mismatch in partial_fit and transform"



@if_safe_multiprocessing_with_blas
Expand Down
13 changes: 6 additions & 7 deletions sklearn/linear_model/stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..utils import check_array, check_random_state, check_X_y
from ..utils.extmath import safe_sparse_dot
from ..utils.multiclass import _check_partial_fit_first_call
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, check_partial_fit_n_features
from ..exceptions import ConvergenceWarning
from ..externals import six
from ..model_selection import StratifiedShuffleSplit, ShuffleSplit
Expand Down Expand Up @@ -533,9 +533,8 @@ def _partial_fit(self, X, y, alpha, C,
if getattr(self, "coef_", None) is None or coef_init is not None:
self._allocate_parameter_mem(n_classes, n_features,
coef_init, intercept_init)
elif n_features != self.coef_.shape[-1]:
raise ValueError("Number of features %d does not match previous "
"data %d." % (n_features, self.coef_.shape[-1]))
else:
check_partial_fit_n_features(X, self.coef_, self)

self.loss_function_ = self._get_loss_function(loss)
if not hasattr(self, "t_"):
Expand Down Expand Up @@ -1144,9 +1143,9 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate,
if getattr(self, "coef_", None) is None:
self._allocate_parameter_mem(1, n_features, coef_init,
intercept_init)
elif n_features != self.coef_.shape[-1]:
raise ValueError("Number of features %d does not match previous "
"data %d." % (n_features, self.coef_.shape[-1]))
else:
check_partial_fit_n_features(X, self.coef_, self)

if self.average > 0 and getattr(self, "average_coef_", None) is None:
self.average_coef_ = np.zeros(n_features,
dtype=np.float64,
Expand Down
12 changes: 5 additions & 7 deletions sklearn/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .utils.extmath import safe_sparse_dot
from .utils.fixes import logsumexp
from .utils.multiclass import _check_partial_fit_first_call
from .utils.validation import check_is_fitted
from .utils.validation import check_is_fitted, check_partial_fit_n_features
from .externals import six

__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB', 'ComplementNB']
Expand Down Expand Up @@ -382,9 +382,8 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
self.class_prior_ = np.zeros(len(self.classes_),
dtype=np.float64)
else:
if X.shape[1] != self.theta_.shape[1]:
msg = "Number of features %d does not match previous data %d."
raise ValueError(msg % (X.shape[1], self.theta_.shape[1]))
check_partial_fit_n_features(X, self.theta_, self)

# Put epsilon back in each time
self.sigma_[:, :] -= self.epsilon_

Expand Down Expand Up @@ -527,9 +526,8 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
self.class_count_ = np.zeros(n_effective_classes, dtype=np.float64)
self.feature_count_ = np.zeros((n_effective_classes, n_features),
dtype=np.float64)
elif n_features != self.coef_.shape[1]:
msg = "Number of features %d does not match previous data %d."
raise ValueError(msg % (n_features, self.coef_.shape[-1]))
else:
check_partial_fit_n_features(X, self.coef_, self)

Y = label_binarize(y, classes=self.classes_)
if Y.shape[1] == 1:
Expand Down
4 changes: 3 additions & 1 deletion sklearn/neural_network/multilayer_perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..utils import check_array, check_X_y, column_or_1d
from ..exceptions import ConvergenceWarning
from ..utils.extmath import safe_sparse_dot
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, check_partial_fit_n_features
from ..utils.multiclass import _check_partial_fit_first_call, unique_labels
from ..utils.multiclass import type_of_target

Expand Down Expand Up @@ -340,6 +340,8 @@ def _fit(self, X, y, incremental=False):
incremental):
# First time training the model
self._initialize(y, layer_units)
else:
check_partial_fit_n_features(X, self.coefs_[0].T, self)

# lbfgs does not support mini-batches
if self.solver == 'lbfgs':
Expand Down
4 changes: 3 additions & 1 deletion sklearn/neural_network/rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..utils import gen_even_slices
from ..utils.extmath import safe_sparse_dot
from ..utils.extmath import log_logistic
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, check_partial_fit_n_features


class BernoulliRBM(BaseEstimator, TransformerMixin):
Expand Down Expand Up @@ -243,6 +243,8 @@ def partial_fit(self, X, y=None):
(self.n_components, X.shape[1])
),
order='F')
else:
check_partial_fit_n_features(X, self.components_, self)
if not hasattr(self, 'intercept_hidden_'):
self.intercept_hidden_ = np.zeros(self.n_components, )
if not hasattr(self, 'intercept_visible_'):
Expand Down
16 changes: 15 additions & 1 deletion sklearn/preprocessing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
mean_variance_axis, incr_mean_variance_axis,
min_max_axis)
from ..utils.validation import (check_is_fitted, check_random_state,
FLOAT_DTYPES)
FLOAT_DTYPES, check_partial_fit_n_features)

from ._csr_polynomial_expansion import _csr_polynomial_expansion

Expand Down Expand Up @@ -358,6 +358,8 @@ def partial_fit(self, X, y=None):
self.n_samples_seen_ = X.shape[0]
# Next steps
else:
check_partial_fit_n_features(X, self.scale_, self)

data_min = np.minimum(self.data_min_, data_min)
data_max = np.maximum(self.data_max_, data_max)
self.n_samples_seen_ += X.shape[0]
Expand Down Expand Up @@ -652,6 +654,16 @@ def partial_fit(self, X, y=None):
self.n_samples_seen_ = np.repeat(self.n_samples_seen_,
X.shape[1]).astype(np.int64)

# if first pass: store number of features
if not hasattr(self, "mean_"):
self._n_features_ = X.shape[1]

# check number of features consistency for next passes
if hasattr(self, "mean_") and self.mean_ is not None:
check_partial_fit_n_features(X, self.mean_, self)
if hasattr(self, "scale_") and self.scale_ is not None:
check_partial_fit_n_features(X, self.scale_, self)

if sparse.issparse(X):
if self.with_mean:
raise ValueError(
Expand Down Expand Up @@ -911,6 +923,8 @@ def partial_fit(self, X, y=None):
self.n_samples_seen_ = X.shape[0]
# Next passes
else:
check_partial_fit_n_features(X, self.scale_, self)

max_abs = np.maximum(self.max_abs_, max_abs)
self.n_samples_seen_ += X.shape[0]

Expand Down
16 changes: 8 additions & 8 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def _yield_classifier_checks(name, classifier):
# test classifiers trained on a single label always return this label
yield check_classifiers_one_label
yield check_classifiers_classes
yield check_estimators_partial_fit_n_features
# basic consistency testing
yield check_classifiers_train
yield partial(check_classifiers_train, readonly_memmap=True)
Expand Down Expand Up @@ -179,7 +178,6 @@ def _yield_regressor_checks(name, regressor):
yield check_regressors_train
yield partial(check_regressors_train, readonly_memmap=True)
yield check_regressor_data_not_an_array
yield check_estimators_partial_fit_n_features
yield check_regressors_no_decision_function
yield check_supervised_y_2d
yield check_supervised_y_no_nan
Expand Down Expand Up @@ -220,7 +218,6 @@ def _yield_clustering_checks(name, clusterer):
# let's not test that here.
yield check_clustering
yield partial(check_clustering, readonly_memmap=True)
yield check_estimators_partial_fit_n_features
yield check_non_transformer_estimators_n_iter


Expand Down Expand Up @@ -268,6 +265,7 @@ def _yield_all_checks(name, estimator):
yield check_dict_unchanged
yield check_dont_overwrite_parameters
yield check_fit_idempotent
yield check_estimators_partial_fit_n_features


def check_estimator(Estimator):
Expand Down Expand Up @@ -1261,11 +1259,13 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
except NotImplementedError:
return

with assert_raises(ValueError,
msg="The estimator {} does not raise an"
" error when the number of features"
" changes between calls to "
"partial_fit.".format(name)):
match = ("Number of input features has changed .* between "
"calls to partial_fit")
msg = ("The estimator {} does not raise an appropriate error when "
"the number of features changes between calls to "
"partial_fit.".format(name))

with assert_raises_regex(ValueError, match, msg=msg):
estimator.partial_fit(X[:, :-1], y)


Expand Down
23 changes: 23 additions & 0 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,3 +971,26 @@ def check_non_negative(X, whom):

if X_min < 0:
raise ValueError("Negative values in data passed to %s" % whom)


def check_partial_fit_n_features(X, components, estimator):
"""
Check if number of features is preseved between calls to partial_fit.

Parameters
----------
X : array-like
Input data for the new call to partial_fit

components : array_like
Fitted attribute of an estimator which has the same number of features
as the input data from the first fit.

estimator : estimator instance.
Estimator instance for which the check is performed.
"""
if X.shape[-1] != components.shape[-1]:
raise ValueError("Number of input features has changed from {0} to {1}"
" between calls to partial_fit of {2}."
"".format(X.shape[-1], components.shape[-1],
type(estimator).__name__))