-
-
Notifications
You must be signed in to change notification settings - Fork 26k
Fixes #16065 Ignore weights equal zero in precision recall curve #16319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes #16065 Ignore weights equal zero in precision recall curve #16319
Conversation
sklearn/metrics/_ranking.py
Outdated
indexes_zeros = sample_weight.index(0) | ||
del y_true[indexes_zeros] | ||
del probas_pred[indexes_zeros] | ||
del sample_weight[indexes_zeros] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All those arguments are more expected to be numpy arrays instead of lists. So instead we should write numpy code such as follows:
if sample_weight is not None:
if not np.all(sample_weight >= 0)
raise ValueError("negative values in sample_weight are invalid")
nonzero_weight_mask = sample_weight > 0
y_true = y_true[nonzero_mask]
probas_pred = probas_pred[nonzero_mask]
sample_weight = sample_weight[nonzero_mask]
But this should only be done after we call check_array
, column_or_1d
or similar validation functions on all the arguments. Those checks are centralized in the _binary_clf_curve
function itself, so the zero weight filtering code should probably be moved to the body of that function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much, Olivier. I have added a check_negative at _check_sample_weight which partially answers #15531. I have added some lines to check that it's not zero.
We might want to move the check for negative sample weights in |
Hi @alonsosilvaallende could you please check the lint errors? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More suggestions:
@@ -1172,7 +1172,7 @@ def _check_psd_eigenvalues(lambdas, enable_warnings=False): | |||
return lambdas | |||
|
|||
|
|||
def _check_sample_weight(sample_weight, X, dtype=None): | |||
def _check_sample_weight(sample_weight, X, dtype=None, check_negative=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the naming check_negative
confusing. What about:
def _check_sample_weight(sample_weight, X, dtype=None, check_negative=False): | |
def _check_sample_weight(sample_weight, X, dtype=None, check_nonnegative=False): |
@@ -544,6 +545,13 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): | |||
if sample_weight is not None: | |||
sample_weight = column_or_1d(sample_weight) | |||
|
|||
# Check to make sure sample_weight is strictly positive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Check to make sure sample_weight is strictly positive | |
# Check to make sure sample_weight are non negative and filter out zero | |
# weighted samples as they should not impact the score value |
I take over this PR. Should I merge or rebase onto master? |
@albertvillanova you can find here some documentation about how to take over stalled pull requests. Hope this will help you. |
Thanks @cmarmo. I have already pulled it. My question was, once pulled, should I merge or rebase onto master. I will do whatever is better for you to manage my eventual own PR. |
We prefer to merge. |
Thank you, @thomasjpfan. Indeed, rebasing was generating conflicts, while merging worked without problems. |
Take |
Fixes #16065.
What does this implement/fix? Explain your changes.
It ignores weights equal to zero when called by precision_recall_curve function.
Any other comments?
I'm not sure if there is a more elegant solution.