diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 7eaedfbdc6eb2..bb43b52cbbf5e 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -330,6 +330,11 @@ Enhancements (`#7248 _`) By `Andreas Müller`_. + - Support sparse contingency matrices in cluster evaluation + (:mod:`metrics.cluster.supervised`) and use these by default. + (`#7419 _`) + By `Gregory Stupp`_ and `Joel Nothman`_. + Bug fixes ......... @@ -4543,3 +4548,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson. .. _Sebastián Vanrell: https://github.com/srvanrell .. _Robert McGibbon: https://github.com/rmcgibbo + +.. _Gregory Stupp: https://github.com/stuppie diff --git a/sklearn/metrics/cluster/expected_mutual_info_fast.pyx b/sklearn/metrics/cluster/expected_mutual_info_fast.pyx index d0c08be8d238d..ddb735be59f86 100644 --- a/sklearn/metrics/cluster/expected_mutual_info_fast.pyx +++ b/sklearn/metrics/cluster/expected_mutual_info_fast.pyx @@ -28,8 +28,8 @@ def expected_mutual_information(contingency, int n_samples): #cdef np.ndarray[int, ndim=2] start, end R, C = contingency.shape N = n_samples - a = np.sum(contingency, axis=1).astype(np.int32) - b = np.sum(contingency, axis=0).astype(np.int32) + a = np.ravel(contingency.sum(axis=1).astype(np.int32)) + b = np.ravel(contingency.sum(axis=0).astype(np.int32)) # There are three major terms to the EMI equation, which are multiplied to # and then summed over varying nij values. # While nijs[0] will never be used, having it simplifies the indexing. diff --git a/sklearn/metrics/cluster/supervised.py b/sklearn/metrics/cluster/supervised.py index 131c14b5078ca..b5b054623dfa7 100644 --- a/sklearn/metrics/cluster/supervised.py +++ b/sklearn/metrics/cluster/supervised.py @@ -9,16 +9,21 @@ # Diego Molla # Arnaud Fouchet # Thierry Guillemot +# Gregory Stupp +# Joel Nothman # License: BSD 3 clause +from __future__ import division + from math import log -from scipy.misc import comb -from scipy.sparse import coo_matrix import numpy as np +from scipy.misc import comb +from scipy import sparse as sp from .expected_mutual_info_fast import expected_mutual_information from ...utils.fixes import bincount +from ...utils.validation import check_array def comb2(n): @@ -46,7 +51,7 @@ def check_clusterings(labels_true, labels_pred): return labels_true, labels_pred -def contingency_matrix(labels_true, labels_pred, eps=None, max_n_classes=5000): +def contingency_matrix(labels_true, labels_pred, eps=None, sparse=False): """Build a contingency matrix describing the relationship between labels. Parameters @@ -57,52 +62,55 @@ def contingency_matrix(labels_true, labels_pred, eps=None, max_n_classes=5000): labels_pred : array, shape = [n_samples] Cluster labels to evaluate - eps: None or float + eps : None or float, optional. If a float, that value is added to all values in the contingency matrix. This helps to stop NaN propagation. If ``None``, nothing is adjusted. - max_n_classes : int, optional (default=5000) - Maximal number of classeses handled for contingency_matrix. - This help to avoid Memory error with regression target - for mutual_information. + sparse : boolean, optional. + If True, return a sparse CSR continency matrix. If ``eps is not None``, + and ``sparse is True``, will throw ValueError. + + .. versionadded:: 0.18 Returns ------- - contingency: array, shape=[n_classes_true, n_classes_pred] + contingency : {array-like, sparse}, shape=[n_classes_true, n_classes_pred] Matrix :math:`C` such that :math:`C_{i, j}` is the number of samples in true class :math:`i` and in predicted class :math:`j`. If ``eps is None``, the dtype of this array will be integer. If ``eps`` is given, the dtype will be float. + Will be a ``scipy.sparse.csr_matrix`` if ``sparse=True``. """ + + if eps is not None and sparse: + raise ValueError("Cannot set 'eps' when sparse=True") + classes, class_idx = np.unique(labels_true, return_inverse=True) clusters, cluster_idx = np.unique(labels_pred, return_inverse=True) n_classes = classes.shape[0] n_clusters = clusters.shape[0] - if n_classes > max_n_classes: - raise ValueError("Too many classes for a clustering metric. If you " - "want to increase the limit, pass parameter " - "max_n_classes to the scoring function") - if n_clusters > max_n_classes: - raise ValueError("Too many clusters for a clustering metric. If you " - "want to increase the limit, pass parameter " - "max_n_classes to the scoring function") # Using coo_matrix to accelerate simple histogram calculation, # i.e. bins are consecutive integers # Currently, coo_matrix is faster than histogram2d for simple cases - contingency = coo_matrix((np.ones(class_idx.shape[0]), - (class_idx, cluster_idx)), - shape=(n_classes, n_clusters), - dtype=np.int).toarray() - if eps is not None: - # don't use += as contingency is integer - contingency = contingency + eps + contingency = sp.coo_matrix((np.ones(class_idx.shape[0]), + (class_idx, cluster_idx)), + shape=(n_classes, n_clusters), + dtype=np.int) + if sparse: + contingency = contingency.tocsr() + contingency.sum_duplicates() + else: + contingency = contingency.toarray() + if eps is not None: + # don't use += as contingency is integer + contingency = contingency + eps return contingency # clustering measures -def adjusted_rand_score(labels_true, labels_pred, max_n_classes=5000): +def adjusted_rand_score(labels_true, labels_pred): """Rand index adjusted for chance. The Rand Index computes a similarity measure between two clusterings @@ -134,11 +142,6 @@ def adjusted_rand_score(labels_true, labels_pred, max_n_classes=5000): labels_pred : array, shape = [n_samples] Cluster labels to evaluate - max_n_classes: int, optional (default=5000) - Maximal number of classes handled by the adjusted_rand_score - metric. Setting it too high can lead to MemoryError or OS - freeze - Returns ------- ari : float @@ -190,31 +193,29 @@ def adjusted_rand_score(labels_true, labels_pred, max_n_classes=5000): """ labels_true, labels_pred = check_clusterings(labels_true, labels_pred) n_samples = labels_true.shape[0] - classes = np.unique(labels_true) - clusters = np.unique(labels_pred) + n_classes = np.unique(labels_true).shape[0] + n_clusters = np.unique(labels_pred).shape[0] + # Special limit cases: no clustering since the data is not split; # or trivial clustering where each document is assigned a unique cluster. # These are perfect matches hence return 1.0. - if (classes.shape[0] == clusters.shape[0] == 1 or - classes.shape[0] == clusters.shape[0] == 0 or - classes.shape[0] == clusters.shape[0] == len(labels_true)): + if (n_classes == n_clusters == 1 or + n_classes == n_clusters == 0 or + n_classes == n_clusters == n_samples): return 1.0 - contingency = contingency_matrix(labels_true, labels_pred, - max_n_classes=max_n_classes) - # Compute the ARI using the contingency data - sum_comb_c = sum(comb2(n_c) for n_c in contingency.sum(axis=1)) - sum_comb_k = sum(comb2(n_k) for n_k in contingency.sum(axis=0)) + contingency = contingency_matrix(labels_true, labels_pred, sparse=True) + sum_comb_c = sum(comb2(n_c) for n_c in np.ravel(contingency.sum(axis=1))) + sum_comb_k = sum(comb2(n_k) for n_k in np.ravel(contingency.sum(axis=0))) + sum_comb = sum(comb2(n_ij) for n_ij in contingency.data) - sum_comb = sum(comb2(n_ij) for n_ij in contingency.flatten()) - prod_comb = (sum_comb_c * sum_comb_k) / float(comb(n_samples, 2)) + prod_comb = (sum_comb_c * sum_comb_k) / comb(n_samples, 2) mean_comb = (sum_comb_k + sum_comb_c) / 2. - return ((sum_comb - prod_comb) / (mean_comb - prod_comb)) + return (sum_comb - prod_comb) / (mean_comb - prod_comb) -def homogeneity_completeness_v_measure(labels_true, labels_pred, - max_n_classes=5000): +def homogeneity_completeness_v_measure(labels_true, labels_pred): """Compute the homogeneity and completeness and V-Measure scores at once. Those metrics are based on normalized conditional entropy measures of @@ -248,11 +249,6 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred, labels_pred : array, shape = [n_samples] cluster labels to evaluate - max_n_classes: int, optional (default=5000) - Maximal number of classes handled by the adjusted_rand_score - metric. Setting it too high can lead to MemoryError or OS - freeze - Returns ------- homogeneity: float @@ -278,8 +274,8 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred, entropy_C = entropy(labels_true) entropy_K = entropy(labels_pred) - MI = mutual_info_score(labels_true, labels_pred, - max_n_classes=max_n_classes) + contingency = contingency_matrix(labels_true, labels_pred, sparse=True) + MI = mutual_info_score(None, None, contingency=contingency) homogeneity = MI / (entropy_C) if entropy_C else 1.0 completeness = MI / (entropy_K) if entropy_K else 1.0 @@ -293,7 +289,7 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred, return homogeneity, completeness, v_measure_score -def homogeneity_score(labels_true, labels_pred, max_n_classes=5000): +def homogeneity_score(labels_true, labels_pred): """Homogeneity metric of a cluster labeling given a ground truth. A clustering result satisfies homogeneity if all of its clusters @@ -317,11 +313,6 @@ def homogeneity_score(labels_true, labels_pred, max_n_classes=5000): labels_pred : array, shape = [n_samples] cluster labels to evaluate - max_n_classes: int, optional (default=5000) - Maximal number of classes handled by the adjusted_rand_score - metric. Setting it too high can lead to MemoryError or OS - freeze - Returns ------- homogeneity: float @@ -369,11 +360,10 @@ def homogeneity_score(labels_true, labels_pred, max_n_classes=5000): 0.0... """ - return homogeneity_completeness_v_measure(labels_true, labels_pred, - max_n_classes)[0] + return homogeneity_completeness_v_measure(labels_true, labels_pred)[0] -def completeness_score(labels_true, labels_pred, max_n_classes=5000): +def completeness_score(labels_true, labels_pred): """Completeness metric of a cluster labeling given a ground truth. A clustering result satisfies completeness if all the data points @@ -397,11 +387,6 @@ def completeness_score(labels_true, labels_pred, max_n_classes=5000): labels_pred : array, shape = [n_samples] cluster labels to evaluate - max_n_classes: int, optional (default=5000) - Maximal number of classes handled by the adjusted_rand_score - metric. Setting it too high can lead to MemoryError or OS - freeze - Returns ------- completeness: float @@ -445,11 +430,10 @@ def completeness_score(labels_true, labels_pred, max_n_classes=5000): 0.0 """ - return homogeneity_completeness_v_measure(labels_true, labels_pred, - max_n_classes)[1] + return homogeneity_completeness_v_measure(labels_true, labels_pred)[1] -def v_measure_score(labels_true, labels_pred, max_n_classes=5000): +def v_measure_score(labels_true, labels_pred): """V-measure cluster labeling given a ground truth. This score is identical to :func:`normalized_mutual_info_score`. @@ -477,11 +461,6 @@ def v_measure_score(labels_true, labels_pred, max_n_classes=5000): labels_pred : array, shape = [n_samples] cluster labels to evaluate - max_n_classes: int, optional (default=5000) - Maximal number of classes handled by the adjusted_rand_score - metric. Setting it too high can lead to MemoryError or OS - freeze - Returns ------- v_measure: float @@ -546,12 +525,10 @@ def v_measure_score(labels_true, labels_pred, max_n_classes=5000): 0.0... """ - return homogeneity_completeness_v_measure(labels_true, labels_pred, - max_n_classes)[2] + return homogeneity_completeness_v_measure(labels_true, labels_pred)[2] -def mutual_info_score(labels_true, labels_pred, contingency=None, - max_n_classes=5000): +def mutual_info_score(labels_true, labels_pred, contingency=None): """Mutual Information between two clusterings. The Mutual Information is a measure of the similarity between two labels of @@ -586,16 +563,12 @@ def mutual_info_score(labels_true, labels_pred, contingency=None, labels_pred : array, shape = [n_samples] A clustering of the data into disjoint subsets. - contingency: None or array, shape = [n_classes_true, n_classes_pred] + contingency : {None, array, sparse matrix}, + shape = [n_classes_true, n_classes_pred] A contingency matrix given by the :func:`contingency_matrix` function. If value is ``None``, it will be computed, otherwise the given value is used, with ``labels_true`` and ``labels_pred`` ignored. - max_n_classes: int, optional (default=5000) - Maximal number of classes handled by the mutual_info_score - metric. Setting it too high can lead to MemoryError or OS - freeze - Returns ------- mi: float @@ -608,27 +581,37 @@ def mutual_info_score(labels_true, labels_pred, contingency=None, """ if contingency is None: labels_true, labels_pred = check_clusterings(labels_true, labels_pred) - contingency = contingency_matrix(labels_true, labels_pred, - max_n_classes=max_n_classes) - contingency = np.array(contingency, dtype='float') - contingency_sum = np.sum(contingency) - pi = np.sum(contingency, axis=1) - pj = np.sum(contingency, axis=0) - outer = np.outer(pi, pj) - nnz = contingency != 0.0 - # normalized contingency - contingency_nm = contingency[nnz] - log_contingency_nm = np.log(contingency_nm) - contingency_nm /= contingency_sum - # log(a / b) should be calculated as log(a) - log(b) for - # possible loss of precision - log_outer = -np.log(outer[nnz]) + log(pi.sum()) + log(pj.sum()) + contingency = contingency_matrix(labels_true, labels_pred, sparse=True) + else: + contingency = check_array(contingency, + accept_sparse=['csr', 'csc', 'coo'], + dtype=[int, np.int32, np.int64]) + + if isinstance(contingency, np.ndarray): + # For an array + nzx, nzy = np.nonzero(contingency) + nz_val = contingency[nzx, nzy] + elif sp.issparse(contingency): + # For a sparse matrix + nzx, nzy, nz_val = sp.find(contingency) + else: + raise ValueError("Unsupported type for 'contingency': %s" % + type(contingency)) + + contingency_sum = contingency.sum() + pi = np.ravel(contingency.sum(axis=1)) + pj = np.ravel(contingency.sum(axis=0)) + log_contingency_nm = np.log(nz_val) + contingency_nm = nz_val / contingency_sum + # Don't need to calculate the full outer product, just for non-zeroes + outer = pi.take(nzx) * pj.take(nzy) + log_outer = -np.log(outer) + log(pi.sum()) + log(pj.sum()) mi = (contingency_nm * (log_contingency_nm - log(contingency_sum)) + contingency_nm * log_outer) return mi.sum() -def adjusted_mutual_info_score(labels_true, labels_pred, max_n_classes=5000): +def adjusted_mutual_info_score(labels_true, labels_pred): """Adjusted Mutual Information between two clusterings. Adjusted Mutual Information (AMI) is an adjustment of the Mutual @@ -661,11 +644,6 @@ def adjusted_mutual_info_score(labels_true, labels_pred, max_n_classes=5000): labels_pred : array, shape = [n_samples] A clustering of the data into disjoint subsets. - max_n_classes: int, optional (default=5000) - Maximal number of classes handled by the adjusted_rand_score - metric. Setting it too high can lead to MemoryError or OS - freeze - Returns ------- ami: float(upperlimited by 1.0) @@ -716,9 +694,8 @@ def adjusted_mutual_info_score(labels_true, labels_pred, max_n_classes=5000): if (classes.shape[0] == clusters.shape[0] == 1 or classes.shape[0] == clusters.shape[0] == 0): return 1.0 - contingency = contingency_matrix(labels_true, labels_pred, - max_n_classes=max_n_classes) - contingency = np.array(contingency, dtype='float') + contingency = contingency_matrix(labels_true, labels_pred, sparse=True) + contingency = contingency.astype(float) # Calculate the MI for the two clusterings mi = mutual_info_score(labels_true, labels_pred, contingency=contingency) @@ -730,7 +707,7 @@ def adjusted_mutual_info_score(labels_true, labels_pred, max_n_classes=5000): return ami -def normalized_mutual_info_score(labels_true, labels_pred, max_n_classes=5000): +def normalized_mutual_info_score(labels_true, labels_pred): """Normalized Mutual Information between two clusterings. Normalized Mutual Information (NMI) is an normalization of the Mutual @@ -760,11 +737,6 @@ def normalized_mutual_info_score(labels_true, labels_pred, max_n_classes=5000): labels_pred : array, shape = [n_samples] A clustering of the data into disjoint subsets. - max_n_classes: int, optional (default=5000) - Maximal number of classes handled by the adjusted_rand_score - metric. Setting it too high can lead to MemoryError or OS - freeze - Returns ------- nmi: float @@ -803,9 +775,8 @@ def normalized_mutual_info_score(labels_true, labels_pred, max_n_classes=5000): if (classes.shape[0] == clusters.shape[0] == 1 or classes.shape[0] == clusters.shape[0] == 0): return 1.0 - contingency = contingency_matrix(labels_true, labels_pred, - max_n_classes=max_n_classes) - contingency = np.array(contingency, dtype='float') + contingency = contingency_matrix(labels_true, labels_pred, sparse=True) + contingency = contingency.astype(float) # Calculate the MI for the two clusterings mi = mutual_info_score(labels_true, labels_pred, contingency=contingency) @@ -816,7 +787,7 @@ def normalized_mutual_info_score(labels_true, labels_pred, max_n_classes=5000): return nmi -def fowlkes_mallows_score(labels_true, labels_pred, max_n_classes=5000): +def fowlkes_mallows_score(labels_true, labels_pred, sparse=False): """Measure the similarity of two clusterings of a set of points. The Fowlkes-Mallows index (FMI) is defined as the geometric mean between of @@ -845,11 +816,6 @@ def fowlkes_mallows_score(labels_true, labels_pred, max_n_classes=5000): labels_pred : array, shape = (``n_samples``, ) A clustering of the data into disjoint subsets. - max_n_classes : int, optional (default=5000) - Maximal number of classes handled by the Fowlkes-Mallows - metric. Setting it too high can lead to MemoryError or OS - freeze - Returns ------- score : float @@ -883,15 +849,13 @@ def fowlkes_mallows_score(labels_true, labels_pred, max_n_classes=5000): .. [2] `Wikipedia entry for the Fowlkes-Mallows Index `_ """ - labels_true, labels_pred = check_clusterings(labels_true, labels_pred,) + labels_true, labels_pred = check_clusterings(labels_true, labels_pred) n_samples, = labels_true.shape - c = contingency_matrix(labels_true, labels_pred, - max_n_classes=max_n_classes) - tk = np.dot(c.ravel(), c.ravel()) - n_samples - pk = np.sum(np.sum(c, axis=0) ** 2) - n_samples - qk = np.sum(np.sum(c, axis=1) ** 2) - n_samples - + c = contingency_matrix(labels_true, labels_pred, sparse=True) + tk = np.dot(c.data, c.data) - n_samples + pk = np.sum(np.asarray(c.sum(axis=0)).ravel() ** 2) - n_samples + qk = np.sum(np.asarray(c.sum(axis=1)).ravel() ** 2) - n_samples return tk / np.sqrt(pk * qk) if tk != 0. else 0. diff --git a/sklearn/metrics/cluster/tests/test_supervised.py b/sklearn/metrics/cluster/tests/test_supervised.py index 828c2c544574c..b50f681fd1480 100644 --- a/sklearn/metrics/cluster/tests/test_supervised.py +++ b/sklearn/metrics/cluster/tests/test_supervised.py @@ -1,23 +1,21 @@ import numpy as np +from nose.tools import assert_almost_equal +from nose.tools import assert_equal +from numpy.testing import assert_array_almost_equal +from sklearn.metrics.cluster import adjusted_mutual_info_score from sklearn.metrics.cluster import adjusted_rand_score -from sklearn.metrics.cluster import homogeneity_score from sklearn.metrics.cluster import completeness_score -from sklearn.metrics.cluster import v_measure_score -from sklearn.metrics.cluster import homogeneity_completeness_v_measure -from sklearn.metrics.cluster import adjusted_mutual_info_score -from sklearn.metrics.cluster import normalized_mutual_info_score -from sklearn.metrics.cluster import mutual_info_score -from sklearn.metrics.cluster import expected_mutual_information from sklearn.metrics.cluster import contingency_matrix -from sklearn.metrics.cluster import fowlkes_mallows_score from sklearn.metrics.cluster import entropy - +from sklearn.metrics.cluster import expected_mutual_information +from sklearn.metrics.cluster import fowlkes_mallows_score +from sklearn.metrics.cluster import homogeneity_completeness_v_measure +from sklearn.metrics.cluster import homogeneity_score +from sklearn.metrics.cluster import mutual_info_score +from sklearn.metrics.cluster import normalized_mutual_info_score +from sklearn.metrics.cluster import v_measure_score from sklearn.utils.testing import assert_raise_message -from nose.tools import assert_almost_equal -from nose.tools import assert_equal -from numpy.testing import assert_array_almost_equal - score_funcs = [ adjusted_rand_score, @@ -141,9 +139,16 @@ def test_adjusted_mutual_info_score(): # Mutual information mi = mutual_info_score(labels_a, labels_b) assert_almost_equal(mi, 0.41022, 5) - # Expected mutual information + # with provided sparse contingency + C = contingency_matrix(labels_a, labels_b, sparse=True) + mi = mutual_info_score(labels_a, labels_b, contingency=C) + assert_almost_equal(mi, 0.41022, 5) + # with provided dense contingency C = contingency_matrix(labels_a, labels_b) - n_samples = np.sum(C) + mi = mutual_info_score(labels_a, labels_b, contingency=C) + assert_almost_equal(mi, 0.41022, 5) + # Expected mutual information + n_samples = C.sum() emi = expected_mutual_information(C, n_samples) assert_almost_equal(emi, 0.15042, 5) # Adjusted mutual information @@ -183,55 +188,40 @@ def test_contingency_matrix(): assert_array_almost_equal(C, C2 + .1) +def test_contingency_matrix_sparse(): + labels_a = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) + labels_b = np.array([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2]) + C = contingency_matrix(labels_a, labels_b) + C_sparse = contingency_matrix(labels_a, labels_b, sparse=True).toarray() + assert_array_almost_equal(C, C_sparse) + C_sparse = assert_raise_message(ValueError, + "Cannot set 'eps' when sparse=True", + contingency_matrix, labels_a, labels_b, + eps=1e-10, sparse=True) + + def test_exactly_zero_info_score(): # Check numerical stability when information is exactly zero for i in np.logspace(1, 4, 4).astype(np.int): - labels_a, labels_b = np.ones(i, dtype=np.int),\ - np.arange(i, dtype=np.int) - assert_equal(normalized_mutual_info_score(labels_a, labels_b, - max_n_classes=1e4), 0.0) - assert_equal(v_measure_score(labels_a, labels_b, - max_n_classes=1e4), 0.0) - assert_equal(adjusted_mutual_info_score(labels_a, labels_b, - max_n_classes=1e4), 0.0) - assert_equal(normalized_mutual_info_score(labels_a, labels_b, - max_n_classes=1e4), 0.0) + labels_a, labels_b = (np.ones(i, dtype=np.int), + np.arange(i, dtype=np.int)) + assert_equal(normalized_mutual_info_score(labels_a, labels_b), 0.0) + assert_equal(v_measure_score(labels_a, labels_b), 0.0) + assert_equal(adjusted_mutual_info_score(labels_a, labels_b), 0.0) + assert_equal(normalized_mutual_info_score(labels_a, labels_b), 0.0) def test_v_measure_and_mutual_information(seed=36): # Check relation between v_measure, entropy and mutual information for i in np.logspace(1, 4, 4).astype(np.int): random_state = np.random.RandomState(seed) - labels_a, labels_b = random_state.randint(0, 10, i),\ - random_state.randint(0, 10, i) + labels_a, labels_b = (random_state.randint(0, 10, i), + random_state.randint(0, 10, i)) assert_almost_equal(v_measure_score(labels_a, labels_b), 2.0 * mutual_info_score(labels_a, labels_b) / (entropy(labels_a) + entropy(labels_b)), 0) -def test_max_n_classes(): - rng = np.random.RandomState(seed=0) - labels_true = rng.rand(53) - labels_pred = rng.rand(53) - labels_zero = np.zeros(53) - labels_true[:2] = 0 - labels_zero[:3] = 1 - labels_pred[:2] = 0 - for score_func in score_funcs: - expected = ("Too many classes for a clustering metric. If you " - "want to increase the limit, pass parameter " - "max_n_classes to the scoring function") - assert_raise_message(ValueError, expected, score_func, - labels_true, labels_pred, - max_n_classes=50) - expected = ("Too many clusters for a clustering metric. If you " - "want to increase the limit, pass parameter " - "max_n_classes to the scoring function") - assert_raise_message(ValueError, expected, score_func, - labels_zero, labels_pred, - max_n_classes=50) - - def test_fowlkes_mallows_score(): # General case score = fowlkes_mallows_score([0, 0, 0, 1, 1, 1],