Skip to content

average_precision_score breaks on string labels #12312

@amueller

Description

@amueller
import numpy as np
from sklearn.metrics import average_precision_score
probs = np.array([0.41722746, 0.07162791, 0.41722746, 0.07162791, 0.69208494,
                  0.69208494, 0.40750916, 0.18227092, 0.40750916, 0.07162791])
labels = np.array(['No', 'No', 'Yes', 'No', 'Yes', 'Yes', 'No', 'No', 'Yes', 'No'])

average_precision_score(labels, probs)

TypeError: 'bool' object is not subscriptable

That's not very helpful. Fixed in #12313

Casting to dtype object (as coming from pandas):

import numpy as np
from sklearn.metrics import average_precision_score
probs = np.array([0.41722746, 0.07162791, 0.41722746, 0.07162791, 0.69208494,
                  0.69208494, 0.40750916, 0.18227092, 0.40750916, 0.07162791])
labels = np.array(['No', 'No', 'Yes', 'No', 'Yes', 'Yes', 'No', 'No', 'Yes', 'No'], dtype=object)

average_precision_score(labels, probs)
RuntimeWarning: invalid value encountered in true_divide
  recall = tps / tps[-1]

np.NaN

That's terrible.... Fixed in #12313

What I actually did was

probs = np.array([0.41722746, 0.07162791, 0.41722746, 0.07162791, 0.69208494,
                  0.69208494, 0.40750916, 0.18227092, 0.40750916, 0.07162791])
labels = np.array(['No', 'No', 'Yes', 'No', 'Yes', 'Yes', 'No', 'No', 'Yes', 'No'])

average_precision_score(labels, probs, pos_label='yes')  # TYPO
RuntimeWarning: invalid value encountered in true_divide
  recall = tps / tps[-1]

np.NaN

that's also terrible...

Originally I used cross-validation, which is arguably worse:

from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
data = load_breast_cancer()

target = np.array(["yes", "no"], dtype="object")[data.target]
cross_val_score(DecisionTreeClassifier(max_depth=3), data.data, target, scoring='average_precision', cv=5)
/home/andy/checkout/scikit-learn/sklearn/metrics/ranking.py:521: RuntimeWarning: invalid value encountered in true_divide
  recall = tps / tps[-1]
/home/andy/checkout/scikit-learn/sklearn/metrics/ranking.py:521: RuntimeWarning: invalid value encountered in true_divide
  recall = tps / tps[-1]
/home/andy/checkout/scikit-learn/sklearn/metrics/ranking.py:521: RuntimeWarning: invalid value encountered in true_divide
  recall = tps / tps[-1]
/home/andy/checkout/scikit-learn/sklearn/metrics/ranking.py:521: RuntimeWarning: invalid value encountered in true_divide
  recall = tps / tps[-1]
/home/andy/checkout/scikit-learn/sklearn/metrics/ranking.py:521: RuntimeWarning: invalid value encountered in true_divide
  recall = tps / tps[-1]
array([nan, nan, nan, nan, nan])

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions