From ceebfeca1312503b7fa678525dfab939a44db1f8 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 3 Jun 2021 09:13:25 -0400 Subject: [PATCH 1/4] FIX Do not reset for non-fit in multiclass --- sklearn/multiclass.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 99a6db2051030..d40b6e35c7880 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -120,19 +120,19 @@ def fit(self, X, y): def predict(self, X): check_is_fitted(self) - self._check_n_features(X, reset=True) + self._check_n_features(X, reset=False) return np.repeat(self.y_, _num_samples(X)) def decision_function(self, X): check_is_fitted(self) - self._check_n_features(X, reset=True) + self._check_n_features(X, reset=False) return np.repeat(self.y_, _num_samples(X)) def predict_proba(self, X): check_is_fitted(self) - self._check_n_features(X, reset=True) + self._check_n_features(X, reset=False) return np.repeat([np.hstack([1 - self.y_, self.y_])], _num_samples(X), axis=0) From 0618741c209105259a8f6eee1b02d0b2bfeffe22 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 7 Jun 2021 12:04:17 -0400 Subject: [PATCH 2/4] ENH Adds validation to inner estimator --- sklearn/multiclass.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index d40b6e35c7880..fe81af746acb3 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -114,25 +114,29 @@ def _check_estimator(estimator): class _ConstantPredictor(BaseEstimator): def fit(self, X, y): - self._check_n_features(X, reset=True) + self._validate_data(X, y, force_all_finite=False, dtype=None, + ensure_2d=False, reset=True) self.y_ = y return self def predict(self, X): check_is_fitted(self) - self._check_n_features(X, reset=False) + self._validate_data(X, force_all_finite=False, dtype=None, + ensure_2d=False, reset=False) return np.repeat(self.y_, _num_samples(X)) def decision_function(self, X): check_is_fitted(self) - self._check_n_features(X, reset=False) + self._validate_data(X, force_all_finite=False, dtype=None, + ensure_2d=False, reset=False) return np.repeat(self.y_, _num_samples(X)) def predict_proba(self, X): check_is_fitted(self) - self._check_n_features(X, reset=False) + self._validate_data(X, force_all_finite=False, dtype=None, + ensure_2d=False, reset=False) return np.repeat([np.hstack([1 - self.y_, self.y_])], _num_samples(X), axis=0) From e687c5e08c49bb07f2d8c3b09b7de161f542f299 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 7 Jun 2021 13:39:39 -0400 Subject: [PATCH 3/4] FIX Fixes test errors --- sklearn/multiclass.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index fe81af746acb3..30539e5e3aeac 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -54,7 +54,8 @@ from .utils.validation import _assert_all_finite from .utils.multiclass import (_check_partial_fit_first_call, check_classification_targets, - _ovr_decision_function) + _ovr_decision_function, + check_array) from .utils.metaestimators import _safe_split, if_delegate_has_method from .utils.fixes import delayed @@ -114,14 +115,17 @@ def _check_estimator(estimator): class _ConstantPredictor(BaseEstimator): def fit(self, X, y): - self._validate_data(X, y, force_all_finite=False, dtype=None, - ensure_2d=False, reset=True) + check_params = dict(force_all_finite=False, dtype=None, + ensure_2d=False, accept_sparse=True) + self._validate_data(X, y, reset=True, + validate_separately=(check_params, check_params)) self.y_ = y return self def predict(self, X): check_is_fitted(self) self._validate_data(X, force_all_finite=False, dtype=None, + accept_sparse=True, ensure_2d=False, reset=False) return np.repeat(self.y_, _num_samples(X)) @@ -129,6 +133,7 @@ def predict(self, X): def decision_function(self, X): check_is_fitted(self) self._validate_data(X, force_all_finite=False, dtype=None, + accept_sparse=True, ensure_2d=False, reset=False) return np.repeat(self.y_, _num_samples(X)) @@ -136,6 +141,7 @@ def decision_function(self, X): def predict_proba(self, X): check_is_fitted(self) self._validate_data(X, force_all_finite=False, dtype=None, + accept_sparse=True, ensure_2d=False, reset=False) return np.repeat([np.hstack([1 - self.y_, self.y_])], From f8dc5097152a0829bd9d472c93f90878e0000eae Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 7 Jun 2021 17:06:35 -0400 Subject: [PATCH 4/4] STY Lint error --- sklearn/multiclass.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 30539e5e3aeac..ad420506a9694 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -54,8 +54,7 @@ from .utils.validation import _assert_all_finite from .utils.multiclass import (_check_partial_fit_first_call, check_classification_targets, - _ovr_decision_function, - check_array) + _ovr_decision_function) from .utils.metaestimators import _safe_split, if_delegate_has_method from .utils.fixes import delayed