diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index d38afc347a46b..da83cbe9c1ad9 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -34,102 +34,60 @@ ############################################################################### # General utilities ############################################################################### -def _is_1d(x): - """Return True if x is 1d or a column vector +def _column_or_1d(y): + """ Ravel column or 1d numpy array, else raises an error Parameters ---------- - x : numpy array. + y : array-like Returns ------- - is_1d : boolean, - Return True if x can be considered as a 1d vector. - - Examples - -------- - >>> import numpy as np - >>> from sklearn.metrics.metrics import _is_1d - >>> _is_1d([1, 2, 3]) - True - >>> _is_1d(np.array([1, 2, 3])) - True - >>> _is_1d([[1, 2, 3]]) - False - >>> _is_1d(np.array([[1, 2, 3]])) - False - >>> _is_1d([[1], [2], [3]]) - True - >>> _is_1d(np.array([[1], [2], [3]])) - True - >>> _is_1d([[1, 2], [3, 4]]) - False - >>> _is_1d(np.array([[1, 2], [3, 4]])) - False - - See also - -------- - _check_1d_array + y : array """ - shape = np.shape(x) - return len(shape) == 1 or len(shape) == 2 and shape[1] == 1 - + shape = np.shape(y) + if len(shape) == 1 or (len(shape) == 2 and shape[1] == 1): + return np.ravel(y) + raise ValueError("bad input shape {0}".format(shape)) -def _check_1d_array(y1, y2, ravel=False): - """Check that y1 and y2 are vectors of the same shape. - It convert 1d arrays (y1 and y2) of various shape to a common shape - representation. Note that ``y1`` and ``y2`` should have the same number of - elements. +def _check_reg_targets(y_true, y_pred): + """Check that y_true and y_pred belong to the same regression task Parameters ---------- - y1 : array-like, - y1 must be a "vector". - - y2 : array-like - y2 must be a "vector". + y_true : array-like, - ravel : boolean, optional (default=False), - If ``ravel``` is set to ``True``, then ``y1`` and ``y2`` are raveled. + y_pred : array-like, Returns ------- - y1 : numpy array, - If ``ravel`` is set to ``True``, return np.ravel(y1), else - return y1. - - y2 : numpy array, - Return y2 reshaped to have the shape of y1. - - Examples - -------- - >>> from sklearn.metrics.metrics import _check_1d_array - >>> _check_1d_array([1, 2], [[3], [4]]) - (array([1, 2]), array([3, 4])) + type_true : one of {'continuous', continuous-multioutput'} + The type of the true target data, as output by + ``utils.multiclass.type_of_target`` - See also - -------- - _is_1d + y_true : array-like of shape = [n_samples, n_outputs] + Ground truth (correct) target values. + y_pred : array-like of shape = [n_samples, n_outputs] + Estimated target values. """ - y1 = np.asarray(y1) - y2 = np.asarray(y2) + y_true, y_pred = check_arrays(y_true, y_pred) - if not _is_1d(y1): - raise ValueError("y1 can't be considered as a vector") + if y_true.ndim == 1: + y_true = y_true.reshape((-1, 1)) - if not _is_1d(y2): - raise ValueError("y2 can't be considered as a vector") + if y_pred.ndim == 1: + y_pred = y_pred.reshape((-1, 1)) - if ravel: - return np.ravel(y1), np.ravel(y2) - else: - if np.shape(y1) != np.shape(y2): - y2 = np.reshape(y2, np.shape(y1)) + if y_true.shape[1] != y_pred.shape[1]: + raise ValueError("y_true and y_pred have different number of output " + "({0}!={1})".format(y_true.shape[1], y_true.shape[1])) + + y_type = 'continuous' if y_true.shape[1] == 1 else 'continuous-multioutput' - return y1, y2 + return y_type, y_true, y_pred def _check_clf_targets(y_true, y_pred): @@ -181,8 +139,8 @@ def _check_clf_targets(y_true, y_pred): # 'binary' can be removed type_true = type_pred = 'multiclass' - y_true = np.ravel(y_true) - y_pred = np.ravel(y_pred) + y_true = _column_or_1d(y_true) + y_pred = _column_or_1d(y_pred) else: raise ValueError("Can't handle %s/%s targets" % (type_true, type_pred)) @@ -466,8 +424,10 @@ def matthews_corrcoef(y_true, y_pred): -0.33... """ - y_true, y_pred = check_arrays(y_true, y_pred) - y_true, y_pred = _check_1d_array(y_true, y_pred, ravel=True) + y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred) + + if y_type != "binary": + raise ValueError("%s is not supported" % y_type) mcc = np.corrcoef(y_true, y_pred)[0, 1] if np.isnan(mcc): @@ -508,7 +468,8 @@ def _binary_clf_curve(y_true, y_score, pos_label=None): Decreasing score values. """ y_true, y_score = check_arrays(y_true, y_score) - y_true, y_score = _check_1d_array(y_true, y_score, ravel=True) + y_true = _column_or_1d(y_true) + y_score = _column_or_1d(y_score) # ensure binary classification if pos_label is not specified classes = np.unique(y_true) @@ -744,8 +705,9 @@ def confusion_matrix(y_true, y_pred, labels=None): [1, 0, 2]]) """ - y_true, y_pred = check_arrays(y_true, y_pred) - y_true, y_pred = _check_1d_array(y_true, y_pred, ravel=True) + y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred) + if y_type not in ("binary", "multiclass"): + raise ValueError("%s is not supported" % y_type) if labels is None: labels = unique_labels(y_true, y_pred) @@ -833,20 +795,13 @@ def zero_one_loss(y_true, y_pred, normalize=True): """ - y_true, y_pred = check_arrays(y_true, y_pred, allow_lists=True) score = accuracy_score(y_true, y_pred, normalize=normalize) if normalize: return 1 - score else: - if hasattr(y_true, "shape"): - n_samples = (np.max(y_true.shape) if _is_1d(y_true) - else y_true.shape[0]) - - else: - n_samples = len(y_true) - + n_samples = len(y_true) return n_samples - score @@ -1074,6 +1029,7 @@ def accuracy_score(y_true, y_pred, normalize=True): 0.0 """ + # Compute accuracy for each possible representation y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred) if y_type == 'multilabel-indicator': @@ -1422,9 +1378,9 @@ def _tp_tn_fp_fn(y_true, y_pred, labels=None): else: labels = np.asarray(labels) n_labels = labels.size - true_pos = np.zeros((n_labels), dtype=np.int) - false_pos = np.zeros((n_labels), dtype=np.int) - false_neg = np.zeros((n_labels), dtype=np.int) + true_pos = np.zeros((n_labels,), dtype=np.int) + false_pos = np.zeros((n_labels,), dtype=np.int) + false_neg = np.zeros((n_labels,), dtype=np.int) if y_type == 'multilabel-indicator': true_pos = np.sum(np.logical_and(y_true == 1, @@ -1454,12 +1410,7 @@ def _tp_tn_fp_fn(y_true, y_pred, labels=None): false_neg[i] = np.sum(y_pred[y_true == label_i] != label_i) # Compute the true_neg using the tp, fp and fn - if hasattr(y_true, "shape"): - n_samples = (np.max(y_true.shape) if _is_1d(y_true) - else y_true.shape[0]) - else: - n_samples = len(y_true) - + n_samples = len(y_true) true_neg = n_samples - true_pos - false_pos - false_neg return true_pos, true_neg, false_pos, false_neg @@ -2192,6 +2143,7 @@ def hamming_loss(y_true, y_pred, classes=None): """ y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred) + if classes is None: classes = unique_labels(y_true, y_pred) else: @@ -2200,10 +2152,10 @@ def hamming_loss(y_true, y_pred, classes=None): if y_type == 'multilabel-indicator': return np.mean(y_true != y_pred) elif y_type == 'multilabel-sequences': - loss = np.array([len(set(pred) ^ set(true)) - for pred, true in zip(y_pred, y_true)]) + loss = np.array([len(set(pred) ^ set(true)) + for pred, true in zip(y_pred, y_true)]) - return np.mean(loss) / np.size(classes) + return np.mean(loss) / np.size(classes) else: return sp_hamming(y_true, y_pred) @@ -2241,12 +2193,7 @@ def mean_absolute_error(y_true, y_pred): 0.75 """ - y_true, y_pred = check_arrays(y_true, y_pred) - - # Handle mix 1d representation - if _is_1d(y_true): - y_true, y_pred = _check_1d_array(y_true, y_pred) - + y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) return np.mean(np.abs(y_pred - y_true)) @@ -2279,12 +2226,7 @@ def mean_squared_error(y_true, y_pred): 0.708... """ - y_true, y_pred = check_arrays(y_true, y_pred) - - # Handle mix 1d representation - if _is_1d(y_true): - y_true, y_pred = _check_1d_array(y_true, y_pred) - + y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) return np.mean((y_pred - y_true) ** 2) @@ -2322,11 +2264,10 @@ def explained_variance_score(y_true, y_pred): 0.957... """ - y_true, y_pred = check_arrays(y_true, y_pred) + y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - # Handle mix 1d representation - if _is_1d(y_true): - y_true, y_pred = _check_1d_array(y_true, y_pred) + if y_type != "continuous": + raise ValueError("{0} is not supported".format(y_type)) numerator = np.var(y_true - y_pred) denominator = np.var(y_true) @@ -2383,11 +2324,7 @@ def r2_score(y_true, y_pred): 0.938... """ - y_true, y_pred = check_arrays(y_true, y_pred) - - # Handle mix 1d representation - if _is_1d(y_true): - y_true, y_pred = _check_1d_array(y_true, y_pred, ravel=True) + y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) if len(y_true) == 1: raise ValueError("r2_score can only be computed given more than one" diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 51638bdc56459..b70a912894eb2 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -2,6 +2,7 @@ from functools import partial import warnings +from itertools import product import numpy as np from sklearn import datasets @@ -47,12 +48,16 @@ zero_one_score, zero_one_loss) from sklearn.metrics.metrics import _check_clf_targets +from sklearn.metrics.metrics import _check_reg_targets +from sklearn.metrics.metrics import _column_or_1d + + from sklearn.externals.six.moves import xrange ALL_METRICS = { "accuracy_score": accuracy_score, "unormalized_accuracy_score": partial(accuracy_score, normalize=False), - + "confusion_matrix": confusion_matrix, "hamming_loss": hamming_loss, "jaccard_similarity_score": jaccard_similarity_score, @@ -115,6 +120,12 @@ "zero_one_loss": zero_one_loss, "unnormalized_zero_one_loss": partial(zero_one_loss, normalize=False), + "precision_score": precision_score, + "recall_score": recall_score, + "f1_score": f1_score, + "f2_score": partial(fbeta_score, beta=2), + "f0.5_score": partial(fbeta_score, beta=0.5), + "weighted_f0.5_score": partial(fbeta_score, average="weighted", beta=0.5), "weighted_f1_score": partial(f1_score, average="weighted"), "weighted_f2_score": partial(fbeta_score, average="weighted", beta=2), @@ -140,6 +151,13 @@ "macro_recall_score": partial(recall_score, average="macro"), } +MULTIOUTPUT_METRICS = { + "mean_absolute_error": mean_absolute_error, + "mean_squared_error": mean_squared_error, + "r2_score": r2_score, +} + + SYMMETRIC_METRICS = { "accuracy_score": accuracy_score, "unormalized_accuracy_score": partial(accuracy_score, normalize=False), @@ -167,6 +185,8 @@ "explained_variance_score": explained_variance_score, "r2_score": r2_score, + "confusion_matrix": confusion_matrix, + "precision_score": precision_score, "recall_score": recall_score, "f2_score": partial(fbeta_score, beta=2), @@ -973,15 +993,19 @@ def test_format_invariance_with_1d_vectors(): "with mix list and np-array-column" % name) - # At the moment, these mix representations aren't allowed + # These mix representations aren't allowed assert_raises(ValueError, metric, y1_1d, y2_row) assert_raises(ValueError, metric, y1_row, y2_1d) assert_raises(ValueError, metric, y1_list, y2_row) assert_raises(ValueError, metric, y1_row, y2_list) assert_raises(ValueError, metric, y1_column, y2_row) assert_raises(ValueError, metric, y1_row, y2_column) + # NB: We do not test for y1_row, y2_row as these may be # interpreted as multilabel or multioutput data. + if (name not in MULTIOUTPUT_METRICS and + name not in MULTILABELS_METRICS): + assert_raises(ValueError, metric, y1_row, y2_row) def test_clf_single_sample(): @@ -1040,9 +1064,8 @@ def test_multioutput_number_of_output_differ(): y_true = np.array([[1, 0, 0, 1], [0, 1, 1, 1], [1, 1, 0, 1]]) y_pred = np.array([[0, 0], [1, 0], [0, 0]]) - assert_raises(ValueError, mean_squared_error, y_true, y_pred) - assert_raises(ValueError, mean_absolute_error, y_true, y_pred) - assert_raises(ValueError, r2_score, y_true, y_pred) + for name, metrics in MULTIOUTPUT_METRICS.items(): + assert_raises(ValueError, metrics, y_true, y_pred) def test_multioutput_regression_invariance_to_dimension_shuffling(): @@ -1053,13 +1076,15 @@ def test_multioutput_regression_invariance_to_dimension_shuffling(): y_pred = np.reshape(y_pred, (-1, n_dims)) rng = check_random_state(314159) - for metric in [r2_score, mean_squared_error, mean_absolute_error]: + for name, metric in MULTIOUTPUT_METRICS.items(): error = metric(y_true, y_pred) for _ in xrange(3): perm = rng.permutation(n_dims) assert_almost_equal(metric(y_true[:, perm], y_pred[:, perm]), - error) + error, + err_msg="%s is not dimension shuffling" + "invariant" % name) def test_multilabel_representation_invariance(): @@ -1619,3 +1644,52 @@ def test__check_clf_targets(): assert_array_equal(y1out, np.squeeze(y1)) assert_array_equal(y2out, np.squeeze(y2)) assert_raises(ValueError, _check_clf_targets, y1[:-1], y2) + + +def test__check_reg_targets(): + # All of length 3 + EXAMPLES = [ + ("continuous", [1, 2, 3], 1), + ("continuous", [[1], [2], [3]], 1), + ("continuous-multioutput", [[1, 1], [2, 2], [3, 1]], 2), + ("continuous-multioutput", [[5, 1], [4, 2], [3, 1]], 2), + ("continuous-multioutput", [[1, 3, 4], [2, 2, 2], [3, 1, 1]], 3), + ] + + for (type1, y1, n_out1), (type2, y2, n_out2) in product(EXAMPLES, + EXAMPLES): + + if type1 == type2 and n_out1 == n_out2: + y_type, y_check1, y_check2 = _check_reg_targets(y1, y2) + assert_equal(type1, y_type) + if type1 == 'continuous': + assert_array_equal(y_check1, np.reshape(y1, (-1, 1))) + assert_array_equal(y_check2, np.reshape(y2, (-1, 1))) + else: + assert_array_equal(y_check1, y1) + assert_array_equal(y_check2, y2) + else: + assert_raises(ValueError, _check_reg_targets, y1, y2) + + +def test__column_or_1d(): + EXAMPLES = [ + ("binary", ["spam", "egg", "spam"]), + ("binary", [0, 1, 0, 1]), + ("continuous", np.arange(10) / 20.), + ("multiclass", [1, 2, 3]), + ("multiclass", [0, 1, 2, 2, 0]), + ("multiclass", [[1], [2], [3]]), + ("multilabel-indicator", [[0, 1, 0], [0, 0, 1]]), + ("multiclass-multioutput", [[1, 2, 3]]), + ("multiclass-multioutput", [[1, 1], [2, 2], [3, 1]]), + ("multiclass-multioutput", [[5, 1], [4, 2], [3, 1]]), + ("multiclass-multioutput", [[1, 2, 3]]), + ("continuous-multioutput", np.arange(30).reshape((-1, 3))), + ] + + for y_type, y in EXAMPLES: + if y_type in ["binary", 'multiclass', "continuous"]: + assert_array_equal(_column_or_1d(y), np.ravel(y)) + else: + assert_raises(ValueError, _column_or_1d, y) diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py index 11289c433bf2d..3f365ab1e4d09 100644 --- a/sklearn/utils/multiclass.py +++ b/sklearn/utils/multiclass.py @@ -212,6 +212,8 @@ def is_multilabel(y): True >>> is_multilabel(np.array([[1], [0], [0]])) False + >>> is_multilabel(np.array([[1, 0, 0]])) + True """ return is_label_indicator_matrix(y) or is_sequence_of_sequences(y)