-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Description
Noticed while working on #30508 (comment)
Currently the following metric functions do not explicitly check that pos_label
is present in y_true
:
roc_curve
precision_recall_curve
det_curve
brier_score_loss
AFAICT all (?) other classification metrics (e.g., recall_score
, precision_score
), including ranking metric average_precision_score
explicitly check that pos_label
is present in y_true
:
e.g. this is the error from recall_score
/precision_score
/f1
family:
if y_type == "binary":
if len(present_labels) == 2 and pos_label not in present_labels:
> raise ValueError(
f"pos_label={pos_label} is not a valid label. It should be "
f"one of {present_labels}"
)
E ValueError: pos_label=2 is not a valid label. It should be one of [0, 1]
roc_curve
and precision_recall_curve
do not explicitly check this, they do warn (no error) that there are no 'positive' samples in y_true
:
if tps[-1] <= 0:
> warnings.warn(
"No positive samples in y_true, true positive value should be meaningless",
UndefinedMetricWarning,
)
E sklearn.exceptions.UndefinedMetricWarning: No positive samples in y_true, true positive value should be meaningless
Similarly, for det_curve
this results in an invalid divide warning (we divide by 0):
File ~/Documents/dev/scikit-learn/sklearn/metrics/_ranking.py:418, in det_curve(y_true, y_score, pos_label, sample_weight, drop_intermediate)
415 sl = slice(first_ind, last_ind)
417 # reverse the output such that list of false positives is decreasing
--> 418 return (fps[sl][::-1] / n_count, fns[sl][::-1] / p_count, thresholds[sl][::-1])
RuntimeWarning: invalid value encountered in divide
brier_score_loss
gives no error and no warning. _validate_binary_probabilistic_prediction
does the following
scikit-learn/sklearn/metrics/_classification.py
Lines 3468 to 3469 in bde701d
# convert (n_samples,) to (n_samples, 2) shape | |
y_true = np.array(y_true == pos_label, int) |
which just results in y_true
being all 0's.
For completeness I'm just going to reference #18101, even though I don't think it directly affects this issue; regardless of pos_label
meaning, the question of whether to check it is present in y_true
is the same. Just a reference for those who were involved in that discussion, e.g., @glemaitre 😬