From cc00d295a7519e9e5a404f2c19af711482861290 Mon Sep 17 00:00:00 2001 From: Mikhail Korobov Date: Wed, 2 Nov 2016 19:43:46 +0500 Subject: [PATCH 1/3] OneVsRestClassifier: don't expose predict_proba and decision_function methods if they are not supported by base estimator. --- sklearn/multiclass.py | 7 +++---- sklearn/tests/test_multiclass.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index e3fad7e08e3e0..9f93bc8e7e239 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -50,7 +50,7 @@ from .utils.multiclass import (_check_partial_fit_first_call, check_classification_targets, _ovr_decision_function) -from .utils.metaestimators import _safe_split +from .utils.metaestimators import _safe_split, if_delegate_has_method from .externals.joblib import Parallel from .externals.joblib import delayed @@ -309,6 +309,7 @@ def predict(self, X): shape=(n_samples, len(self.estimators_))) return self.label_binarizer_.inverse_transform(indicator) + @if_delegate_has_method('estimator') def predict_proba(self, X): """Probability estimates. @@ -347,6 +348,7 @@ def predict_proba(self, X): Y /= np.sum(Y, axis=1)[:, np.newaxis] return Y + @if_delegate_has_method('estimator') def decision_function(self, X): """Returns the distance of each sample from the decision boundary for each class. This can only be used with estimators which implement the @@ -361,9 +363,6 @@ def decision_function(self, X): T : array-like, shape = [n_samples, n_classes] """ check_is_fitted(self, 'estimators_') - if not hasattr(self.estimators_[0], "decision_function"): - raise AttributeError( - "Base estimator doesn't have a decision_function attribute.") return np.array([est.decision_function(X).ravel() for est in self.estimators_]).T diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index 5bdc13f8d5d9a..04e4c32057d6e 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -314,14 +314,17 @@ def test_ovr_multilabel_predict_proba(): X_test = X[80:] clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train) - # decision function only estimator. Fails in current implementation. + # Decision function only estimator. decision_only = OneVsRestClassifier(svm.SVR()).fit(X_train, Y_train) - assert_raises(AttributeError, decision_only.predict_proba, X_test) + assert not hasattr(decision_only, 'predict_proba') + assert hasattr(decision_only, 'decision_function') # Estimator with predict_proba disabled, depending on parameters. decision_only = OneVsRestClassifier(svm.SVC(probability=False)) + assert not hasattr(decision_only, 'predict_proba') decision_only.fit(X_train, Y_train) - assert_raises(AttributeError, decision_only.predict_proba, X_test) + assert not hasattr(decision_only, 'predict_proba') + assert hasattr(decision_only, 'decision_function') Y_pred = clf.predict(X_test) Y_proba = clf.predict_proba(X_test) @@ -339,9 +342,10 @@ def test_ovr_single_label_predict_proba(): X_test = X[80:] clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train) - # decision function only estimator. Fails in current implementation. + # Decision function only estimator. decision_only = OneVsRestClassifier(svm.SVR()).fit(X_train, Y_train) - assert_raises(AttributeError, decision_only.predict_proba, X_test) + assert not hasattr(decision_only, 'predict_proba') + assert hasattr(decision_only, 'decision_function') Y_pred = clf.predict(X_test) Y_proba = clf.predict_proba(X_test) From 5c34996cb40ec3d1013be4e80e069bfbcf907201 Mon Sep 17 00:00:00 2001 From: Mikhail Korobov Date: Thu, 3 Nov 2016 02:21:48 +0500 Subject: [PATCH 2/3] TST use nose-style asserts --- sklearn/tests/test_multiclass.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index 04e4c32057d6e..fda86f0639cff 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -316,15 +316,15 @@ def test_ovr_multilabel_predict_proba(): # Decision function only estimator. decision_only = OneVsRestClassifier(svm.SVR()).fit(X_train, Y_train) - assert not hasattr(decision_only, 'predict_proba') - assert hasattr(decision_only, 'decision_function') + assert_false(hasattr(decision_only, 'predict_proba')) + assert_true(hasattr(decision_only, 'decision_function')) # Estimator with predict_proba disabled, depending on parameters. decision_only = OneVsRestClassifier(svm.SVC(probability=False)) - assert not hasattr(decision_only, 'predict_proba') + assert_false(hasattr(decision_only, 'predict_proba')) decision_only.fit(X_train, Y_train) - assert not hasattr(decision_only, 'predict_proba') - assert hasattr(decision_only, 'decision_function') + assert_false(hasattr(decision_only, 'predict_proba')) + assert_true(hasattr(decision_only, 'decision_function')) Y_pred = clf.predict(X_test) Y_proba = clf.predict_proba(X_test) @@ -344,8 +344,8 @@ def test_ovr_single_label_predict_proba(): # Decision function only estimator. decision_only = OneVsRestClassifier(svm.SVR()).fit(X_train, Y_train) - assert not hasattr(decision_only, 'predict_proba') - assert hasattr(decision_only, 'decision_function') + assert_false(hasattr(decision_only, 'predict_proba')) + assert_true(hasattr(decision_only, 'decision_function')) Y_pred = clf.predict(X_test) Y_proba = clf.predict_proba(X_test) From 14a9f2f19c1777494ac650be7251d0e0305bbfc6 Mon Sep 17 00:00:00 2001 From: Mikhail Korobov Date: Thu, 3 Nov 2016 02:33:19 +0500 Subject: [PATCH 3/3] handle a case where classifier get predict_proba method only after .fit --- sklearn/multiclass.py | 8 ++++++-- sklearn/tests/test_multiclass.py | 8 ++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 9f93bc8e7e239..66b0a47da9599 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -309,7 +309,7 @@ def predict(self, X): shape=(n_samples, len(self.estimators_))) return self.label_binarizer_.inverse_transform(indicator) - @if_delegate_has_method('estimator') + @if_delegate_has_method(['_first_estimator', 'estimator']) def predict_proba(self, X): """Probability estimates. @@ -348,7 +348,7 @@ def predict_proba(self, X): Y /= np.sum(Y, axis=1)[:, np.newaxis] return Y - @if_delegate_has_method('estimator') + @if_delegate_has_method(['_first_estimator', 'estimator']) def decision_function(self, X): """Returns the distance of each sample from the decision boundary for each class. This can only be used with estimators which implement the @@ -399,6 +399,10 @@ def _pairwise(self): """Indicate if wrapped estimator is using a precomputed Gram matrix""" return getattr(self.estimator, "_pairwise", False) + @property + def _first_estimator(self): + return self.estimators_[0] + def _fit_ovo_binary(estimator, X, y, i, j): """Fit a single binary estimator (one-vs-one).""" diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index fda86f0639cff..ca5909ea72a48 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -326,6 +326,14 @@ def test_ovr_multilabel_predict_proba(): assert_false(hasattr(decision_only, 'predict_proba')) assert_true(hasattr(decision_only, 'decision_function')) + # Estimator which can get predict_proba enabled after fitting + gs = GridSearchCV(svm.SVC(probability=False), + param_grid={'probability': [True]}) + proba_after_fit = OneVsRestClassifier(gs) + assert_false(hasattr(proba_after_fit, 'predict_proba')) + proba_after_fit.fit(X_train, Y_train) + assert_true(hasattr(proba_after_fit, 'predict_proba')) + Y_pred = clf.predict(X_test) Y_proba = clf.predict_proba(X_test)