Skip to content

[MRG + 1] Add fowlkess-mallows and other supervised cluster metrics to SCORERS dict so it can be used in hyper-param search #8117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Usage examples:
>>> model = svm.SVC()
>>> cross_val_score(model, X, y, scoring='wrong_choice')
Traceback (most recent call last):
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc']
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_mutual_info_score', 'adjusted_rand_score', 'average_precision', 'completeness_score', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'fowlkes_mallows_score', 'homogeneity_score', 'mutual_info_score', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'normalized_mutual_info_score', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc', 'v_measure_score']

.. note::

Expand Down
28 changes: 27 additions & 1 deletion sklearn/metrics/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,16 @@
mean_squared_error, mean_squared_log_error, accuracy_score,
f1_score, roc_auc_score, average_precision_score,
precision_score, recall_score, log_loss)

from .cluster import adjusted_rand_score
from .cluster import homogeneity_score
from .cluster import completeness_score
from .cluster import v_measure_score
from .cluster import mutual_info_score
from .cluster import adjusted_mutual_info_score
from .cluster import normalized_mutual_info_score
from .cluster import fowlkes_mallows_score

from ..utils.multiclass import type_of_target
from ..externals import six
from ..base import is_regressor
Expand Down Expand Up @@ -393,6 +402,14 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,

# Clustering scores
adjusted_rand_scorer = make_scorer(adjusted_rand_score)
homogeneity_scorer = make_scorer(homogeneity_score)
completeness_scorer = make_scorer(completeness_score)
v_measure_scorer = make_scorer(v_measure_score)
mutual_info_scorer = make_scorer(mutual_info_score)
adjusted_mutual_info_scorer = make_scorer(adjusted_mutual_info_score)
normalized_mutual_info_scorer = make_scorer(normalized_mutual_info_score)
fowlkes_mallows_scorer = make_scorer(fowlkes_mallows_score)


SCORERS = dict(r2=r2_scorer,
neg_median_absolute_error=neg_median_absolute_error_scorer,
Expand All @@ -406,7 +423,16 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
average_precision=average_precision_scorer,
log_loss=log_loss_scorer,
neg_log_loss=neg_log_loss_scorer,
adjusted_rand_score=adjusted_rand_scorer)
# Cluster metrics that use supervised evaluation
adjusted_rand_score=adjusted_rand_scorer,
homogeneity_score=homogeneity_scorer,
completeness_score=completeness_scorer,
v_measure_score=v_measure_scorer,
mutual_info_score=mutual_info_scorer,
adjusted_mutual_info_score=adjusted_mutual_info_scorer,
normalized_mutual_info_score=normalized_mutual_info_scorer,
fowlkes_mallows_score=fowlkes_mallows_scorer)


for name, metric in [('precision', precision_score),
('recall', recall_score), ('f1', f1_score)]:
Expand Down
29 changes: 19 additions & 10 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sklearn.base import BaseEstimator
from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score,
log_loss, precision_score, recall_score)
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics import cluster as cluster_module
Copy link
Contributor

@tguillemot tguillemot Jan 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer you export directly all the necessary metrics as it's done on the line before.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal.

from sklearn.metrics.scorer import (check_scoring, _PredictScorer,
_passthrough_scorer)
from sklearn.metrics import make_scorer, get_scorer, SCORERS
Expand Down Expand Up @@ -47,9 +47,17 @@
'roc_auc', 'average_precision', 'precision',
'precision_weighted', 'precision_macro', 'precision_micro',
'recall', 'recall_weighted', 'recall_macro', 'recall_micro',
'neg_log_loss', 'log_loss',
'adjusted_rand_score' # not really, but works
]
'neg_log_loss', 'log_loss']

# All supervised cluster scorers (They behave like classification metric)
CLUSTER_SCORERS = ["adjusted_rand_score",
"homogeneity_score",
"completeness_score",
"v_measure_score",
"mutual_info_score",
"adjusted_mutual_info_score",
"normalized_mutual_info_score",
"fowlkes_mallows_score"]

MULTILABEL_ONLY_SCORERS = ['precision_samples', 'recall_samples', 'f1_samples']

Expand All @@ -65,6 +73,7 @@ def _make_estimators(X_train, y_train, y_ml_train):
return dict(
[(name, sensible_regr) for name in REGRESSION_SCORERS] +
[(name, sensible_clf) for name in CLF_SCORERS] +
[(name, sensible_clf) for name in CLUSTER_SCORERS] +
[(name, sensible_ml_clf) for name in MULTILABEL_ONLY_SCORERS]
)

Expand Down Expand Up @@ -330,16 +339,16 @@ def test_thresholded_scorers_multilabel_indicator_data():
assert_almost_equal(score1, score2)


def test_unsupervised_scorers():
def test_supervised_cluster_scorers():
# Test clustering scorers against gold standard labeling.
# We don't have any real unsupervised Scorers yet.
X, y = make_blobs(random_state=0, centers=2)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
km = KMeans(n_clusters=3)
km.fit(X_train)
score1 = get_scorer('adjusted_rand_score')(km, X_test, y_test)
score2 = adjusted_rand_score(y_test, km.predict(X_test))
assert_almost_equal(score1, score2)
for name in CLUSTER_SCORERS:
score1 = get_scorer(name)(km, X_test, y_test)
score2 = getattr(cluster_module, name)(y_test, km.predict(X_test))
assert_almost_equal(score1, score2)


@ignore_warnings
Expand Down Expand Up @@ -445,4 +454,4 @@ def test_scoring_is_not_metric():
assert_raises_regexp(ValueError, 'make_scorer', check_scoring,
Ridge(), r2_score)
assert_raises_regexp(ValueError, 'make_scorer', check_scoring,
KMeans(), adjusted_rand_score)
KMeans(), cluster_module.adjusted_rand_score)
6 changes: 6 additions & 0 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,12 @@ def test_unsupervised_grid_search():
# ARI can find the right number :)
assert_equal(grid_search.best_params_["n_clusters"], 3)

grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]),
scoring='fowlkes_mallows_score')
grid_search.fit(X, y)
# So can FMS ;)
assert_equal(grid_search.best_params_["n_clusters"], 3)

# Now without a score, and without y
grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]))
grid_search.fit(X)
Expand Down