diff --git a/doc/modules/clustering.rst b/doc/modules/clustering.rst index e895be80042e8..1410b50d6518b 100644 --- a/doc/modules/clustering.rst +++ b/doc/modules/clustering.rst @@ -1158,8 +1158,8 @@ Given the knowledge of the ground truth class assignments ``labels_true`` and our clustering algorithm assignments of the same samples ``labels_pred``, the **Mutual Information** is a function that measures the **agreement** of the two assignments, ignoring permutations. Two different normalized versions of this -measure are available, **Normalized Mutual Information(NMI)** and **Adjusted -Mutual Information(AMI)**. NMI is often used in the literature while AMI was +measure are available, **Normalized Mutual Information (NMI)** and **Adjusted +Mutual Information (AMI)**. NMI is often used in the literature, while AMI was proposed more recently and is **normalized against chance**:: >>> from sklearn import metrics @@ -1212,17 +1212,11 @@ Advantages for any value of ``n_clusters`` and ``n_samples`` (which is not the case for raw Mutual Information or the V-measure for instance). -- **Bounded range [0, 1]**: Values close to zero indicate two label +- **Upper bound of 1**: Values close to zero indicate two label assignments that are largely independent, while values close to one - indicate significant agreement. Further, values of exactly 0 indicate - **purely** independent label assignments and a AMI of exactly 1 indicates + indicate significant agreement. Further, an AMI of exactly 1 indicates that the two label assignments are equal (with or without permutation). -- **No assumption is made on the cluster structure**: can be used - to compare clustering algorithms such as k-means which assumes isotropic - blob shapes with results of spectral clustering algorithms which can - find cluster with "folded" shapes. - Drawbacks ~~~~~~~~~ @@ -1274,7 +1268,7 @@ It also can be expressed in set cardinality formulation: The normalized mutual information is defined as -.. math:: \text{NMI}(U, V) = \frac{\text{MI}(U, V)}{\sqrt{H(U)H(V)}} +.. math:: \text{NMI}(U, V) = \frac{\text{MI}(U, V)}{\text{mean}(H(U), H(V))} This value of the mutual information and also the normalized variant is not adjusted for chance and will tend to increase as the number of different labels @@ -1282,7 +1276,7 @@ adjusted for chance and will tend to increase as the number of different labels between the label assignments. The expected value for the mutual information can be calculated using the -following equation, from Vinh, Epps, and Bailey, (2009). In this equation, +following equation [VEB2009]_. In this equation, :math:`a_i = |U_i|` (the number of elements in :math:`U_i`) and :math:`b_j = |V_j|` (the number of elements in :math:`V_j`). @@ -1295,7 +1289,19 @@ following equation, from Vinh, Epps, and Bailey, (2009). In this equation, Using the expected value, the adjusted mutual information can then be calculated using a similar form to that of the adjusted Rand index: -.. math:: \text{AMI} = \frac{\text{MI} - E[\text{MI}]}{\max(H(U), H(V)) - E[\text{MI}]} +.. math:: \text{AMI} = \frac{\text{MI} - E[\text{MI}]}{\text{mean}(H(U), H(V)) - E[\text{MI}]} + +For normalized mutual information and adjusted mutual information, the normalizing +value is typically some *generalized* mean of the entropies of each clustering. +Various generalized means exist, and no firm rules exist for preferring one over the +others. The decision is largely a field-by-field basis; for instance, in community +detection, the arithmetic mean is most common. Each +normalizing method provides "qualitatively similar behaviours" [YAT2016]_. In our +implementation, this is controlled by the ``average_method`` parameter. + +Vinh et al. (2010) named variants of NMI and AMI by their averaging method [VEB2010]_. Their +'sqrt' and 'sum' averages are the geometric and arithmetic means; we use these +more broadly common names. .. topic:: References @@ -1304,22 +1310,29 @@ calculated using a similar form to that of the adjusted Rand index: Machine Learning Research 3: 583–617. `doi:10.1162/153244303321897735 `_. - * Vinh, Epps, and Bailey, (2009). "Information theoretic measures + * [VEB2009] Vinh, Epps, and Bailey, (2009). "Information theoretic measures for clusterings comparison". Proceedings of the 26th Annual International Conference on Machine Learning - ICML '09. `doi:10.1145/1553374.1553511 `_. ISBN 9781605585161. - * Vinh, Epps, and Bailey, (2010). Information Theoretic Measures for + * [VEB2010] Vinh, Epps, and Bailey, (2010). "Information Theoretic Measures for Clusterings Comparison: Variants, Properties, Normalization and - Correction for Chance, JMLR - http://jmlr.csail.mit.edu/papers/volume11/vinh10a/vinh10a.pdf + Correction for Chance". JMLR + * `Wikipedia entry for the (normalized) Mutual Information `_ * `Wikipedia entry for the Adjusted Mutual Information `_ + + * [YAT2016] Yang, Algesheimer, and Tessone, (2016). "A comparative analysis of + community + detection algorithms on artificial networks". Scientific Reports 6: 30750. + `doi:10.1038/srep30750 `_. + + .. _homogeneity_completeness: @@ -1359,7 +1372,7 @@ Their harmonic mean called **V-measure** is computed by 0.51... The V-measure is actually equivalent to the mutual information (NMI) -discussed above normalized by the sum of the label entropies [B2011]_. +discussed above, with the aggregation function being the arithmetic mean [B2011]_. Homogeneity, completeness and V-measure can be computed at once using :func:`homogeneity_completeness_v_measure` as follows:: @@ -1534,7 +1547,7 @@ Advantages for any value of ``n_clusters`` and ``n_samples`` (which is not the case for raw Mutual Information or the V-measure for instance). -- **Bounded range [0, 1]**: Values close to zero indicate two label +- **Upper-bounded at 1**: Values close to zero indicate two label assignments that are largely independent, while values close to one indicate significant agreement. Further, values of exactly 0 indicate **purely** independent label assignments and a AMI of exactly 1 indicates diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 5b9216926b834..dca1157ac0368 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -201,6 +201,12 @@ Metrics :func:`metrics.roc_auc_score`. :issue:`3273` by :user:`Alexander Niederbühl `. +- Added control over the normalization in + :func:`metrics.normalized_mutual_information_score` and + :func:`metrics.adjusted_mutual_information_score` via the ``average_method`` + parameter. In version 0.22, the default normalizer for each will become + the *arithmetic* mean of the entropies of each clustering. :issue:`11124` by + :user:`Arya McCarthy `. - Added ``output_dict`` parameter in :func:`metrics.classification_report` to return classification statistics as dictionary. :issue:`11160` by :user:`Dan Barkhorn `. @@ -768,6 +774,17 @@ Metrics due to floating point error in the input. :issue:`9851` by :user:`Hanmin Qin `. +- In :func:`metrics.normalized_mutual_information_score` and + :func:`metrics.adjusted_mutual_information_score`, + warn that ``average_method`` + will have a new default value. In version 0.22, the default normalizer for each + will become the *arithmetic* mean of the entropies of each clustering. Currently, + :func:`metrics.normalized_mutual_information_score` uses the default of + ``average_method='geometric'``, and :func:`metrics.adjusted_mutual_information_score` + uses the default of ``average_method='max'`` to match their behaviors in + version 0.19. + :issue:`11124` by :user:`Arya McCarthy `. + - The ``batch_size`` parameter to :func:`metrics.pairwise_distances_argmin_min` and :func:`metrics.pairwise_distances_argmin` is deprecated to be removed in v0.22. It no longer has any effect, as batch size is determined by global diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 381f51777b6ae..13addf29fdc00 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -11,11 +11,13 @@ # Thierry Guillemot # Gregory Stupp # Joel Nothman +# Arya McCarthy # License: BSD 3 clause from __future__ import division from math import log +import warnings import numpy as np from scipy import sparse as sp @@ -59,6 +61,21 @@ def check_clusterings(labels_true, labels_pred): return labels_true, labels_pred +def _generalized_average(U, V, average_method): + """Return a particular mean of two numbers.""" + if average_method == "min": + return min(U, V) + elif average_method == "geometric": + return np.sqrt(U * V) + elif average_method == "arithmetic": + return np.mean([U, V]) + elif average_method == "max": + return max(U, V) + else: + raise ValueError("'average_method' must be 'min', 'geometric', " + "'arithmetic', or 'max'") + + def contingency_matrix(labels_true, labels_pred, eps=None, sparse=False): """Build a contingency matrix describing the relationship between labels. @@ -245,7 +262,9 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred): V-Measure is furthermore symmetric: swapping ``labels_true`` and ``label_pred`` will give the same score. This does not hold for - homogeneity and completeness. + homogeneity and completeness. V-Measure is identical to + :func:`normalized_mutual_info_score` with the arithmetic averaging + method. Read more in the :ref:`User Guide `. @@ -444,7 +463,8 @@ def completeness_score(labels_true, labels_pred): def v_measure_score(labels_true, labels_pred): """V-measure cluster labeling given a ground truth. - This score is identical to :func:`normalized_mutual_info_score`. + This score is identical to :func:`normalized_mutual_info_score` with + the ``'arithmetic'`` option for averaging. The V-measure is the harmonic mean between homogeneity and completeness:: @@ -459,6 +479,7 @@ def v_measure_score(labels_true, labels_pred): measure the agreement of two independent label assignments strategies on the same dataset when the real ground truth is not known. + Read more in the :ref:`User Guide `. Parameters @@ -485,6 +506,7 @@ def v_measure_score(labels_true, labels_pred): -------- homogeneity_score completeness_score + normalized_mutual_info_score Examples -------- @@ -617,7 +639,8 @@ def mutual_info_score(labels_true, labels_pred, contingency=None): return mi.sum() -def adjusted_mutual_info_score(labels_true, labels_pred): +def adjusted_mutual_info_score(labels_true, labels_pred, + average_method='warn'): """Adjusted Mutual Information between two clusterings. Adjusted Mutual Information (AMI) is an adjustment of the Mutual @@ -626,7 +649,7 @@ def adjusted_mutual_info_score(labels_true, labels_pred): clusters, regardless of whether there is actually more information shared. For two clusterings :math:`U` and :math:`V`, the AMI is given as:: - AMI(U, V) = [MI(U, V) - E(MI(U, V))] / [max(H(U), H(V)) - E(MI(U, V))] + AMI(U, V) = [MI(U, V) - E(MI(U, V))] / [avg(H(U), H(V)) - E(MI(U, V))] This metric is independent of the absolute values of the labels: a permutation of the class or cluster label values won't change the @@ -650,9 +673,17 @@ def adjusted_mutual_info_score(labels_true, labels_pred): labels_pred : array, shape = [n_samples] A clustering of the data into disjoint subsets. + average_method : string, optional (default: 'warn') + How to compute the normalizer in the denominator. Possible options + are 'min', 'geometric', 'arithmetic', and 'max'. + If 'warn', 'max' will be used. The default will change to + 'arithmetic' in version 0.22. + + .. versionadded:: 0.20 + Returns ------- - ami: float(upperlimited by 1.0) + ami: float (upperlimited by 1.0) The AMI returns a value of 1 when the two partitions are identical (ie perfectly matched). Random partitions (independent labellings) have an expected AMI around 0 on average hence can be negative. @@ -691,6 +722,12 @@ def adjusted_mutual_info_score(labels_true, labels_pred): `_ """ + if average_method == 'warn': + warnings.warn("The behavior of AMI will change in version 0.22. " + "To match the behavior of 'v_measure_score', AMI will " + "use average_method='arithmetic' by default.", + FutureWarning) + average_method = 'max' labels_true, labels_pred = check_clusterings(labels_true, labels_pred) n_samples = labels_true.shape[0] classes = np.unique(labels_true) @@ -709,17 +746,29 @@ def adjusted_mutual_info_score(labels_true, labels_pred): emi = expected_mutual_information(contingency, n_samples) # Calculate entropy for each labeling h_true, h_pred = entropy(labels_true), entropy(labels_pred) - ami = (mi - emi) / (max(h_true, h_pred) - emi) + normalizer = _generalized_average(h_true, h_pred, average_method) + denominator = normalizer - emi + # Avoid 0.0 / 0.0 when expectation equals maximum, i.e a perfect match. + # normalizer should always be >= emi, but because of floating-point + # representation, sometimes emi is slightly larger. Correct this + # by preserving the sign. + if denominator < 0: + denominator = min(denominator, -np.finfo('float64').eps) + else: + denominator = max(denominator, np.finfo('float64').eps) + ami = (mi - emi) / denominator return ami -def normalized_mutual_info_score(labels_true, labels_pred): +def normalized_mutual_info_score(labels_true, labels_pred, + average_method='warn'): """Normalized Mutual Information between two clusterings. Normalized Mutual Information (NMI) is an normalization of the Mutual Information (MI) score to scale the results between 0 (no mutual information) and 1 (perfect correlation). In this function, mutual - information is normalized by ``sqrt(H(labels_true) * H(labels_pred))``. + information is normalized by some generalized mean of ``H(labels_true)`` + and ``H(labels_pred))``, defined by the `average_method`. This measure is not adjusted for chance. Therefore :func:`adjusted_mustual_info_score` might be preferred. @@ -743,6 +792,14 @@ def normalized_mutual_info_score(labels_true, labels_pred): labels_pred : array, shape = [n_samples] A clustering of the data into disjoint subsets. + average_method : string, optional (default: 'warn') + How to compute the normalizer in the denominator. Possible options + are 'min', 'geometric', 'arithmetic', and 'max'. + If 'warn', 'geometric' will be used. The default will change to + 'arithmetic' in version 0.22. + + .. versionadded:: 0.20 + Returns ------- nmi : float @@ -750,6 +807,7 @@ def normalized_mutual_info_score(labels_true, labels_pred): See also -------- + v_measure_score: V-Measure (NMI with arithmetic mean option.) adjusted_rand_score: Adjusted Rand Index adjusted_mutual_info_score: Adjusted Mutual Information (adjusted against chance) @@ -773,6 +831,12 @@ def normalized_mutual_info_score(labels_true, labels_pred): 0.0 """ + if average_method == 'warn': + warnings.warn("The behavior of NMI will change in version 0.22. " + "To match the behavior of 'v_measure_score', NMI will " + "use average_method='arithmetic' by default.", + FutureWarning) + average_method = 'geometric' labels_true, labels_pred = check_clusterings(labels_true, labels_pred) classes = np.unique(labels_true) clusters = np.unique(labels_pred) @@ -789,7 +853,10 @@ def normalized_mutual_info_score(labels_true, labels_pred): # Calculate the expected value for the mutual information # Calculate entropy for each labeling h_true, h_pred = entropy(labels_true), entropy(labels_pred) - nmi = mi / max(np.sqrt(h_true * h_pred), 1e-10) + normalizer = _generalized_average(h_true, h_pred, average_method) + # Avoid 0.0 / 0.0 when either entropy is zero. + normalizer = max(normalizer, np.finfo('float64').eps) + nmi = mi / normalizer return nmi diff --git a/sklearn/metrics/cluster/tests/test_supervised.py b/sklearn/metrics/cluster/tests/test_supervised.py index 8be39cd220d2a..46b95cfd8fda4 100644 --- a/sklearn/metrics/cluster/tests/test_supervised.py +++ b/sklearn/metrics/cluster/tests/test_supervised.py @@ -12,10 +12,12 @@ from sklearn.metrics.cluster import mutual_info_score from sklearn.metrics.cluster import normalized_mutual_info_score from sklearn.metrics.cluster import v_measure_score +from sklearn.metrics.cluster.supervised import _generalized_average from sklearn.utils import assert_all_finite from sklearn.utils.testing import ( assert_equal, assert_almost_equal, assert_raise_message, + assert_warns_message, ignore_warnings ) from numpy.testing import assert_array_almost_equal @@ -30,6 +32,18 @@ ] +def test_future_warning(): + score_funcs_with_changing_means = [ + normalized_mutual_info_score, + adjusted_mutual_info_score, + ] + warning_msg = "The behavior of " + args = [0, 0, 0], [0, 0, 0] + for score_func in score_funcs_with_changing_means: + assert_warns_message(FutureWarning, warning_msg, score_func, *args) + + +@ignore_warnings(category=FutureWarning) def test_error_messages_on_wrong_input(): for score_func in score_funcs: expected = ('labels_true and labels_pred must have same size,' @@ -46,6 +60,17 @@ def test_error_messages_on_wrong_input(): [0, 1, 0], [[1, 1], [0, 0]]) +def test_generalized_average(): + a, b = 1, 2 + methods = ["min", "geometric", "arithmetic", "max"] + means = [_generalized_average(a, b, method) for method in methods] + assert means[0] <= means[1] <= means[2] <= means[3] + c, d = 12, 12 + means = [_generalized_average(c, d, method) for method in methods] + assert means[0] == means[1] == means[2] == means[3] + + +@ignore_warnings(category=FutureWarning) def test_perfect_matches(): for score_func in score_funcs: assert_equal(score_func([], []), 1.0) @@ -55,6 +80,20 @@ def test_perfect_matches(): assert_equal(score_func([0., 1., 0.], [42., 7., 42.]), 1.0) assert_equal(score_func([0., 1., 2.], [42., 7., 2.]), 1.0) assert_equal(score_func([0, 1, 2], [42, 7, 2]), 1.0) + score_funcs_with_changing_means = [ + normalized_mutual_info_score, + adjusted_mutual_info_score, + ] + means = {"min", "geometric", "arithmetic", "max"} + for score_func in score_funcs_with_changing_means: + for mean in means: + assert score_func([], [], mean) == 1.0 + assert score_func([0], [1], mean) == 1.0 + assert score_func([0, 0, 0], [0, 0, 0], mean) == 1.0 + assert score_func([0, 1, 0], [42, 7, 42], mean) == 1.0 + assert score_func([0., 1., 0.], [42., 7., 42.], mean) == 1.0 + assert score_func([0., 1., 2.], [42., 7., 2.], mean) == 1.0 + assert score_func([0, 1, 2], [42, 7, 2], mean) == 1.0 def test_homogeneous_but_not_complete_labeling(): @@ -87,7 +126,7 @@ def test_not_complete_and_not_homogeneous_labeling(): assert_almost_equal(v, 0.52, 2) -def test_non_consicutive_labels(): +def test_non_consecutive_labels(): # regression tests for labels with gaps h, c, v = homogeneity_completeness_v_measure( [0, 0, 0, 2, 2, 2], @@ -109,6 +148,7 @@ def test_non_consicutive_labels(): assert_almost_equal(ari_2, 0.24, 2) +@ignore_warnings(category=FutureWarning) def uniform_labelings_scores(score_func, n_samples, k_range, n_runs=10, seed=42): # Compute score for random uniform cluster labelings @@ -122,6 +162,7 @@ def uniform_labelings_scores(score_func, n_samples, k_range, n_runs=10, return scores +@ignore_warnings(category=FutureWarning) def test_adjustment_for_chance(): # Check that adjusted scores are almost zero on random labels n_clusters_range = [2, 10, 50, 90] @@ -135,6 +176,7 @@ def test_adjustment_for_chance(): assert_array_almost_equal(max_abs_scores, [0.02, 0.03, 0.03, 0.02], 2) +@ignore_warnings(category=FutureWarning) def test_adjusted_mutual_info_score(): # Compute the Adjusted Mutual Information and test against known values labels_a = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) @@ -215,6 +257,7 @@ def test_contingency_matrix_sparse(): eps=1e-10, sparse=True) +@ignore_warnings(category=FutureWarning) def test_exactly_zero_info_score(): # Check numerical stability when information is exactly zero for i in np.logspace(1, 4, 4).astype(np.int): @@ -224,6 +267,11 @@ def test_exactly_zero_info_score(): assert_equal(v_measure_score(labels_a, labels_b), 0.0) assert_equal(adjusted_mutual_info_score(labels_a, labels_b), 0.0) assert_equal(normalized_mutual_info_score(labels_a, labels_b), 0.0) + for method in ["min", "geometric", "arithmetic", "max"]: + assert adjusted_mutual_info_score(labels_a, labels_b, + method) == 0.0 + assert normalized_mutual_info_score(labels_a, labels_b, + method) == 0.0 def test_v_measure_and_mutual_information(seed=36): @@ -235,6 +283,11 @@ def test_v_measure_and_mutual_information(seed=36): assert_almost_equal(v_measure_score(labels_a, labels_b), 2.0 * mutual_info_score(labels_a, labels_b) / (entropy(labels_a) + entropy(labels_b)), 0) + avg = 'arithmetic' + assert_almost_equal(v_measure_score(labels_a, labels_b), + normalized_mutual_info_score(labels_a, labels_b, + average_method=avg) + ) def test_fowlkes_mallows_score():