diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 99a6db2051030..ad420506a9694 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -114,25 +114,34 @@ def _check_estimator(estimator): class _ConstantPredictor(BaseEstimator): def fit(self, X, y): - self._check_n_features(X, reset=True) + check_params = dict(force_all_finite=False, dtype=None, + ensure_2d=False, accept_sparse=True) + self._validate_data(X, y, reset=True, + validate_separately=(check_params, check_params)) self.y_ = y return self def predict(self, X): check_is_fitted(self) - self._check_n_features(X, reset=True) + self._validate_data(X, force_all_finite=False, dtype=None, + accept_sparse=True, + ensure_2d=False, reset=False) return np.repeat(self.y_, _num_samples(X)) def decision_function(self, X): check_is_fitted(self) - self._check_n_features(X, reset=True) + self._validate_data(X, force_all_finite=False, dtype=None, + accept_sparse=True, + ensure_2d=False, reset=False) return np.repeat(self.y_, _num_samples(X)) def predict_proba(self, X): check_is_fitted(self) - self._check_n_features(X, reset=True) + self._validate_data(X, force_all_finite=False, dtype=None, + accept_sparse=True, + ensure_2d=False, reset=False) return np.repeat([np.hstack([1 - self.y_, self.y_])], _num_samples(X), axis=0)