Skip to content

TST add binary and multiclass test for scorers #18904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jan 28, 2021
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
afe66d9
removed 'log_loss' from CLF_SCORERS list as there is no 'log_loss' sc…
efiegel Nov 24, 2020
95ec70e
added general multiclass test for all multiclass scorers with a logis…
efiegel Nov 24, 2020
9405ed1
added tree classifier in addition to logistic regresion classifier. s…
efiegel Nov 25, 2020
30f3f5d
fixed causes of linting errors
efiegel Nov 25, 2020
867249f
fixed causes of linting errors (try #2)
efiegel Nov 25, 2020
a930412
fixed causes of linting errors (try 2). ignore doc tests that are fai…
efiegel Nov 25, 2020
2011f24
fixed causes of linting errors (try 2). ignore doc tests that are fai…
efiegel Nov 26, 2020
41fcda4
fixed causes of linting errors (try 2)
efiegel Nov 26, 2020
64e784a
fixed causes of linting errors (try 2)
efiegel Nov 26, 2020
6009551
parameterized multiclass score test
efiegel Dec 27, 2020
6fe5883
removed BINARY_SCORERS_ONLY since it wasn't used. renamed test_classi…
efiegel Dec 27, 2020
90421eb
moved scorer pickling test to a separate test so that test_classifica…
efiegel Dec 27, 2020
3357945
refactored test_classification_binary_scores for parameterized input
efiegel Dec 27, 2020
ab4c1e9
flake8 tweaks
efiegel Dec 27, 2020
ed7d555
Apply suggestions from code review
efiegel Jan 11, 2021
3139090
Merge remote-tracking branch 'upstream/main' into binary-classificati…
efiegel Jan 24, 2021
9ab3c16
removed clf from parameterized arguments in test_classification_multi…
efiegel Jan 24, 2021
0a33048
removed clf from parameterized arguments in test_classification_multi…
efiegel Jan 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 70 additions & 29 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from sklearn.base import BaseEstimator
from sklearn.metrics import (
accuracy_score,
balanced_accuracy_score,
average_precision_score,
brier_score_loss,
f1_score,
Expand All @@ -28,13 +30,13 @@
r2_score,
recall_score,
roc_auc_score,
top_k_accuracy_score
)
from sklearn.metrics import cluster as cluster_module
from sklearn.metrics import check_scoring
from sklearn.metrics._scorer import (_PredictScorer, _passthrough_scorer,
_MultimetricScorer,
_check_multimetric_scoring)
from sklearn.metrics import accuracy_score
from sklearn.metrics import make_scorer, get_scorer, SCORERS
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import LinearSVC
Expand Down Expand Up @@ -68,7 +70,7 @@
'roc_auc', 'average_precision', 'precision',
'precision_weighted', 'precision_macro', 'precision_micro',
'recall', 'recall_weighted', 'recall_macro', 'recall_micro',
'neg_log_loss', 'log_loss', 'neg_brier_score',
'neg_log_loss', 'neg_brier_score',
'jaccard', 'jaccard_weighted', 'jaccard_macro',
'jaccard_micro', 'roc_auc_ovr', 'roc_auc_ovo',
'roc_auc_ovr_weighted', 'roc_auc_ovo_weighted']
Expand Down Expand Up @@ -306,46 +308,85 @@ def test_make_scorer():
make_scorer(f, needs_threshold=True, needs_proba=True)


def test_classification_scores():
# Test classification scorers.
@pytest.mark.parametrize('scorer_name, metric', [
('f1', f1_score),
('f1_weighted', partial(f1_score, average='weighted')),
('f1_macro', partial(f1_score, average='macro')),
('f1_micro', partial(f1_score, average='micro')),
('precision', precision_score),
('precision_weighted', partial(precision_score, average='weighted')),
('precision_macro', partial(precision_score, average='macro')),
('precision_micro', partial(precision_score, average='micro')),
('recall', recall_score),
('recall_weighted', partial(recall_score, average='weighted')),
('recall_macro', partial(recall_score, average='macro')),
('recall_micro', partial(recall_score, average='micro')),
('jaccard', jaccard_score),
('jaccard_weighted', partial(jaccard_score, average='weighted')),
('jaccard_macro', partial(jaccard_score, average='macro')),
('jaccard_micro', partial(jaccard_score, average='micro')),
('top_k_accuracy', top_k_accuracy_score),
])
def test_classification_binary_scores(scorer_name, metric):
# check consistency between score and scorer for scores supporting
# binary classification.
X, y = make_blobs(random_state=0, centers=2)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = LinearSVC(random_state=0)
clf.fit(X_train, y_train)

for prefix, metric in [('f1', f1_score), ('precision', precision_score),
('recall', recall_score),
('jaccard', jaccard_score)]:
score = SCORERS[scorer_name](clf, X_test, y_test)
expected_score = metric(y_test, clf.predict(X_test))
assert_almost_equal(score, expected_score)

score1 = get_scorer('%s_weighted' % prefix)(clf, X_test, y_test)
score2 = metric(y_test, clf.predict(X_test), pos_label=None,
average='weighted')
assert_almost_equal(score1, score2)

score1 = get_scorer('%s_macro' % prefix)(clf, X_test, y_test)
score2 = metric(y_test, clf.predict(X_test), pos_label=None,
average='macro')
assert_almost_equal(score1, score2)
@pytest.mark.parametrize('scorer_name, metric', [
('accuracy', accuracy_score),
('balanced_accuracy', balanced_accuracy_score),
('f1_weighted', partial(f1_score, average='weighted')),
('f1_macro', partial(f1_score, average='macro')),
('f1_micro', partial(f1_score, average='micro')),
('precision_weighted', partial(precision_score, average='weighted')),
('precision_macro', partial(precision_score, average='macro')),
('precision_micro', partial(precision_score, average='micro')),
('recall_weighted', partial(recall_score, average='weighted')),
('recall_macro', partial(recall_score, average='macro')),
('recall_micro', partial(recall_score, average='micro')),
('jaccard_weighted', partial(jaccard_score, average='weighted')),
('jaccard_macro', partial(jaccard_score, average='macro')),
('jaccard_micro', partial(jaccard_score, average='micro')),
])
def test_classification_multiclass_scores(scorer_name, metric):
# check consistency between score and scorer for scores supporting
# multiclass classification.
X, y = make_classification(
n_classes=3, n_informative=3, n_samples=30, random_state=0
)

score1 = get_scorer('%s_micro' % prefix)(clf, X_test, y_test)
score2 = metric(y_test, clf.predict(X_test), pos_label=None,
average='micro')
assert_almost_equal(score1, score2)
# use `stratify` = y to ensure train and test sets capture all classes
X_train, X_test, y_train, y_test = train_test_split(
X, y, random_state=0, stratify=y
)

score1 = get_scorer('%s' % prefix)(clf, X_test, y_test)
score2 = metric(y_test, clf.predict(X_test), pos_label=1)
assert_almost_equal(score1, score2)
clf = DecisionTreeClassifier(random_state=0)
clf.fit(X_train, y_train)
score = SCORERS[scorer_name](clf, X_test, y_test)
expected_score = metric(y_test, clf.predict(X_test))
assert score == pytest.approx(expected_score)

# test fbeta score that takes an argument
scorer = make_scorer(fbeta_score, beta=2)
score1 = scorer(clf, X_test, y_test)
score2 = fbeta_score(y_test, clf.predict(X_test), beta=2)
assert_almost_equal(score1, score2)

def test_custom_scorer_pickling():
# test that custom scorer can be pickled
X, y = make_blobs(random_state=0, centers=2)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = LinearSVC(random_state=0)
clf.fit(X_train, y_train)

scorer = make_scorer(fbeta_score, beta=2)
score1 = scorer(clf, X_test, y_test)
unpickled_scorer = pickle.loads(pickle.dumps(scorer))
score3 = unpickled_scorer(clf, X_test, y_test)
assert_almost_equal(score1, score3)
score2 = unpickled_scorer(clf, X_test, y_test)
assert score1 == pytest.approx(score2)

# smoke test the repr:
repr(fbeta_score)
Expand Down