-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+1] Ovr/OVO classifier decision_function shape fixes #9100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Found via #8022. |
@@ -574,6 +576,8 @@ def predict(self, X): | |||
Predicted multi-class targets. | |||
""" | |||
Y = self.decision_function(X) | |||
if self.n_classes_ == 2: | |||
return self.classes_[(Y > 0).astype(np.int)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes that estimators_[0].decision_function correctly returns a vector. Are we relying on check_estimator to validate this, or should we check it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no we should assume that base_estimator actually works. Otherwise everything in sklearn is broken ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, lgtm then
sklearn/tests/test_multiclass.py
Outdated
@@ -251,6 +251,8 @@ def conduct_test(base_clf, test_predict_proba=False): | |||
assert_equal(set(clf.classes_), classes) | |||
y_pred = clf.predict(np.array([[0, 0, 4]]))[0] | |||
assert_equal(set(y_pred), set("eggs")) | |||
dec = clf.decision_function(X) | |||
assert_equal(dec.shape, (5,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP8: missing space after comma
# first binary | ||
ovo_clf.fit(iris.data, iris.target == 0) | ||
decisions = ovo_clf.decision_function(iris.data) | ||
assert_equal(decisions.shape, (n_samples,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP8 here too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is pep8....
LGTM. 👍 |
Aside from the PEP8 issues :) |
LGTM +1 for merge when travis is happy |
sklearn/tests/test_multiclass.py
Outdated
@@ -251,6 +251,9 @@ def conduct_test(base_clf, test_predict_proba=False): | |||
assert_equal(set(clf.classes_), classes) | |||
y_pred = clf.predict(np.array([[0, 0, 4]]))[0] | |||
assert_equal(set(y_pred), set("eggs")) | |||
if hasattr(base_clf, 'decision_function'): | |||
dec = clf.decision_function(X) | |||
assert_equal(dec.shape, (5,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this breaks travis, line needs to be indented deeper
green! @vene? |
* fix OVR classifier edgecase bugs * add regression tests for OVO and OVR decision function shapes
* fix OVR classifier edgecase bugs * add regression tests for OVO and OVR decision function shapes
* fix OVR classifier edgecase bugs * add regression tests for OVO and OVR decision function shapes
* fix OVR classifier edgecase bugs * add regression tests for OVO and OVR decision function shapes
* fix OVR classifier edgecase bugs * add regression tests for OVO and OVR decision function shapes
* fix OVR classifier edgecase bugs * add regression tests for OVO and OVR decision function shapes
* fix OVR classifier edgecase bugs * add regression tests for OVO and OVR decision function shapes
* fix OVR classifier edgecase bugs * add regression tests for OVO and OVR decision function shapes
OVR Classifier had a
decision_function
shape of(n_samples, 1)
for binary classification, OVO classifier had(n_samples, 2)
!Fixed to conform to standard API.