Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions sklearn/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,11 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
>>> f1_score(y_true, y_pred, average=None)
array([0.8, 0. , 0. ])

Notes
-----
When ``true positive + false positive == 0`` or
``true positive + false negative == 0``, f-score returns 0 and raises
``UndefinedMetricWarning``.
"""
return fbeta_score(y_true, y_pred, 1, labels=labels,
pos_label=pos_label, average=average,
Expand Down Expand Up @@ -1036,6 +1041,11 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
... # doctest: +ELLIPSIS
array([0.71..., 0. , 0. ])

Notes
-----
When ``true positive + false positive == 0`` or
``true positive + false negative == 0``, f-score returns 0 and raises
``UndefinedMetricWarning``.
"""
_, _, f, _ = precision_recall_fscore_support(y_true, y_pred,
beta=beta,
Expand Down Expand Up @@ -1233,6 +1243,12 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
array([0., 0., 1.]), array([0. , 0. , 0.8]),
array([2, 2, 2]))

Notes
-----
When ``true positive + false positive == 0``, precision is undefined;
When ``true positive + false negative == 0``, recall is undefined.
In such cases, the metric will be set to 0, as will f-score, and
``UndefinedMetricWarning`` will be raised.
"""
average_options = (None, 'micro', 'macro', 'weighted', 'samples')
if average not in average_options and average != 'binary':
Expand All @@ -1247,13 +1263,9 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,

if average == 'binary':
if y_type == 'binary':
if pos_label not in present_labels:
if len(present_labels) < 2:
# Only negative labels
return (0., 0., 0., 0)
else:
raise ValueError("pos_label=%r is not a valid label: %r" %
(pos_label, present_labels))
if pos_label not in present_labels and len(present_labels) >= 2:
raise ValueError("pos_label=%r is not a valid label: %r" %
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused here. This is the only functional part which is changing, but all the new tests are only testing the warning messages. There's no new test which test this change in behavior (i.e. not returning 0,0,0,0).

Doesn't this need a whats_new entry?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused here. This is the only functional part which is changing, but all the new tests are only testing the warning messages. There's no new test which test this change in behavior (i.e. not returning 0,0,0,0).

We now calculate these values manually and raise consistent warning in _prf_divide. I guess we don't need a what's new here since the users will only get some extra warnings.

(pos_label, present_labels))
labels = [pos_label]
else:
raise ValueError("Target is %s but average='binary'. Please "
Expand All @@ -1279,7 +1291,6 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
true_sum = np.array([true_sum.sum()])

# Finally, we have all our sufficient statistics. Divide! #

beta2 = beta ** 2
with np.errstate(divide='ignore', invalid='ignore'):
# Divide, and on zero-division, set scores to 0 and warn:
Expand All @@ -1297,7 +1308,6 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
f_score[tp_sum == 0] = 0.0

# Average the results

if average == 'weighted':
weights = true_sum
if weights.sum() == 0:
Expand Down Expand Up @@ -1410,6 +1420,10 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
>>> precision_score(y_true, y_pred, average=None) # doctest: +ELLIPSIS
array([0.66..., 0. , 0. ])

Notes
-----
When ``true positive + false positive == 0``, precision returns 0 and
raises ``UndefinedMetricWarning``.
"""
p, _, _, _ = precision_recall_fscore_support(y_true, y_pred,
labels=labels,
Expand Down Expand Up @@ -1512,6 +1526,10 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
>>> recall_score(y_true, y_pred, average=None)
array([1., 0., 0.])

Notes
-----
When ``true positive + false negative == 0``, recall returns 0 and raises
``UndefinedMetricWarning``.
"""
_, r, _, _ = precision_recall_fscore_support(y_true, y_pred,
labels=labels,
Expand Down
29 changes: 28 additions & 1 deletion sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def test_precision_recall_f1_score_binary():
(1 + 2 ** 2) * ps * rs / (2 ** 2 * ps + rs), 2)


@ignore_warnings
def test_precision_recall_f_binary_single_class():
# Test precision, recall and F1 score behave with a single positive or
# negative class
Expand Down Expand Up @@ -1065,6 +1066,7 @@ def test_classification_report_no_labels_target_names_unequal_length():
y_true, y_pred, target_names=target_names)


@ignore_warnings
def test_multilabel_classification_report():
n_classes = 4
n_samples = 50
Expand Down Expand Up @@ -1446,6 +1448,17 @@ def test_prf_warnings():
'being set to 0.0 due to no true samples.')
my_assert(w, msg, f, [-1, -1], [1, 1], average='binary')

clean_warning_registry()
with warnings.catch_warnings(record=True) as record:
warnings.simplefilter('always')
precision_recall_fscore_support([0, 0], [0, 0], average="binary")
msg = ('Recall and F-score are ill-defined and '
'being set to 0.0 due to no true samples.')
assert_equal(str(record.pop().message), msg)
msg = ('Precision and F-score are ill-defined and '
'being set to 0.0 due to no predicted samples.')
assert_equal(str(record.pop().message), msg)


def test_recall_warnings():
assert_no_warnings(recall_score,
Expand All @@ -1461,19 +1474,26 @@ def test_recall_warnings():
assert_equal(str(record.pop().message),
'Recall is ill-defined and '
'being set to 0.0 due to no true samples.')
recall_score([0, 0], [0, 0])
assert_equal(str(record.pop().message),
'Recall is ill-defined and '
'being set to 0.0 due to no true samples.')


def test_precision_warnings():
clean_warning_registry()
with warnings.catch_warnings(record=True) as record:
warnings.simplefilter('always')

precision_score(np.array([[1, 1], [1, 1]]),
np.array([[0, 0], [0, 0]]),
average='micro')
assert_equal(str(record.pop().message),
'Precision is ill-defined and '
'being set to 0.0 due to no predicted samples.')
precision_score([0, 0], [0, 0])
assert_equal(str(record.pop().message),
'Precision is ill-defined and '
'being set to 0.0 due to no predicted samples.')

assert_no_warnings(precision_score,
np.array([[0, 0], [0, 0]]),
Expand All @@ -1499,6 +1519,13 @@ def test_fscore_warnings():
assert_equal(str(record.pop().message),
'F-score is ill-defined and '
'being set to 0.0 due to no true samples.')
score([0, 0], [0, 0])
assert_equal(str(record.pop().message),
'F-score is ill-defined and '
'being set to 0.0 due to no true samples.')
assert_equal(str(record.pop().message),
'F-score is ill-defined and '
'being set to 0.0 due to no predicted samples.')


def test_prf_average_binary_data_non_binary():
Expand Down