-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG]Clusters-Class auto-match #10604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2213018
ef2ec2b
5fab323
8402cf2
e97ef2b
625ced6
46fd79f
d1bdae1
07be65c
2b01392
b5116b7
81804bf
4ba0cac
d4aab23
29b7154
2588538
6e436d8
c3d1ea5
10892dc
8e12b2a
63dd108
c50284f
5c8c3f3
fefe91c
c10f69c
0225c32
bd31fb9
cd3a09d
1539a0b
9d39163
1e08a8b
51a68cb
01e9253
e11cd91
320858d
0a65cc2
4847df3
42a7a1d
cdbe4a3
1474d53
f582884
7227ab5
e1686d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
# Thierry Guillemot <thierry.guillemot.work@gmail.com> | ||
# Gregory Stupp <stuppie@gmail.com> | ||
# Joel Nothman <joel.nothman@gmail.com> | ||
# Lucas Pugens Fernandes <lpfernandes@gmail.com> | ||
# License: BSD 3 clause | ||
|
||
from __future__ import division | ||
|
@@ -22,7 +23,9 @@ | |
|
||
from .expected_mutual_info_fast import expected_mutual_information | ||
from ...utils.validation import check_array | ||
from ...utils.multiclass import unique_labels | ||
from ...utils.fixes import comb | ||
from ...utils.linear_assignment_ import linear_assignment | ||
|
||
|
||
def comb2(n): | ||
|
@@ -871,3 +874,71 @@ def entropy(labels): | |
# log(a / b) should be calculated as log(a) - log(b) for | ||
# possible loss of precision | ||
return -np.sum((pi / pi_sum) * (np.log(pi) - log(pi_sum))) | ||
|
||
|
||
def map_cluster_labels(labels_true, labels_pred): | ||
"""Translate prediction labels to maximize the accuracy. | ||
|
||
Translate the prediction labels of a clustering output to those in the | ||
ground truth to enable calc of external metrics (eg. accuracy, f1_score, | ||
...). Translation is done by maximization of the confusion matrix :math:`C` | ||
main diagonal sum :math:`\sum{i=0}^{K}C_{i, i}`. | ||
|
||
Parameters | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank line before this |
||
---------- | ||
labels_true : array, shape = [n_samples] | ||
Ground truth (correct) target values. | ||
labels_pred : array, shape = [n_samples] | ||
Estimated clusters as returned by a clustering algorithm. | ||
|
||
Returns | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank line before this |
||
------- | ||
trans : array, shape = [n_classes, n_classes] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not what you return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd be tempted to return three arrays:
Each is of length There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm.. I've now realised you're producing something slightly different to what I was thinking. And fair enough. Perhaps you're right that we should produce your more usable/understandable interface, but it is not sufficient to implement something like CEAF :( |
||
Mapping of labels_pred clusters, such that :math:`trans\subseteq | ||
labels_true` | ||
|
||
References | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank line before this |
||
---------- | ||
|
||
Examples | ||
-------- | ||
>>> from sklearn.metrics import confusion_matrix | ||
>>> from sklearn.metrics.cluster import map_cluster_labels | ||
>>> labels_true = ["class1", "class2", "class3", "class1", "class1", | ||
... "class3"] | ||
>>> labels_pred = [0, 0, 2, 2, 0, 2] | ||
>>> y_pred_translated = map_cluster_labels(labels_true, labels_pred) | ||
>>> y_pred_translated | ||
['class1', 'class1', 'class3', 'class3', 'class1', 'class3'] | ||
>>> confusion_matrix(labels_true, y_pred_translated) | ||
array([[2, 0, 1], | ||
[1, 0, 0], | ||
[0, 0, 2]]) | ||
""" | ||
|
||
classes = unique_labels(labels_true).tolist() | ||
n_classes = len(classes) | ||
clusters = unique_labels(labels_pred).tolist() | ||
n_clusters = len(clusters) | ||
|
||
if n_clusters > n_classes: | ||
classes += ['DEFAULT_LABEL_'+str(i) for i in | ||
range(n_clusters-n_classes)] | ||
elif n_classes > n_clusters: | ||
clusters += ['DEFAULT_CLUSTER_'+str(i) for i in | ||
range(n_classes-n_clusters)] | ||
|
||
C = contingency_matrix(labels_true, labels_pred) | ||
true_idx, pred_idx = linear_assignment(-C).T | ||
|
||
true_idx = true_idx.tolist() | ||
pred_idx = pred_idx.tolist() | ||
|
||
true_idx = [classes[idx] for idx in true_idx] | ||
true_idx = true_idx + sorted(set(classes) - set(true_idx)) | ||
pred_idx = [clusters[idx] for idx in pred_idx] | ||
pred_idx = pred_idx + sorted(set(clusters) - set(pred_idx)) | ||
|
||
return_list = [true_idx[pred_idx.index(y)] for y in labels_pred] | ||
|
||
return return_list |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,10 +12,11 @@ | |
from sklearn.metrics.cluster import mutual_info_score | ||
from sklearn.metrics.cluster import normalized_mutual_info_score | ||
from sklearn.metrics.cluster import v_measure_score | ||
from sklearn.metrics.cluster import map_cluster_labels | ||
|
||
from sklearn.utils import assert_all_finite | ||
from sklearn.utils.testing import ( | ||
assert_equal, assert_almost_equal, assert_raise_message, | ||
assert_equal, assert_almost_equal, assert_raise_message, | ||
) | ||
from numpy.testing import assert_array_almost_equal | ||
|
||
|
@@ -275,3 +276,33 @@ def test_fowlkes_mallows_score_properties(): | |
# symmetric and permutation(both together) | ||
score_both = fowlkes_mallows_score(labels_b, (labels_a + 2) % 3) | ||
assert_almost_equal(score_both, expected) | ||
|
||
|
||
def test_map_cluster_labels(): | ||
# handcrafted example - same number of clusters and classes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unless there is a reason to use a large example, it is best to use something in tests that a reader can easily reason about, i.e. a small example. Besides, linear_assignment is tested: your job here is to check you've used it correctly. |
||
y_true = ['a', 'b', 'b', 'c', 'c', 'a'] | ||
y_pred = [1, 0, 0, 1, 2, 1] | ||
|
||
expected = ['a', 'b', 'b', 'a', 'c', 'a'] | ||
|
||
y_pred_translated = map_cluster_labels(y_true, y_pred) | ||
assert_equal(y_pred_translated, expected) | ||
|
||
# handcrafted example - more clusters than classes | ||
y_true = ['a', 'a', 'a', 'b', 'b', 'b'] | ||
y_pred = [4, 0, 1, 1, 2, 2] | ||
|
||
expected = ['DEFAULT_LABEL_1', 'a', 'DEFAULT_LABEL_0', 'DEFAULT_LABEL_0', | ||
'b', 'b'] | ||
|
||
y_pred_translated = map_cluster_labels(y_true, y_pred) | ||
assert_equal(y_pred_translated, expected) | ||
|
||
# handcrafted example - more classes than clusters | ||
y_true = ['a', 'd', 'e', 'b', 'b', 'b'] | ||
y_pred = [0, 0, -1, -1, 2, 2] | ||
|
||
expected = ['a', 'a', 'e', 'e', 'b', 'b'] | ||
|
||
y_pred_translated = map_cluster_labels(y_true, y_pred) | ||
assert_equal(y_pred_translated, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Find the best mapping between true and predicted cluster labels"