-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] FIX remaining bug in precision, recall and fscore with multilabel data #1988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
068dcc3
1868790
42a1270
98a6b70
cc6963b
6a4a362
645bacc
377a963
06c2c7b
6dafe57
79a0cc9
aa0c47e
a0aa777
a5a026c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -910,11 +910,9 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True): | |
# Compute accuracy for each possible representation | ||
y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred) | ||
if y_type == 'multilabel-indicator': | ||
try: | ||
with np.errstate(divide='ignore', invalid='ignore'): | ||
# oddly, we may get an "invalid" rather than a "divide" | ||
# error here | ||
old_err_settings = np.seterr(divide='ignore', | ||
invalid='ignore') | ||
y_pred_pos_label = y_pred == 1 | ||
y_true_pos_label = y_true == 1 | ||
pred_inter_true = np.sum(np.logical_and(y_pred_pos_label, | ||
|
@@ -929,8 +927,6 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True): | |
# the jaccard to 1: lim_{x->0} x/x = 1 | ||
# Note with py2.6 and np 1.3: we can't check safely for nan. | ||
score[pred_union_true == 0.0] = 1.0 | ||
finally: | ||
np.seterr(**old_err_settings) | ||
|
||
elif y_type == 'multilabel-sequences': | ||
score = np.empty(len(y_true), dtype=np.float) | ||
|
@@ -1448,24 +1444,37 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, | |
size_true[i] = len(true_set) | ||
else: | ||
raise ValueError("Example-based precision, recall, fscore is " | ||
"not meaning full outside multilabe" | ||
"classification. See the accuracy_score instead.") | ||
"not meaningful outside of multilabel" | ||
"classification. Use accuracy_score instead.") | ||
|
||
try: | ||
warning_msg = "" | ||
if np.any(size_pred == 0): | ||
warning_msg += ("Sample-based precision is undefined for some " | ||
"samples. ") | ||
|
||
if np.any(size_true == 0): | ||
warning_msg += ("Sample-based recall is undefined for some " | ||
"samples. ") | ||
|
||
if np.any((beta2 * size_true + size_pred) == 0): | ||
warning_msg += ("Sample-based f_score is undefined for some " | ||
"samples. ") | ||
|
||
if warning_msg: | ||
warnings.warn(warning_msg) | ||
|
||
with np.errstate(divide="ignore", invalid="ignore"): | ||
# oddly, we may get an "invalid" rather than a "divide" error | ||
# here | ||
old_err_settings = np.seterr(divide='ignore', invalid='ignore') | ||
|
||
precision = size_inter / size_true | ||
recall = size_inter / size_pred | ||
f_score = ((1 + beta2 ** 2) * size_inter / | ||
(beta2 * size_pred + size_true)) | ||
finally: | ||
np.seterr(**old_err_settings) | ||
precision = divide(size_inter, size_pred, dtype=np.double) | ||
recall = divide(size_inter, size_true, dtype=np.double) | ||
f_score = divide((1 + beta2) * size_inter, | ||
(beta2 * size_true + size_pred), | ||
dtype=np.double) | ||
|
||
precision[size_true == 0] = 1.0 | ||
recall[size_pred == 0] = 1.0 | ||
f_score[(beta2 * size_pred + size_true) == 0] = 1.0 | ||
precision[size_pred == 0] = 0.0 | ||
recall[size_true == 0] = 0.0 | ||
f_score[(beta2 * size_true + size_pred) == 0] = 0.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Warning messages are not raised for this case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can understand why you might want that to be so, but to me it's a bigger problem that in a multilabel problem with these true values:
For the first example you get 0 recall and 0 precision no matter your prediction, when surely This applies too to label-based averaging, although there such cases are considered pathological. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree. But this means, we have to change the default behaviour... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this means sample-based averaging is not a good idea if some samples have |
||
|
||
precision = np.mean(precision) | ||
recall = np.mean(recall) | ||
|
@@ -1476,26 +1485,50 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, | |
true_pos, _, false_pos, false_neg = _tp_tn_fp_fn(y_true, y_pred, labels) | ||
support = true_pos + false_neg | ||
|
||
try: | ||
with np.errstate(divide='ignore', invalid='ignore'): | ||
# oddly, we may get an "invalid" rather than a "divide" error here | ||
old_err_settings = np.seterr(divide='ignore', invalid='ignore') | ||
|
||
# precision and recall | ||
precision = divide(true_pos.astype(np.float), true_pos + false_pos) | ||
recall = divide(true_pos.astype(np.float), true_pos + false_neg) | ||
|
||
idx_ill_defined_precision = (true_pos + false_pos) == 0 | ||
idx_ill_defined_recall = (true_pos + false_neg) == 0 | ||
|
||
# handle division by 0 in precision and recall | ||
precision[(true_pos + false_pos) == 0] = 0.0 | ||
recall[(true_pos + false_neg) == 0] = 0.0 | ||
precision[idx_ill_defined_precision] = 0.0 | ||
recall[idx_ill_defined_recall] = 0.0 | ||
|
||
# fbeta score | ||
fscore = divide((1 + beta2) * precision * recall, | ||
beta2 * precision + recall) | ||
|
||
# handle division by 0 in fscore | ||
fscore[(beta2 * precision + recall) == 0] = 0.0 | ||
finally: | ||
np.seterr(**old_err_settings) | ||
idx_ill_defined_fbeta_score = (beta2 * precision + recall) == 0 | ||
fscore[idx_ill_defined_fbeta_score] = 0.0 | ||
|
||
if average in (None, "macro", "weighted"): | ||
warning_msg = "" | ||
if np.any(idx_ill_defined_precision): | ||
warning_msg += ("The sum of true positives and false positives " | ||
"are equal to zero for some labels. Precision is " | ||
"ill defined for those labels %s. " | ||
% labels[idx_ill_defined_precision]) | ||
|
||
if np.any(idx_ill_defined_recall): | ||
warning_msg += ("The sum of true positives and false negatives " | ||
"are equal to zero for some labels. Recall is ill " | ||
"defined for those labels %s. " | ||
% labels[idx_ill_defined_recall]) | ||
|
||
if np.any(idx_ill_defined_fbeta_score): | ||
warning_msg += ("The precision and recall are equal to zero for " | ||
"some labels. fbeta_score is ill defined for " | ||
"those labels %s. " | ||
% labels[idx_ill_defined_fbeta_score]) | ||
|
||
if warning_msg: | ||
warnings.warn(warning_msg, stacklevel=2) | ||
|
||
if not average: | ||
return precision, recall, fscore, support | ||
|
@@ -1513,24 +1546,40 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, | |
else: | ||
average_options = (None, 'micro', 'macro', 'weighted', 'samples') | ||
if average == 'micro': | ||
avg_precision = divide(true_pos.sum(), | ||
true_pos.sum() + false_pos.sum(), | ||
dtype=np.double) | ||
avg_recall = divide(true_pos.sum(), | ||
true_pos.sum() + false_neg.sum(), | ||
dtype=np.double) | ||
avg_fscore = divide((1 + beta2) * (avg_precision * avg_recall), | ||
beta2 * avg_precision + avg_recall, | ||
dtype=np.double) | ||
|
||
if np.isnan(avg_precision): | ||
with np.errstate(divide='ignore', invalid='ignore'): | ||
# oddly, we may get an "invalid" rather than a "divide" error | ||
# here | ||
|
||
tp_sum = true_pos.sum() | ||
fp_sum = false_pos.sum() | ||
fn_sum = false_neg.sum() | ||
avg_precision = divide(tp_sum, tp_sum + fp_sum, | ||
dtype=np.double) | ||
avg_recall = divide(tp_sum, tp_sum + fn_sum, dtype=np.double) | ||
avg_fscore = divide((1 + beta2) * (avg_precision * avg_recall), | ||
beta2 * avg_precision + avg_recall, | ||
dtype=np.double) | ||
|
||
warning_msg = "" | ||
if tp_sum + fp_sum == 0: | ||
avg_precision = 0. | ||
warning_msg += ("The sum of true positives and false " | ||
"positives are equal to zero. Micro-precision" | ||
" is ill defined. ") | ||
|
||
if np.isnan(avg_recall): | ||
if tp_sum + fn_sum == 0: | ||
avg_recall = 0. | ||
warning_msg += ("The sum of true positives and false " | ||
"negatives are equal to zero. Micro-recall " | ||
"is ill defined. ") | ||
|
||
if np.isnan(avg_fscore): | ||
if beta2 * avg_precision + avg_recall == 0: | ||
avg_fscore = 0. | ||
warning_msg += ("Micro-precision and micro-recall are equal " | ||
"to zero. Micro-fbeta_score is ill defined.") | ||
|
||
if warning_msg: | ||
warnings.warn(warning_msg, stacklevel=2) | ||
|
||
elif average == 'macro': | ||
avg_precision = np.mean(precision) | ||
|
@@ -1542,6 +1591,11 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, | |
avg_precision = 0. | ||
avg_recall = 0. | ||
avg_fscore = 0. | ||
warnings.warn("There isn't any labels in y_true. " | ||
"Weighted-precision, weighted-recall and " | ||
"weighted-fbeta_score are ill defined.", | ||
stacklevel=2) | ||
|
||
else: | ||
avg_precision = np.average(precision, weights=support) | ||
avg_recall = np.average(recall, weights=support) | ||
|
@@ -1698,6 +1752,7 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'): | |
>>> recall_score(y_true, y_pred, average=None) | ||
array([ 1., 0., 0.]) | ||
|
||
|
||
""" | ||
_, r, _, _ = precision_recall_fscore_support(y_true, y_pred, | ||
labels=labels, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want a "stacklevel" here? I think that stacklevel=2 would be a good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See 0f0f4a3