From bb0aed4596444cba2194e0e9bc578980cfccc783 Mon Sep 17 00:00:00 2001 From: Raghav R V Date: Tue, 30 Jun 2015 16:15:17 +0530 Subject: [PATCH 1/2] FIX use best_estimator_ as delegate to avoid incorrect attr checks --- sklearn/grid_search.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index 07bf443eb9c95..7fa8fed79958d 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -415,7 +415,7 @@ def score(self, X, y=None): ChangedBehaviorWarning) return self.scorer_(self.best_estimator_, X, y) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate='best_estimator_') def predict(self, X): """Call predict on the estimator with the best found parameters. @@ -431,7 +431,7 @@ def predict(self, X): """ return self.best_estimator_.predict(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate='best_estimator_') def predict_proba(self, X): """Call predict_proba on the estimator with the best found parameters. @@ -447,7 +447,7 @@ def predict_proba(self, X): """ return self.best_estimator_.predict_proba(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate='best_estimator_') def predict_log_proba(self, X): """Call predict_log_proba on the estimator with the best found parameters. @@ -463,7 +463,7 @@ def predict_log_proba(self, X): """ return self.best_estimator_.predict_log_proba(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate='best_estimator_') def decision_function(self, X): """Call decision_function on the estimator with the best found parameters. @@ -479,7 +479,7 @@ def decision_function(self, X): """ return self.best_estimator_.decision_function(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate='best_estimator_') def transform(self, X): """Call transform on the estimator with the best found parameters. @@ -495,7 +495,7 @@ def transform(self, X): """ return self.best_estimator_.transform(X) - @if_delegate_has_method(delegate='estimator') + @if_delegate_has_method(delegate='best_estimator_') def inverse_transform(self, Xt): """Call inverse_transform on the estimator with the best found parameters. From 9533a814809f51e3685463850331e4c4f91fd602 Mon Sep 17 00:00:00 2001 From: Raghav R V Date: Tue, 30 Jun 2015 16:20:48 +0530 Subject: [PATCH 2/2] FIX raise NotFittedError before checking if model supports predict_proba --- sklearn/svm/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/svm/base.py b/sklearn/svm/base.py index 6bbb4879d5166..df14e61f87678 100644 --- a/sklearn/svm/base.py +++ b/sklearn/svm/base.py @@ -571,6 +571,7 @@ def predict(self, X): # probabilities are not available depending on a setting, introduce two # estimators. def _check_proba(self): + check_is_fitted(self, 'support_') if not self.probability or self.probA_.size == 0 or self.probB_.size == 0: raise AttributeError("predict_proba is not available when fitted with" " probability=False")