From eb4eebd3eecf16f2eb851fcf07da261f270fa2af Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Sun, 26 May 2013 17:21:43 +0200 Subject: [PATCH 01/14] ENH remove _is_1d and _check_1d_array thanks to @GaelVaroquaux --- sklearn/metrics/metrics.py | 167 ++++++-------------------- sklearn/metrics/tests/test_metrics.py | 5 +- 2 files changed, 42 insertions(+), 130 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index d38afc347a46b..770f8506b4c2b 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -34,103 +34,6 @@ ############################################################################### # General utilities ############################################################################### -def _is_1d(x): - """Return True if x is 1d or a column vector - - Parameters - ---------- - x : numpy array. - - Returns - ------- - is_1d : boolean, - Return True if x can be considered as a 1d vector. - - Examples - -------- - >>> import numpy as np - >>> from sklearn.metrics.metrics import _is_1d - >>> _is_1d([1, 2, 3]) - True - >>> _is_1d(np.array([1, 2, 3])) - True - >>> _is_1d([[1, 2, 3]]) - False - >>> _is_1d(np.array([[1, 2, 3]])) - False - >>> _is_1d([[1], [2], [3]]) - True - >>> _is_1d(np.array([[1], [2], [3]])) - True - >>> _is_1d([[1, 2], [3, 4]]) - False - >>> _is_1d(np.array([[1, 2], [3, 4]])) - False - - See also - -------- - _check_1d_array - - """ - shape = np.shape(x) - return len(shape) == 1 or len(shape) == 2 and shape[1] == 1 - - -def _check_1d_array(y1, y2, ravel=False): - """Check that y1 and y2 are vectors of the same shape. - - It convert 1d arrays (y1 and y2) of various shape to a common shape - representation. Note that ``y1`` and ``y2`` should have the same number of - elements. - - Parameters - ---------- - y1 : array-like, - y1 must be a "vector". - - y2 : array-like - y2 must be a "vector". - - ravel : boolean, optional (default=False), - If ``ravel``` is set to ``True``, then ``y1`` and ``y2`` are raveled. - - Returns - ------- - y1 : numpy array, - If ``ravel`` is set to ``True``, return np.ravel(y1), else - return y1. - - y2 : numpy array, - Return y2 reshaped to have the shape of y1. - - Examples - -------- - >>> from sklearn.metrics.metrics import _check_1d_array - >>> _check_1d_array([1, 2], [[3], [4]]) - (array([1, 2]), array([3, 4])) - - See also - -------- - _is_1d - - """ - y1 = np.asarray(y1) - y2 = np.asarray(y2) - - if not _is_1d(y1): - raise ValueError("y1 can't be considered as a vector") - - if not _is_1d(y2): - raise ValueError("y2 can't be considered as a vector") - - if ravel: - return np.ravel(y1), np.ravel(y2) - else: - if np.shape(y1) != np.shape(y2): - y2 = np.reshape(y2, np.shape(y1)) - - return y1, y2 - def _check_clf_targets(y_true, y_pred): """Check that y_true and y_pred belong to the same classification task @@ -467,7 +370,8 @@ def matthews_corrcoef(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) - y_true, y_pred = _check_1d_array(y_true, y_pred, ravel=True) + y_true = np.squeeze(y_true) + y_pred = np.squeeze(y_pred) mcc = np.corrcoef(y_true, y_pred)[0, 1] if np.isnan(mcc): @@ -508,7 +412,8 @@ def _binary_clf_curve(y_true, y_score, pos_label=None): Decreasing score values. """ y_true, y_score = check_arrays(y_true, y_score) - y_true, y_score = _check_1d_array(y_true, y_score, ravel=True) + y_true = np.squeeze(y_true) + y_score = np.squeeze(y_score) # ensure binary classification if pos_label is not specified classes = np.unique(y_true) @@ -745,7 +650,8 @@ def confusion_matrix(y_true, y_pred, labels=None): """ y_true, y_pred = check_arrays(y_true, y_pred) - y_true, y_pred = _check_1d_array(y_true, y_pred, ravel=True) + y_true = np.squeeze(y_true) + y_pred = np.squeeze(y_pred) if labels is None: labels = unique_labels(y_true, y_pred) @@ -840,13 +746,7 @@ def zero_one_loss(y_true, y_pred, normalize=True): if normalize: return 1 - score else: - if hasattr(y_true, "shape"): - n_samples = (np.max(y_true.shape) if _is_1d(y_true) - else y_true.shape[0]) - - else: - n_samples = len(y_true) - + n_samples = len(y_true) return n_samples - score @@ -1006,6 +906,11 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True): score[i] = (len(true_set & pred_set) / size_true_union_pred) else: + y_true, y_pred = check_arrays(y_true, y_pred) + + y_true = np.squeeze(y_true) + y_pred = np.squeeze(y_pred) + score = y_true == y_pred if normalize: @@ -1082,6 +987,11 @@ def accuracy_score(y_true, y_pred, normalize=True): score = np.array([len(set(true) ^ set(pred)) == 0 for pred, true in zip(y_pred, y_true)]) else: + y_true, y_pred = check_arrays(y_true, y_pred) + + y_true = np.squeeze(y_true) + y_pred = np.squeeze(y_pred) + score = y_true == y_pred if normalize: @@ -1422,9 +1332,9 @@ def _tp_tn_fp_fn(y_true, y_pred, labels=None): else: labels = np.asarray(labels) n_labels = labels.size - true_pos = np.zeros((n_labels), dtype=np.int) - false_pos = np.zeros((n_labels), dtype=np.int) - false_neg = np.zeros((n_labels), dtype=np.int) + true_pos = np.zeros((n_labels,), dtype=np.int) + false_pos = np.zeros((n_labels,), dtype=np.int) + false_neg = np.zeros((n_labels,), dtype=np.int) if y_type == 'multilabel-indicator': true_pos = np.sum(np.logical_and(y_true == 1, @@ -1448,18 +1358,16 @@ def _tp_tn_fp_fn(y_true, y_pred, labels=None): false_neg[np.setdiff1d(true_set, pred_set)] += 1 else: + + y_true, y_pred = check_arrays(y_true, y_pred) + for i, label_i in enumerate(labels): true_pos[i] = np.sum(y_pred[y_true == label_i] == label_i) false_pos[i] = np.sum(y_pred[y_true != label_i] == label_i) false_neg[i] = np.sum(y_pred[y_true == label_i] != label_i) # Compute the true_neg using the tp, fp and fn - if hasattr(y_true, "shape"): - n_samples = (np.max(y_true.shape) if _is_1d(y_true) - else y_true.shape[0]) - else: - n_samples = len(y_true) - + n_samples = len(y_true) true_neg = n_samples - true_pos - false_pos - false_neg return true_pos, true_neg, false_pos, false_neg @@ -2206,6 +2114,11 @@ def hamming_loss(y_true, y_pred, classes=None): return np.mean(loss) / np.size(classes) else: + + y_true, y_pred = check_arrays(y_true, y_pred) + y_true = np.squeeze(y_true) + y_pred = np.squeeze(y_pred) + return sp_hamming(y_true, y_pred) @@ -2243,9 +2156,8 @@ def mean_absolute_error(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) - # Handle mix 1d representation - if _is_1d(y_true): - y_true, y_pred = _check_1d_array(y_true, y_pred) + y_true = np.squeeze(y_true) + y_pred = np.squeeze(y_pred) return np.mean(np.abs(y_pred - y_true)) @@ -2281,9 +2193,8 @@ def mean_squared_error(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) - # Handle mix 1d representation - if _is_1d(y_true): - y_true, y_pred = _check_1d_array(y_true, y_pred) + y_true = np.squeeze(y_true) + y_pred = np.squeeze(y_pred) return np.mean((y_pred - y_true) ** 2) @@ -2324,9 +2235,8 @@ def explained_variance_score(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) - # Handle mix 1d representation - if _is_1d(y_true): - y_true, y_pred = _check_1d_array(y_true, y_pred) + y_true = np.atleast_1d(np.squeeze(y_true)) + y_pred = np.atleast_1d(np.squeeze(y_pred)) numerator = np.var(y_true - y_pred) denominator = np.var(y_true) @@ -2385,11 +2295,10 @@ def r2_score(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) - # Handle mix 1d representation - if _is_1d(y_true): - y_true, y_pred = _check_1d_array(y_true, y_pred, ravel=True) + y_true = np.squeeze(y_true) + y_pred = np.squeeze(y_pred) - if len(y_true) == 1: + if y_true.size == 1: raise ValueError("r2_score can only be computed given more than one" " sample.") numerator = ((y_true - y_pred) ** 2).sum() diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 51638bdc56459..56ac9a3ebafaa 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -1,6 +1,8 @@ from __future__ import division, print_function from functools import partial +from collections import namedtuple +from itertools import product import warnings import numpy as np @@ -19,7 +21,8 @@ assert_not_equal, assert_array_equal, assert_array_almost_equal, - assert_greater) + assert_greater, + assert_false) from sklearn.metrics import (accuracy_score, From 23ca7148418661e046f3dfdf8635c69709f78f03 Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Mon, 27 May 2013 08:44:19 +0200 Subject: [PATCH 02/14] flake8 --- sklearn/metrics/tests/test_metrics.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 56ac9a3ebafaa..51638bdc56459 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -1,8 +1,6 @@ from __future__ import division, print_function from functools import partial -from collections import namedtuple -from itertools import product import warnings import numpy as np @@ -21,8 +19,7 @@ assert_not_equal, assert_array_equal, assert_array_almost_equal, - assert_greater, - assert_false) + assert_greater) from sklearn.metrics import (accuracy_score, From 2982f080644cc359fdea76260f70cb28e5739e42 Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Mon, 27 May 2013 09:38:28 +0200 Subject: [PATCH 03/14] ENH raise ValueError with row vector if multilabel or multioutput is not supported --- sklearn/metrics/metrics.py | 27 ++++++++++++++++++++++- sklearn/metrics/tests/test_metrics.py | 31 +++++++++++++++++++++------ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 770f8506b4c2b..c75bf5e785825 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -370,6 +370,13 @@ def matthews_corrcoef(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) + + if (y_true.ndim == 2 and y_true.shape[1] > 1): + raise ValueError("Bad y_true input shape") + + if (y_pred.ndim == 2 and y_pred.shape[1] > 1): + raise ValueError("Bad y_pred input shape") + y_true = np.squeeze(y_true) y_pred = np.squeeze(y_pred) @@ -412,6 +419,13 @@ def _binary_clf_curve(y_true, y_score, pos_label=None): Decreasing score values. """ y_true, y_score = check_arrays(y_true, y_score) + + if (y_true.ndim == 2 and y_true.shape[1] > 1): + raise ValueError("Bad y_true input shape") + + if (y_score.ndim == 2 and y_score.shape[1] > 1): + raise ValueError("Bad y_score input shape") + y_true = np.squeeze(y_true) y_score = np.squeeze(y_score) @@ -650,6 +664,12 @@ def confusion_matrix(y_true, y_pred, labels=None): """ y_true, y_pred = check_arrays(y_true, y_pred) + + if (y_true.ndim == 2 and y_true.shape[1] > 1): + raise ValueError("Bad y_true input shape") + if (y_pred.ndim == 2 and y_pred.shape[1] > 1): + raise ValueError("Bad y_pred input shape") + y_true = np.squeeze(y_true) y_pred = np.squeeze(y_pred) @@ -739,7 +759,6 @@ def zero_one_loss(y_true, y_pred, normalize=True): """ - y_true, y_pred = check_arrays(y_true, y_pred, allow_lists=True) score = accuracy_score(y_true, y_pred, normalize=normalize) @@ -2235,6 +2254,12 @@ def explained_variance_score(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) + if (y_true.ndim == 2 and y_true.shape[1] > 1): + raise ValueError("Bad y_true input shape") + + if (y_pred.ndim == 2 and y_pred.shape[1] > 1): + raise ValueError("Bad y_pred input shape") + y_true = np.atleast_1d(np.squeeze(y_true)) y_pred = np.atleast_1d(np.squeeze(y_pred)) diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 51638bdc56459..e1234d3b2a2ad 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -52,7 +52,7 @@ ALL_METRICS = { "accuracy_score": accuracy_score, "unormalized_accuracy_score": partial(accuracy_score, normalize=False), - + "confusion_matrix": confusion_matrix, "hamming_loss": hamming_loss, "jaccard_similarity_score": jaccard_similarity_score, @@ -115,6 +115,12 @@ "zero_one_loss": zero_one_loss, "unnormalized_zero_one_loss": partial(zero_one_loss, normalize=False), + "precision_score": precision_score, + "recall_score": recall_score, + "f1_score": f1_score, + "f2_score": partial(fbeta_score, beta=2), + "f0.5_score": partial(fbeta_score, beta=0.5), + "weighted_f0.5_score": partial(fbeta_score, average="weighted", beta=0.5), "weighted_f1_score": partial(f1_score, average="weighted"), "weighted_f2_score": partial(fbeta_score, average="weighted", beta=2), @@ -140,6 +146,13 @@ "macro_recall_score": partial(recall_score, average="macro"), } +MULTIOUTPUT_METRICS = { + "mean_absolute_error": mean_absolute_error, + "mean_squared_error": mean_squared_error, + "r2_score": r2_score, +} + + SYMMETRIC_METRICS = { "accuracy_score": accuracy_score, "unormalized_accuracy_score": partial(accuracy_score, normalize=False), @@ -167,6 +180,8 @@ "explained_variance_score": explained_variance_score, "r2_score": r2_score, + "confusion_matrix": confusion_matrix, + "precision_score": precision_score, "recall_score": recall_score, "f2_score": partial(fbeta_score, beta=2), @@ -974,6 +989,9 @@ def test_format_invariance_with_1d_vectors(): % name) # At the moment, these mix representations aren't allowed + if name not in MULTILABELS_METRICS and not name in MULTIOUTPUT_METRICS: + assert_raises(ValueError, metric, y1_row, y2_row) + assert_raises(ValueError, metric, y1_1d, y2_row) assert_raises(ValueError, metric, y1_row, y2_1d) assert_raises(ValueError, metric, y1_list, y2_row) @@ -1040,9 +1058,8 @@ def test_multioutput_number_of_output_differ(): y_true = np.array([[1, 0, 0, 1], [0, 1, 1, 1], [1, 1, 0, 1]]) y_pred = np.array([[0, 0], [1, 0], [0, 0]]) - assert_raises(ValueError, mean_squared_error, y_true, y_pred) - assert_raises(ValueError, mean_absolute_error, y_true, y_pred) - assert_raises(ValueError, r2_score, y_true, y_pred) + for name, metrics in MULTIOUTPUT_METRICS.items(): + assert_raises(ValueError, metrics, y_true, y_pred) def test_multioutput_regression_invariance_to_dimension_shuffling(): @@ -1053,13 +1070,15 @@ def test_multioutput_regression_invariance_to_dimension_shuffling(): y_pred = np.reshape(y_pred, (-1, n_dims)) rng = check_random_state(314159) - for metric in [r2_score, mean_squared_error, mean_absolute_error]: + for name, metric in MULTIOUTPUT_METRICS.items(): error = metric(y_true, y_pred) for _ in xrange(3): perm = rng.permutation(n_dims) assert_almost_equal(metric(y_true[:, perm], y_pred[:, perm]), - error) + error, + err_msg="%s is not dimension shuffling" + "invariant" % name) def test_multilabel_representation_invariance(): From 39ad1d7323a86f9aa19ceaeffc504f9721c8304d Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Mon, 27 May 2013 09:47:13 +0200 Subject: [PATCH 04/14] ENH being less permissive thanks to @jnothman --- sklearn/metrics/metrics.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index c75bf5e785825..ac678201fcdec 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -371,12 +371,14 @@ def matthews_corrcoef(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) - if (y_true.ndim == 2 and y_true.shape[1] > 1): + if not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1)): raise ValueError("Bad y_true input shape") - if (y_pred.ndim == 2 and y_pred.shape[1] > 1): + if not (y_pred.ndim == 1 or + (y_pred.ndim == 2 and y_pred.shape[1] == 1)): raise ValueError("Bad y_pred input shape") + y_true = np.squeeze(y_true) y_pred = np.squeeze(y_pred) @@ -420,10 +422,11 @@ def _binary_clf_curve(y_true, y_score, pos_label=None): """ y_true, y_score = check_arrays(y_true, y_score) - if (y_true.ndim == 2 and y_true.shape[1] > 1): + if not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1)): raise ValueError("Bad y_true input shape") - if (y_score.ndim == 2 and y_score.shape[1] > 1): + if not (y_score.ndim == 1 or + (y_score.ndim == 2 and y_score.shape[1] == 1)): raise ValueError("Bad y_score input shape") y_true = np.squeeze(y_true) @@ -665,9 +668,10 @@ def confusion_matrix(y_true, y_pred, labels=None): """ y_true, y_pred = check_arrays(y_true, y_pred) - if (y_true.ndim == 2 and y_true.shape[1] > 1): + if not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1)): raise ValueError("Bad y_true input shape") - if (y_pred.ndim == 2 and y_pred.shape[1] > 1): + + if not (y_pred.ndim == 1 or (y_pred.ndim == 2 and y_pred.shape[1] == 1)): raise ValueError("Bad y_pred input shape") y_true = np.squeeze(y_true) @@ -2254,10 +2258,10 @@ def explained_variance_score(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) - if (y_true.ndim == 2 and y_true.shape[1] > 1): + if not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1)): raise ValueError("Bad y_true input shape") - if (y_pred.ndim == 2 and y_pred.shape[1] > 1): + if not (y_pred.ndim == 1 or (y_pred.ndim == 2 and y_pred.shape[1] == 1)): raise ValueError("Bad y_pred input shape") y_true = np.atleast_1d(np.squeeze(y_true)) From ca8e8034e5390dd831c9a1b11d0983cb58d0d7ca Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Mon, 27 May 2013 09:50:36 +0200 Subject: [PATCH 05/14] DOC add example is_multilabel --- sklearn/utils/multiclass.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py index 11289c433bf2d..3f365ab1e4d09 100644 --- a/sklearn/utils/multiclass.py +++ b/sklearn/utils/multiclass.py @@ -212,6 +212,8 @@ def is_multilabel(y): True >>> is_multilabel(np.array([[1], [0], [0]])) False + >>> is_multilabel(np.array([[1, 0, 0]])) + True """ return is_label_indicator_matrix(y) or is_sequence_of_sequences(y) From e83f73d21bccd5f8887c0851c1fb83426a7345d6 Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Mon, 27 May 2013 12:32:35 +0200 Subject: [PATCH 06/14] ENH handle properly row vector --- sklearn/metrics/metrics.py | 73 +++++++++------------------ sklearn/metrics/tests/test_metrics.py | 10 ++-- 2 files changed, 29 insertions(+), 54 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index ac678201fcdec..8806222cc797d 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -84,8 +84,16 @@ def _check_clf_targets(y_true, y_pred): # 'binary' can be removed type_true = type_pred = 'multiclass' - y_true = np.ravel(y_true) - y_pred = np.ravel(y_pred) + y_true, y_pred = check_arrays(y_true, y_pred) + + if (not (y_true.ndim == 1 or + (y_true.ndim == 2 and y_true.shape[1] == 1)) or + not (y_pred.ndim == 1 or + (y_pred.ndim == 2 and y_pred.shape[1] == 1))): + raise ValueError("Bad input shape") + + y_true = np.atleast_1d(np.squeeze(y_true)) + y_pred = np.atleast_1d(np.squeeze(y_pred)) else: raise ValueError("Can't handle %s/%s targets" % (type_true, type_pred)) @@ -369,18 +377,10 @@ def matthews_corrcoef(y_true, y_pred): -0.33... """ - y_true, y_pred = check_arrays(y_true, y_pred) - - if not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1)): - raise ValueError("Bad y_true input shape") - - if not (y_pred.ndim == 1 or - (y_pred.ndim == 2 and y_pred.shape[1] == 1)): - raise ValueError("Bad y_pred input shape") + y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred) - - y_true = np.squeeze(y_true) - y_pred = np.squeeze(y_pred) + if y_type != "binary": + raise ValueError("%s is not supported" % y_type) mcc = np.corrcoef(y_true, y_pred)[0, 1] if np.isnan(mcc): @@ -422,12 +422,11 @@ def _binary_clf_curve(y_true, y_score, pos_label=None): """ y_true, y_score = check_arrays(y_true, y_score) - if not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1)): - raise ValueError("Bad y_true input shape") - - if not (y_score.ndim == 1 or - (y_score.ndim == 2 and y_score.shape[1] == 1)): - raise ValueError("Bad y_score input shape") + if (not (y_true.ndim == 1 or + (y_true.ndim == 2 and y_true.shape[1] == 1)) or + not (y_score.ndim == 1 or + (y_score.ndim == 2 and y_score.shape[1] == 1))): + raise ValueError("Bad input shape") y_true = np.squeeze(y_true) y_score = np.squeeze(y_score) @@ -666,16 +665,9 @@ def confusion_matrix(y_true, y_pred, labels=None): [1, 0, 2]]) """ - y_true, y_pred = check_arrays(y_true, y_pred) - - if not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1)): - raise ValueError("Bad y_true input shape") - - if not (y_pred.ndim == 1 or (y_pred.ndim == 2 and y_pred.shape[1] == 1)): - raise ValueError("Bad y_pred input shape") - - y_true = np.squeeze(y_true) - y_pred = np.squeeze(y_pred) + y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred) + if y_type not in ("binary", "multiclass"): + raise ValueError("%s is not supported" % y_type) if labels is None: labels = unique_labels(y_true, y_pred) @@ -929,11 +921,6 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True): score[i] = (len(true_set & pred_set) / size_true_union_pred) else: - y_true, y_pred = check_arrays(y_true, y_pred) - - y_true = np.squeeze(y_true) - y_pred = np.squeeze(y_pred) - score = y_true == y_pred if normalize: @@ -1002,6 +989,7 @@ def accuracy_score(y_true, y_pred, normalize=True): 0.0 """ + # Compute accuracy for each possible representation y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred) if y_type == 'multilabel-indicator': @@ -1010,11 +998,6 @@ def accuracy_score(y_true, y_pred, normalize=True): score = np.array([len(set(true) ^ set(pred)) == 0 for pred, true in zip(y_pred, y_true)]) else: - y_true, y_pred = check_arrays(y_true, y_pred) - - y_true = np.squeeze(y_true) - y_pred = np.squeeze(y_pred) - score = y_true == y_pred if normalize: @@ -1381,9 +1364,6 @@ def _tp_tn_fp_fn(y_true, y_pred, labels=None): false_neg[np.setdiff1d(true_set, pred_set)] += 1 else: - - y_true, y_pred = check_arrays(y_true, y_pred) - for i, label_i in enumerate(labels): true_pos[i] = np.sum(y_pred[y_true == label_i] == label_i) false_pos[i] = np.sum(y_pred[y_true != label_i] == label_i) @@ -2123,6 +2103,7 @@ def hamming_loss(y_true, y_pred, classes=None): """ y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred) + if classes is None: classes = unique_labels(y_true, y_pred) else: @@ -2137,12 +2118,7 @@ def hamming_loss(y_true, y_pred, classes=None): return np.mean(loss) / np.size(classes) else: - - y_true, y_pred = check_arrays(y_true, y_pred) - y_true = np.squeeze(y_true) - y_pred = np.squeeze(y_pred) - - return sp_hamming(y_true, y_pred) + return sp_hamming(y_true, y_pred) ############################################################################### @@ -2215,7 +2191,6 @@ def mean_squared_error(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) - y_true = np.squeeze(y_true) y_pred = np.squeeze(y_pred) diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index e1234d3b2a2ad..73a6ab124fdfd 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -988,19 +988,19 @@ def test_format_invariance_with_1d_vectors(): "with mix list and np-array-column" % name) - # At the moment, these mix representations aren't allowed - if name not in MULTILABELS_METRICS and not name in MULTIOUTPUT_METRICS: - assert_raises(ValueError, metric, y1_row, y2_row) - + # These mix representations aren't allowed assert_raises(ValueError, metric, y1_1d, y2_row) assert_raises(ValueError, metric, y1_row, y2_1d) assert_raises(ValueError, metric, y1_list, y2_row) assert_raises(ValueError, metric, y1_row, y2_list) assert_raises(ValueError, metric, y1_column, y2_row) assert_raises(ValueError, metric, y1_row, y2_column) + # NB: We do not test for y1_row, y2_row as these may be # interpreted as multilabel or multioutput data. - + if (name not in MULTIOUTPUT_METRICS and + name not in MULTILABELS_METRICS): + assert_raises(ValueError, metric, y1_row, y2_row) def test_clf_single_sample(): """Non-regression test: scores should work with a single sample. From d884512459536ff8ef2602afd8ab8f3aa4bd534a Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Mon, 27 May 2013 14:47:34 +0200 Subject: [PATCH 07/14] Flake8 --- sklearn/metrics/metrics.py | 2 +- sklearn/metrics/tests/test_metrics.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 8806222cc797d..7a8ac84d30864 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -2118,7 +2118,7 @@ def hamming_loss(y_true, y_pred, classes=None): return np.mean(loss) / np.size(classes) else: - return sp_hamming(y_true, y_pred) + return sp_hamming(y_true, y_pred) ############################################################################### diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 73a6ab124fdfd..e69649903fa87 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -2,6 +2,7 @@ from functools import partial import warnings +from itertools import product import numpy as np from sklearn import datasets @@ -47,6 +48,7 @@ zero_one_score, zero_one_loss) from sklearn.metrics.metrics import _check_clf_targets + from sklearn.externals.six.moves import xrange ALL_METRICS = { From 9e0c896136343a2999a8d148333c9343ae92cf3f Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Mon, 8 Jul 2013 14:07:39 +0200 Subject: [PATCH 08/14] ENH better error message --- sklearn/metrics/metrics.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 7a8ac84d30864..d234e87affd0f 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -87,10 +87,12 @@ def _check_clf_targets(y_true, y_pred): y_true, y_pred = check_arrays(y_true, y_pred) if (not (y_true.ndim == 1 or - (y_true.ndim == 2 and y_true.shape[1] == 1)) or - not (y_pred.ndim == 1 or - (y_pred.ndim == 2 and y_pred.shape[1] == 1))): - raise ValueError("Bad input shape") + (y_true.ndim == 2 and y_true.shape[1] == 1))): + raise ValueError("y_true has a bad input shape %s" % y_true.shape) + + if (not (y_pred.ndim == 1 or + (y_pred.ndim == 2 and y_pred.shape[1] == 1))): + raise ValueError("y_pred has a bad input shape %s" % y_pred.shape) y_true = np.atleast_1d(np.squeeze(y_true)) y_pred = np.atleast_1d(np.squeeze(y_pred)) @@ -2234,10 +2236,10 @@ def explained_variance_score(y_true, y_pred): y_true, y_pred = check_arrays(y_true, y_pred) if not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1)): - raise ValueError("Bad y_true input shape") + raise ValueError("y_true has a bad input shape %s" % y_true.shape) if not (y_pred.ndim == 1 or (y_pred.ndim == 2 and y_pred.shape[1] == 1)): - raise ValueError("Bad y_pred input shape") + raise ValueError("y_pred has a bad input shape %s" % y_pred.shape) y_true = np.atleast_1d(np.squeeze(y_true)) y_pred = np.atleast_1d(np.squeeze(y_pred)) From 07be6557028fb2a22898ff22e596f5c9c0a62710 Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Mon, 8 Jul 2013 14:38:56 +0200 Subject: [PATCH 09/14] FIX switch to the new format syntax --- sklearn/metrics/metrics.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index d234e87affd0f..3fdee6975ad4b 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -88,11 +88,13 @@ def _check_clf_targets(y_true, y_pred): if (not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1))): - raise ValueError("y_true has a bad input shape %s" % y_true.shape) + raise ValueError("y_true has a bad input shape " + "{0}".format(y_true.shape)) if (not (y_pred.ndim == 1 or (y_pred.ndim == 2 and y_pred.shape[1] == 1))): - raise ValueError("y_pred has a bad input shape %s" % y_pred.shape) + raise ValueError("y_pred has a bad input shape " + "{0}".format(y_pred.shape)) y_true = np.atleast_1d(np.squeeze(y_true)) y_pred = np.atleast_1d(np.squeeze(y_pred)) @@ -2236,10 +2238,12 @@ def explained_variance_score(y_true, y_pred): y_true, y_pred = check_arrays(y_true, y_pred) if not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1)): - raise ValueError("y_true has a bad input shape %s" % y_true.shape) + raise ValueError("y_true has a bad input shape " + "{0}".format(y_true.shape)) if not (y_pred.ndim == 1 or (y_pred.ndim == 2 and y_pred.shape[1] == 1)): - raise ValueError("y_pred has a bad input shape %s" % y_pred.shape) + raise ValueError("y_pred has a bad input shape " + "{0}".format(y_pred.shape)) y_true = np.atleast_1d(np.squeeze(y_true)) y_pred = np.atleast_1d(np.squeeze(y_pred)) From 9555fe75e219accf33418b148d189c7aee712cc0 Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Tue, 9 Jul 2013 08:34:41 +0200 Subject: [PATCH 10/14] ENH prettier error message for _binary_clf_curve with bad input shape --- sklearn/metrics/metrics.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 3fdee6975ad4b..292e89a56971e 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -427,10 +427,14 @@ def _binary_clf_curve(y_true, y_score, pos_label=None): y_true, y_score = check_arrays(y_true, y_score) if (not (y_true.ndim == 1 or - (y_true.ndim == 2 and y_true.shape[1] == 1)) or - not (y_score.ndim == 1 or - (y_score.ndim == 2 and y_score.shape[1] == 1))): - raise ValueError("Bad input shape") + (y_true.ndim == 2 and y_true.shape[1] == 1))): + raise ValueError("y_true has a bad input shape " + "{0}".format(y_true.shape)) + + if (not (y_score.ndim == 1 or + (y_score.ndim == 2 and y_score.shape[1] == 1))): + raise ValueError("y_score has a bad input shape " + "{0}".format(y_score.shape)) y_true = np.squeeze(y_true) y_score = np.squeeze(y_score) From f68d270e13c03bb42040b83f8a84ec937f2a816f Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Wed, 10 Jul 2013 11:38:05 +0200 Subject: [PATCH 11/14] ENH use ravel instead of atleast_1d and squeeze whenever possible --- sklearn/metrics/metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 292e89a56971e..869b274400c1e 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -96,8 +96,8 @@ def _check_clf_targets(y_true, y_pred): raise ValueError("y_pred has a bad input shape " "{0}".format(y_pred.shape)) - y_true = np.atleast_1d(np.squeeze(y_true)) - y_pred = np.atleast_1d(np.squeeze(y_pred)) + y_true = np.ravel(y_true) + y_pred = np.ravel(y_pred) else: raise ValueError("Can't handle %s/%s targets" % (type_true, type_pred)) @@ -2249,8 +2249,8 @@ def explained_variance_score(y_true, y_pred): raise ValueError("y_pred has a bad input shape " "{0}".format(y_pred.shape)) - y_true = np.atleast_1d(np.squeeze(y_true)) - y_pred = np.atleast_1d(np.squeeze(y_pred)) + y_true = np.ravel(y_true) + y_pred = np.ravel(y_pred)) numerator = np.var(y_true - y_pred) denominator = np.var(y_true) From 43199195516ee92330bc5e8fbd91f3888f237fb0 Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Wed, 10 Jul 2013 15:06:21 +0200 Subject: [PATCH 12/14] ENH coherently input checking for regression metrics --- sklearn/metrics/metrics.py | 103 +++++++++++++++----------- sklearn/metrics/tests/test_metrics.py | 28 +++++++ 2 files changed, 87 insertions(+), 44 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 869b274400c1e..9ca68485b4424 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -34,6 +34,43 @@ ############################################################################### # General utilities ############################################################################### +def _check_reg_targets(y_true, y_pred): + """Check that y_true and y_pred belong to the same regression task + + Parameters + ---------- + y_true : array-like, + + y_pred : array-like + + Returns + ------- + type_true : one of {'continuous', continuous-multioutput'} + The type of the true target data, as output by + ``utils.multiclass.type_of_target`` + + y_true : array-like of shape = [n_samples, n_outputs] + Ground truth (correct) target values. + + y_pred : array-like of shape = [n_samples, n_outputs] + Estimated target values. + """ + y_true, y_pred = check_arrays(y_true, y_pred) + + if y_true.ndim == 1: + y_true = y_true.reshape((-1, 1)) + + if y_pred.ndim == 1: + y_pred = y_pred.reshape((-1, 1)) + + if y_true.shape[1] != y_pred.shape[1]: + raise ValueError("y_true and y_pred have different number of output " + "({0}!={1})".format(y_true.shape[1], y_true.shape[1])) + + y_type = 'continuous' if y_true.shape[1] == 1 else 'continuous-multioutput' + + return y_type, y_true, y_pred + def _check_clf_targets(y_true, y_pred): """Check that y_true and y_pred belong to the same classification task @@ -85,19 +122,17 @@ def _check_clf_targets(y_true, y_pred): type_true = type_pred = 'multiclass' y_true, y_pred = check_arrays(y_true, y_pred) - if (not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1))): - raise ValueError("y_true has a bad input shape " - "{0}".format(y_true.shape)) + raise ValueError("y_true has a bad input shape {0} " + "".format(y_true.shape)) if (not (y_pred.ndim == 1 or (y_pred.ndim == 2 and y_pred.shape[1] == 1))): - raise ValueError("y_pred has a bad input shape " - "{0}".format(y_pred.shape)) - - y_true = np.ravel(y_true) - y_pred = np.ravel(y_pred) + raise ValueError("y_pred has a bad input shape {0}" + "".format(y_pred.shape)) + y_true = y_true.ravel() + y_pred = y_pred.ravel() else: raise ValueError("Can't handle %s/%s targets" % (type_true, type_pred)) @@ -425,19 +460,17 @@ def _binary_clf_curve(y_true, y_score, pos_label=None): Decreasing score values. """ y_true, y_score = check_arrays(y_true, y_score) - if (not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1))): - raise ValueError("y_true has a bad input shape " - "{0}".format(y_true.shape)) + raise ValueError("y_true has a bad input shape {0} " + "".format(y_true.shape)) if (not (y_score.ndim == 1 or (y_score.ndim == 2 and y_score.shape[1] == 1))): - raise ValueError("y_score has a bad input shape " - "{0}".format(y_score.shape)) - - y_true = np.squeeze(y_true) - y_score = np.squeeze(y_score) + raise ValueError("y_score has a bad input shape {0}" + "".format(y_score.shape)) + y_true = np.ravel(y_true) + y_score = np.ravel(y_score) # ensure binary classification if pos_label is not specified classes = np.unique(y_true) @@ -2120,10 +2153,10 @@ def hamming_loss(y_true, y_pred, classes=None): if y_type == 'multilabel-indicator': return np.mean(y_true != y_pred) elif y_type == 'multilabel-sequences': - loss = np.array([len(set(pred) ^ set(true)) - for pred, true in zip(y_pred, y_true)]) + loss = np.array([len(set(pred) ^ set(true)) + for pred, true in zip(y_pred, y_true)]) - return np.mean(loss) / np.size(classes) + return np.mean(loss) / np.size(classes) else: return sp_hamming(y_true, y_pred) @@ -2161,11 +2194,7 @@ def mean_absolute_error(y_true, y_pred): 0.75 """ - y_true, y_pred = check_arrays(y_true, y_pred) - - y_true = np.squeeze(y_true) - y_pred = np.squeeze(y_pred) - + y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) return np.mean(np.abs(y_pred - y_true)) @@ -2198,10 +2227,7 @@ def mean_squared_error(y_true, y_pred): 0.708... """ - y_true, y_pred = check_arrays(y_true, y_pred) - y_true = np.squeeze(y_true) - y_pred = np.squeeze(y_pred) - + y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) return np.mean((y_pred - y_true) ** 2) @@ -2239,18 +2265,10 @@ def explained_variance_score(y_true, y_pred): 0.957... """ - y_true, y_pred = check_arrays(y_true, y_pred) - - if not (y_true.ndim == 1 or (y_true.ndim == 2 and y_true.shape[1] == 1)): - raise ValueError("y_true has a bad input shape " - "{0}".format(y_true.shape)) + y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - if not (y_pred.ndim == 1 or (y_pred.ndim == 2 and y_pred.shape[1] == 1)): - raise ValueError("y_pred has a bad input shape " - "{0}".format(y_pred.shape)) - - y_true = np.ravel(y_true) - y_pred = np.ravel(y_pred)) + if y_type != "continuous": + raise ValueError("{0} is not supported".format(y_type)) numerator = np.var(y_true - y_pred) denominator = np.var(y_true) @@ -2307,12 +2325,9 @@ def r2_score(y_true, y_pred): 0.938... """ - y_true, y_pred = check_arrays(y_true, y_pred) - - y_true = np.squeeze(y_true) - y_pred = np.squeeze(y_pred) + y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - if y_true.size == 1: + if len(y_true) == 1: raise ValueError("r2_score can only be computed given more than one" " sample.") numerator = ((y_true - y_pred) ** 2).sum() diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index e69649903fa87..c6facf6528d64 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -48,6 +48,7 @@ zero_one_score, zero_one_loss) from sklearn.metrics.metrics import _check_clf_targets +from sklearn.metrics.metrics import _check_reg_targets from sklearn.externals.six.moves import xrange @@ -1004,6 +1005,7 @@ def test_format_invariance_with_1d_vectors(): name not in MULTILABELS_METRICS): assert_raises(ValueError, metric, y1_row, y2_row) + def test_clf_single_sample(): """Non-regression test: scores should work with a single sample. @@ -1640,3 +1642,29 @@ def test__check_clf_targets(): assert_array_equal(y1out, np.squeeze(y1)) assert_array_equal(y2out, np.squeeze(y2)) assert_raises(ValueError, _check_clf_targets, y1[:-1], y2) + + +def test__check_reg_targets(): + # All of length 3 + EXAMPLES = [ + ("continuous", [1, 2, 3], 1), + ("continuous", [[1], [2], [3]], 1), + ("continuous-multioutput", [[1, 1], [2, 2], [3, 1]], 2), + ("continuous-multioutput", [[5, 1], [4, 2], [3, 1]], 2), + ("continuous-multioutput", [[1, 3, 4], [2, 2, 2], [3, 1, 1]], 3), + ] + + for (type1, y1, n_out1), (type2, y2, n_out2) in product(EXAMPLES, + EXAMPLES): + + if type1 == type2 and n_out1 == n_out2: + y_type, y_check1, y_check2 = _check_reg_targets(y1, y2) + assert_equal(type1, y_type) + if type1 == 'continuous': + assert_array_equal(y_check1, np.reshape(y1, (-1, 1))) + assert_array_equal(y_check2, np.reshape(y2, (-1, 1))) + else: + assert_array_equal(y_check1, y1) + assert_array_equal(y_check2, y2) + else: + assert_raises(ValueError, _check_reg_targets, y1, y2) From a7353446736d4589847e2c11f9c6d9dd28a1ecd2 Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Wed, 10 Jul 2013 18:55:24 +0200 Subject: [PATCH 13/14] ENH dryer thanks to @jnothman --- sklearn/metrics/metrics.py | 47 +++++++++++++-------------- sklearn/metrics/tests/test_metrics.py | 25 ++++++++++++++ 2 files changed, 48 insertions(+), 24 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 9ca68485b4424..da83cbe9c1ad9 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -34,6 +34,24 @@ ############################################################################### # General utilities ############################################################################### +def _column_or_1d(y): + """ Ravel column or 1d numpy array, else raises an error + + Parameters + ---------- + y : array-like + + Returns + ------- + y : array + + """ + shape = np.shape(y) + if len(shape) == 1 or (len(shape) == 2 and shape[1] == 1): + return np.ravel(y) + raise ValueError("bad input shape {0}".format(shape)) + + def _check_reg_targets(y_true, y_pred): """Check that y_true and y_pred belong to the same regression task @@ -41,7 +59,7 @@ def _check_reg_targets(y_true, y_pred): ---------- y_true : array-like, - y_pred : array-like + y_pred : array-like, Returns ------- @@ -121,18 +139,8 @@ def _check_clf_targets(y_true, y_pred): # 'binary' can be removed type_true = type_pred = 'multiclass' - y_true, y_pred = check_arrays(y_true, y_pred) - if (not (y_true.ndim == 1 or - (y_true.ndim == 2 and y_true.shape[1] == 1))): - raise ValueError("y_true has a bad input shape {0} " - "".format(y_true.shape)) - - if (not (y_pred.ndim == 1 or - (y_pred.ndim == 2 and y_pred.shape[1] == 1))): - raise ValueError("y_pred has a bad input shape {0}" - "".format(y_pred.shape)) - y_true = y_true.ravel() - y_pred = y_pred.ravel() + y_true = _column_or_1d(y_true) + y_pred = _column_or_1d(y_pred) else: raise ValueError("Can't handle %s/%s targets" % (type_true, type_pred)) @@ -460,17 +468,8 @@ def _binary_clf_curve(y_true, y_score, pos_label=None): Decreasing score values. """ y_true, y_score = check_arrays(y_true, y_score) - if (not (y_true.ndim == 1 or - (y_true.ndim == 2 and y_true.shape[1] == 1))): - raise ValueError("y_true has a bad input shape {0} " - "".format(y_true.shape)) - - if (not (y_score.ndim == 1 or - (y_score.ndim == 2 and y_score.shape[1] == 1))): - raise ValueError("y_score has a bad input shape {0}" - "".format(y_score.shape)) - y_true = np.ravel(y_true) - y_score = np.ravel(y_score) + y_true = _column_or_1d(y_true) + y_score = _column_or_1d(y_score) # ensure binary classification if pos_label is not specified classes = np.unique(y_true) diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index c6facf6528d64..8be5981a4b84b 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -49,6 +49,8 @@ zero_one_loss) from sklearn.metrics.metrics import _check_clf_targets from sklearn.metrics.metrics import _check_reg_targets +from sklearn.metrics.metrics import _column_or_1d + from sklearn.externals.six.moves import xrange @@ -1668,3 +1670,26 @@ def test__check_reg_targets(): assert_array_equal(y_check2, y2) else: assert_raises(ValueError, _check_reg_targets, y1, y2) + + +def test__column_or_1d(): + EXAMPLES = [ + ("binary", ["spam", "egg", "spam"]), + ("binary", [0, 1, 0, 1]), + ("continuous", np.arange(10) / 20.), + ("multiclass", [1, 2, 3]), + ("multiclass", [0, 1, 2, 2, 0]), + ("multiclass", [[1], [2], [3]]), + ("multilabel-indicator", [[0, 1, 0], [0, 0, 1]]), + ("multiclass-multioutput", [[1, 2, 3]]), + ("multiclass-multioutput", [[1, 1], [2, 2], [3, 1]]), + ("multiclass-multioutput", [[5, 1], [4, 2], [3, 1]]), + ("multiclass-multioutput", [[1, 2, 3]]), + ("continuous-multioutput", np.arange(30).reshape((-1, 3))), + ] + + for y_type, y in EXAMPLES: + if y_type in ["binary", 'multiclass', "continuous"]: + _column_or_1d(y) + else: + assert_raises(ValueError, _column_or_1d, y) From 93335af3818cb93dcbf3f0e9a3ac99ec158ddfe5 Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Thu, 11 Jul 2013 07:55:53 +0200 Subject: [PATCH 14/14] TST stronger test for _column_or_1d function --- sklearn/metrics/tests/test_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 8be5981a4b84b..b70a912894eb2 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -1690,6 +1690,6 @@ def test__column_or_1d(): for y_type, y in EXAMPLES: if y_type in ["binary", 'multiclass', "continuous"]: - _column_or_1d(y) + assert_array_equal(_column_or_1d(y), np.ravel(y)) else: assert_raises(ValueError, _column_or_1d, y)