Skip to content

MNT refactor _get_response_values #21538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from

Conversation

glemaitre
Copy link
Member

@glemaitre glemaitre commented Nov 3, 2021

Partially address #18212
This is a simplification towards merging #20999

@@ -26,72 +24,6 @@
)


def test_errors(pyplot):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those tests are moved in the common test in test_plot_curve_common.py

@@ -634,42 +634,6 @@ def iris_data_binary(iris_data):
return X[y < 2], y[y < 2]


def test_calibration_display_validation(pyplot, iris_data, iris_data_binary):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those tests are moved in the common test in test_plot_curve_common.py

@@ -189,7 +189,7 @@ def test_parallel_execution(data, method, ensemble):
X, y = data
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

base_estimator = LinearSVC(random_state=42)
base_estimator = make_pipeline(StandardScaler(), LinearSVC(random_state=42))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove a bunch of ConvergenceWarning for free

@@ -21,48 +20,6 @@
)


def test_precision_recall_display_validation(pyplot):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those tests are moved in the common test in plot_common_curve_display.py

@glemaitre
Copy link
Member Author

Maybe this would be easier to review @ogrisel @thomasjpfan

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, I like the refactor.

@glemaitre glemaitre force-pushed the is/18212_again_again branch from abea58f to a7e46da Compare November 29, 2021 15:18
glemaitre and others added 3 commits November 29, 2021 16:33
@glemaitre glemaitre removed the cython label Nov 29, 2021
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the follow up!

Comment on lines +280 to +284
target_type = type_of_target(y_true)
if target_type != "binary":
raise ValueError(
f"The target y is not binary. Got {target_type} type of target."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this check here, since we do the check in det_curve?

if len(np.unique(y_true)) != 2:
raise ValueError(
"Only one class present in y_true. Detection error "
"tradeoff curve is not defined in that case."
)

@glemaitre glemaitre added this to the 1.1 milestone Jan 26, 2022
Copy link
Member

@lesteve lesteve left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass of comments

response_method="predict",
target_type=target_type,
)
assert_allclose(y_pred, regressor.predict(X))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one could be checking strict equality right?


Raises
------
ValueError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ValueError
AttributeError

y_pred = y_pred[:, col_idx]
else:
err_msg = (
f"Got predict_proba of shape {y_pred.shape}, but need "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the error message not super explicit, I would find it easier to understand if it was saying the classifier was fitted with only one class (rather that complaining about the shape of the predict_proba). Maybe there are some edge cases I haven't thought of though.

f"one of {classes}"
)
elif pos_label is None and target_type == "binary":
pos_label = pos_label if pos_label is not None else classes[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although -1 and 1 are equivalent I would use 1 since classes[1] is used elsewhere e.g. in _get_response_values docstring:

Suggested change
pos_label = pos_label if pos_label is not None else classes[-1]
pos_label = pos_label if pos_label is not None else classes[1]

Comment on lines +126 to +130
method = (
["predict_proba", "decision_function", "predict"]
if method == "auto"
else method
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit lighter to parse IMO (and also what is used in other places in this PR):

Suggested change
method = (
["predict_proba", "decision_function", "predict"]
if method == "auto"
else method
)
if method == "auto":
method = ["predict_proba", "decision_function", "predict"]

Comment on lines +35 to +36
"'fit' with appropriate arguments before intending to use it to plotting "
"functionalities."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a bit lighter to read (better suggestion welcome):

Suggested change
"'fit' with appropriate arguments before intending to use it to plotting "
"functionalities."
"'fit' with appropriate arguments before using it for plotting "
"functionalities."

@lesteve
Copy link
Member

lesteve commented Apr 6, 2022

The CI is red at the moment, maybe I got something wrong in my merge ...

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment that still needs to be addressed, with the same concern:

The metric checks if the input is binary. If the *Display object checks too, then there is double validation and np.unqiue is called twice.

I guess this is okay.

@@ -330,6 +342,12 @@ def from_predictions(
"""
check_matplotlib_support(f"{cls.__name__}.from_predictions")

target_type = type_of_target(y_true)
if target_type != "binary":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we need this check. roc_curve calls _binary_clf_curve which ends up doing the binary check itself:

y_type = type_of_target(y_true, input_name="y_true")
if not (y_type == "binary" or (y_type == "multiclass" and pos_label is not None)):
raise ValueError("{0} format is not supported".format(y_type))

@jeremiedbb
Copy link
Member

Let's move that to 1.2

@jeremiedbb jeremiedbb modified the milestones: 1.1, 1.2 Apr 19, 2022
@glemaitre glemaitre modified the milestones: 1.2, 1.3 Nov 16, 2022
@glemaitre
Copy link
Member Author

Closing this one since we merged #23073

@glemaitre glemaitre closed this Mar 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
No open projects
Status: Done
Development

Successfully merging this pull request may close these issues.

4 participants