Skip to content

Commit 2237b17

Browse files
qinhanmin2014adrinjalali
authored andcommitted
MNT Consistent warning and more doc about the edge cases of P/R/F (#13143)
* consistent error message * new test * ignore warnings * notes * joel's comment * adrin's comment
1 parent 39ef674 commit 2237b17

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

sklearn/metrics/classification.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,11 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
922922
>>> f1_score(y_true, y_pred, average=None)
923923
array([0.8, 0. , 0. ])
924924
925+
Notes
926+
-----
927+
When ``true positive + false positive == 0`` or
928+
``true positive + false negative == 0``, f-score returns 0 and raises
929+
``UndefinedMetricWarning``.
925930
"""
926931
return fbeta_score(y_true, y_pred, 1, labels=labels,
927932
pos_label=pos_label, average=average,
@@ -1036,6 +1041,11 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
10361041
... # doctest: +ELLIPSIS
10371042
array([0.71..., 0. , 0. ])
10381043
1044+
Notes
1045+
-----
1046+
When ``true positive + false positive == 0`` or
1047+
``true positive + false negative == 0``, f-score returns 0 and raises
1048+
``UndefinedMetricWarning``.
10391049
"""
10401050
_, _, f, _ = precision_recall_fscore_support(y_true, y_pred,
10411051
beta=beta,
@@ -1233,6 +1243,12 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
12331243
array([0., 0., 1.]), array([0. , 0. , 0.8]),
12341244
array([2, 2, 2]))
12351245
1246+
Notes
1247+
-----
1248+
When ``true positive + false positive == 0``, precision is undefined;
1249+
When ``true positive + false negative == 0``, recall is undefined.
1250+
In such cases, the metric will be set to 0, as will f-score, and
1251+
``UndefinedMetricWarning`` will be raised.
12361252
"""
12371253
average_options = (None, 'micro', 'macro', 'weighted', 'samples')
12381254
if average not in average_options and average != 'binary':
@@ -1247,13 +1263,9 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
12471263

12481264
if average == 'binary':
12491265
if y_type == 'binary':
1250-
if pos_label not in present_labels:
1251-
if len(present_labels) < 2:
1252-
# Only negative labels
1253-
return (0., 0., 0., 0)
1254-
else:
1255-
raise ValueError("pos_label=%r is not a valid label: %r" %
1256-
(pos_label, present_labels))
1266+
if pos_label not in present_labels and len(present_labels) >= 2:
1267+
raise ValueError("pos_label=%r is not a valid label: %r" %
1268+
(pos_label, present_labels))
12571269
labels = [pos_label]
12581270
else:
12591271
raise ValueError("Target is %s but average='binary'. Please "
@@ -1279,7 +1291,6 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
12791291
true_sum = np.array([true_sum.sum()])
12801292

12811293
# Finally, we have all our sufficient statistics. Divide! #
1282-
12831294
beta2 = beta ** 2
12841295
with np.errstate(divide='ignore', invalid='ignore'):
12851296
# Divide, and on zero-division, set scores to 0 and warn:
@@ -1297,7 +1308,6 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
12971308
f_score[tp_sum == 0] = 0.0
12981309

12991310
# Average the results
1300-
13011311
if average == 'weighted':
13021312
weights = true_sum
13031313
if weights.sum() == 0:
@@ -1410,6 +1420,10 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
14101420
>>> precision_score(y_true, y_pred, average=None) # doctest: +ELLIPSIS
14111421
array([0.66..., 0. , 0. ])
14121422
1423+
Notes
1424+
-----
1425+
When ``true positive + false positive == 0``, precision returns 0 and
1426+
raises ``UndefinedMetricWarning``.
14131427
"""
14141428
p, _, _, _ = precision_recall_fscore_support(y_true, y_pred,
14151429
labels=labels,
@@ -1512,6 +1526,10 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
15121526
>>> recall_score(y_true, y_pred, average=None)
15131527
array([1., 0., 0.])
15141528
1529+
Notes
1530+
-----
1531+
When ``true positive + false negative == 0``, recall returns 0 and raises
1532+
``UndefinedMetricWarning``.
15151533
"""
15161534
_, r, _, _ = precision_recall_fscore_support(y_true, y_pred,
15171535
labels=labels,

sklearn/metrics/tests/test_classification.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def test_precision_recall_f1_score_binary():
198198
(1 + 2 ** 2) * ps * rs / (2 ** 2 * ps + rs), 2)
199199

200200

201+
@ignore_warnings
201202
def test_precision_recall_f_binary_single_class():
202203
# Test precision, recall and F1 score behave with a single positive or
203204
# negative class
@@ -1065,6 +1066,7 @@ def test_classification_report_no_labels_target_names_unequal_length():
10651066
y_true, y_pred, target_names=target_names)
10661067

10671068

1069+
@ignore_warnings
10681070
def test_multilabel_classification_report():
10691071
n_classes = 4
10701072
n_samples = 50
@@ -1446,6 +1448,17 @@ def test_prf_warnings():
14461448
'being set to 0.0 due to no true samples.')
14471449
my_assert(w, msg, f, [-1, -1], [1, 1], average='binary')
14481450

1451+
clean_warning_registry()
1452+
with warnings.catch_warnings(record=True) as record:
1453+
warnings.simplefilter('always')
1454+
precision_recall_fscore_support([0, 0], [0, 0], average="binary")
1455+
msg = ('Recall and F-score are ill-defined and '
1456+
'being set to 0.0 due to no true samples.')
1457+
assert_equal(str(record.pop().message), msg)
1458+
msg = ('Precision and F-score are ill-defined and '
1459+
'being set to 0.0 due to no predicted samples.')
1460+
assert_equal(str(record.pop().message), msg)
1461+
14491462

14501463
def test_recall_warnings():
14511464
assert_no_warnings(recall_score,
@@ -1461,19 +1474,26 @@ def test_recall_warnings():
14611474
assert_equal(str(record.pop().message),
14621475
'Recall is ill-defined and '
14631476
'being set to 0.0 due to no true samples.')
1477+
recall_score([0, 0], [0, 0])
1478+
assert_equal(str(record.pop().message),
1479+
'Recall is ill-defined and '
1480+
'being set to 0.0 due to no true samples.')
14641481

14651482

14661483
def test_precision_warnings():
14671484
clean_warning_registry()
14681485
with warnings.catch_warnings(record=True) as record:
14691486
warnings.simplefilter('always')
1470-
14711487
precision_score(np.array([[1, 1], [1, 1]]),
14721488
np.array([[0, 0], [0, 0]]),
14731489
average='micro')
14741490
assert_equal(str(record.pop().message),
14751491
'Precision is ill-defined and '
14761492
'being set to 0.0 due to no predicted samples.')
1493+
precision_score([0, 0], [0, 0])
1494+
assert_equal(str(record.pop().message),
1495+
'Precision is ill-defined and '
1496+
'being set to 0.0 due to no predicted samples.')
14771497

14781498
assert_no_warnings(precision_score,
14791499
np.array([[0, 0], [0, 0]]),
@@ -1499,6 +1519,13 @@ def test_fscore_warnings():
14991519
assert_equal(str(record.pop().message),
15001520
'F-score is ill-defined and '
15011521
'being set to 0.0 due to no true samples.')
1522+
score([0, 0], [0, 0])
1523+
assert_equal(str(record.pop().message),
1524+
'F-score is ill-defined and '
1525+
'being set to 0.0 due to no true samples.')
1526+
assert_equal(str(record.pop().message),
1527+
'F-score is ill-defined and '
1528+
'being set to 0.0 due to no predicted samples.')
15021529

15031530

15041531
def test_prf_average_binary_data_non_binary():

0 commit comments

Comments
 (0)