-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
MNT refactor _get_response_values #21538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -26,72 +24,6 @@ | |||
) | |||
|
|||
|
|||
def test_errors(pyplot): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those tests are moved in the common test in test_plot_curve_common.py
@@ -634,42 +634,6 @@ def iris_data_binary(iris_data): | |||
return X[y < 2], y[y < 2] | |||
|
|||
|
|||
def test_calibration_display_validation(pyplot, iris_data, iris_data_binary): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those tests are moved in the common test in test_plot_curve_common.py
@@ -189,7 +189,7 @@ def test_parallel_execution(data, method, ensemble): | |||
X, y = data | |||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) | |||
|
|||
base_estimator = LinearSVC(random_state=42) | |||
base_estimator = make_pipeline(StandardScaler(), LinearSVC(random_state=42)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove a bunch of ConvergenceWarning
for free
@@ -21,48 +20,6 @@ | |||
) | |||
|
|||
|
|||
def test_precision_recall_display_validation(pyplot): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those tests are moved in the common test in plot_common_curve_display.py
Maybe this would be easier to review @ogrisel @thomasjpfan |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, I like the refactor.
abea58f
to
a7e46da
Compare
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the follow up!
target_type = type_of_target(y_true) | ||
if target_type != "binary": | ||
raise ValueError( | ||
f"The target y is not binary. Got {target_type} type of target." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this check here, since we do the check in det_curve
?
scikit-learn/sklearn/metrics/_ranking.py
Lines 310 to 314 in 845b1fa
if len(np.unique(y_true)) != 2: | |
raise ValueError( | |
"Only one class present in y_true. Detection error " | |
"tradeoff curve is not defined in that case." | |
) |
…nto is/18212_again_again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First pass of comments
response_method="predict", | ||
target_type=target_type, | ||
) | ||
assert_allclose(y_pred, regressor.predict(X)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one could be checking strict equality right?
|
||
Raises | ||
------ | ||
ValueError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ValueError | |
AttributeError |
y_pred = y_pred[:, col_idx] | ||
else: | ||
err_msg = ( | ||
f"Got predict_proba of shape {y_pred.shape}, but need " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the error message not super explicit, I would find it easier to understand if it was saying the classifier was fitted with only one class (rather that complaining about the shape of the predict_proba). Maybe there are some edge cases I haven't thought of though.
f"one of {classes}" | ||
) | ||
elif pos_label is None and target_type == "binary": | ||
pos_label = pos_label if pos_label is not None else classes[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although -1 and 1 are equivalent I would use 1 since classes[1]
is used elsewhere e.g. in _get_response_values
docstring:
pos_label = pos_label if pos_label is not None else classes[-1] | |
pos_label = pos_label if pos_label is not None else classes[1] |
method = ( | ||
["predict_proba", "decision_function", "predict"] | ||
if method == "auto" | ||
else method | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit lighter to parse IMO (and also what is used in other places in this PR):
method = ( | |
["predict_proba", "decision_function", "predict"] | |
if method == "auto" | |
else method | |
) | |
if method == "auto": | |
method = ["predict_proba", "decision_function", "predict"] |
"'fit' with appropriate arguments before intending to use it to plotting " | ||
"functionalities." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a bit lighter to read (better suggestion welcome):
"'fit' with appropriate arguments before intending to use it to plotting " | |
"functionalities." | |
"'fit' with appropriate arguments before using it for plotting " | |
"functionalities." |
The CI is red at the moment, maybe I got something wrong in my merge ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment that still needs to be addressed, with the same concern:
- https://github.com/scikit-learn/scikit-learn/pull/21538/files#r769872637
- https://github.com/scikit-learn/scikit-learn/pull/21538/files#r769873989
The metric checks if the input is binary. If the *Display
object checks too, then there is double validation and np.unqiue
is called twice.
I guess this is okay.
@@ -330,6 +342,12 @@ def from_predictions( | |||
""" | |||
check_matplotlib_support(f"{cls.__name__}.from_predictions") | |||
|
|||
target_type = type_of_target(y_true) | |||
if target_type != "binary": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think we need this check. roc_curve
calls _binary_clf_curve
which ends up doing the binary check itself:
scikit-learn/sklearn/metrics/_ranking.py
Lines 736 to 738 in 22ca942
y_type = type_of_target(y_true, input_name="y_true") | |
if not (y_type == "binary" or (y_type == "multiclass" and pos_label is not None)): | |
raise ValueError("{0} format is not supported".format(y_type)) |
Let's move that to 1.2 |
Closing this one since we merged #23073 |
Partially address #18212
This is a simplification towards merging #20999