From a32984673fece64f192a34e7c0124cf9610ccf08 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Sun, 20 Jul 2014 20:40:20 +0200 Subject: [PATCH 1/2] replayed manojs changes in pr2493 onto master --- sklearn/metrics/metrics.py | 249 +++++++++++++++++++++----- sklearn/metrics/tests/test_metrics.py | 71 +++++++- 2 files changed, 274 insertions(+), 46 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 8c07d5452d2bb..1076b6e2e5c44 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -16,6 +16,8 @@ # Lars Buitinck # Joel Nothman # Noel Dawe +# Manoj Kumar +# Michael Eickenberg # License: BSD 3 clause from __future__ import division @@ -33,6 +35,7 @@ from ..utils import check_arrays from ..utils import deprecated from ..utils import column_or_1d +from ..utils import safe_asarray from ..utils.multiclass import unique_labels from ..utils.multiclass import type_of_target from ..utils.fixes import isclose @@ -41,7 +44,7 @@ ############################################################################### # General utilities ############################################################################### -def _check_reg_targets(y_true, y_pred): +def _check_reg_targets(y_true, y_pred, output_weights): """Check that y_true and y_pred belong to the same regression task Parameters @@ -50,6 +53,8 @@ def _check_reg_targets(y_true, y_pred): y_pred : array-like, + output_weights : array-like or string in ['uniform', 'variance'] or None + Returns ------- type_true : one of {'continuous', continuous-multioutput'} @@ -61,6 +66,13 @@ def _check_reg_targets(y_true, y_pred): y_pred : array-like of shape = [n_samples, n_outputs] Estimated target values. + + output_weights : array-like of shape = [n_outputs] + or string in ['uniform', 'variance'] or None + Custom output weights is output_weights is array-like. 'uniform' and + 'variance' indicate specific weight vectors. + None indicates no agglomeration of scores across targets: Scores are + output separately per target. """ y_true, y_pred = check_arrays(y_true, y_pred) @@ -74,9 +86,20 @@ def _check_reg_targets(y_true, y_pred): raise ValueError("y_true and y_pred have different number of output " "({0}!={1})".format(y_true.shape[1], y_pred.shape[1])) + n_outputs = y_true.shape[1] + output_weights_options = (None, 'uniform', 'variance') + if output_weights not in output_weights_options: + output_weights = safe_asarray(output_weights) + if n_outputs == 1: + raise ValueError("Custom weights are useful only in " + "multi-output cases.") + elif n_outputs != len(output_weights): + raise ValueError(("There must be equally many custom weights " + "(%d) as outputs (%d).") % + (len(outputs), n_outputs)) y_type = 'continuous' if y_true.shape[1] == 1 else 'continuous-multioutput' - return y_type, y_true, y_pred + return y_type, y_true, y_pred, output_weights def _check_clf_targets(y_true, y_pred): @@ -133,7 +156,7 @@ def _check_clf_targets(y_true, y_pred): return y_type, y_true, y_pred -def _average_and_variance(values, sample_weight=None): +def _average_and_variance(values, sample_weight=None, axis=None): """ Compute the (weighted) average and variance. @@ -159,8 +182,22 @@ def _average_and_variance(values, sample_weight=None): sample_weight = np.asarray(sample_weight) if sample_weight.ndim == 1: sample_weight = sample_weight.reshape((-1, 1)) - average = np.average(values, weights=sample_weight) - variance = np.average((values - average)**2, weights=sample_weight) + if axis is not None: + values = np.rollaxis(values, axis) + axis = 0 + # if multi output but sample weight only specified in one column, + # then we need to broadcast it over outputs + if (sample_weight is not None and + values.shape[1] != sample_weight.shape[1]): + if sample_weight.shape[1] != 1: + raise ValueError("Sample weight shape and data shape " + "do not correspond.") + sample_weight = sample_weight * np.ones([1, values.shape[1]]) + + average = np.average(values, weights=sample_weight, axis=axis) + variance = np.average((values - average)**2, + weights=sample_weight, + axis=axis) return average, variance @@ -2147,7 +2184,9 @@ def hamming_loss(y_true, y_pred, classes=None): ############################################################################### # Regression metrics ############################################################################### -def mean_absolute_error(y_true, y_pred, sample_weight=None): +def mean_absolute_error(y_true, y_pred, + output_weights='uniform', + sample_weight=None): """Mean absolute error regression loss Parameters @@ -2158,13 +2197,31 @@ def mean_absolute_error(y_true, y_pred, sample_weight=None): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. + output_weights : string in ['uniform'] or None + or array-like of shape [n_outputs] + Weights by which to average scores of outputs. Useful only if using + multiple outputs. + + ``ndarray`` : + array containing weights for the weighted average. + + ``'uniform'`` : + Scores of all outputs are averaged with uniform weight. + + ``None`` : + No averaging is performed, an array of scores is returned. + sample_weight : array-like of shape = [n_samples], optional Sample weights. - Returns ------- - loss : float - A positive floating point value (the best value is 0.0). + loss : float or ndarray of shape [n_outputs] + If output_weights is None, then mean absolute error is returned for + each output separately. + If output_weights is 'uniform' or an ndarray of weights, then the + weighted average of all output scores is returned. + + MAE output is non-negative floating point. The best value is 0.0. Examples -------- @@ -2177,14 +2234,29 @@ def mean_absolute_error(y_true, y_pred, sample_weight=None): >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_absolute_error(y_true, y_pred) 0.75 - + >>> mean_absolute_error(y_true, y_pred, output_weights=None) + array([ 0.5, 1. ]) + >>> mean_absolute_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.849... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - return np.average(np.abs(y_pred - y_true).mean(axis=1), - weights=sample_weight) + y_type, y_true, y_pred, output_weights = _check_reg_targets( + y_true, y_pred, output_weights) + individual_errors = np.average(np.abs(y_pred - y_true), + weights=sample_weight, axis=0) + + if output_weights == 'uniform': + error = np.average(individual_errors) + elif output_weights is None: + error = individual_errors + else: + error = np.average(individual_errors, weights=output_weights) + return error -def mean_squared_error(y_true, y_pred, sample_weight=None): +def mean_squared_error(y_true, y_pred, + output_weights='uniform', + sample_weight=None): """Mean squared error regression loss Parameters @@ -2195,6 +2267,20 @@ def mean_squared_error(y_true, y_pred, sample_weight=None): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. + output_weights : string in ['uniform'] or None + or array-like of shape [n_outputs] + Weights by which to average scores of outputs. Useful only if using + multiple outputs. + + ``ndarray`` : + array containing weights for the weighted average. + + ``'uniform'`` : + Scores of all outputs are averaged with uniform weight. + + ``None`` : + No averaging is performed, an array of scores is returned. + sample_weight : array-like of shape = [n_samples], optional Sample weights. @@ -2214,17 +2300,34 @@ def mean_squared_error(y_true, y_pred, sample_weight=None): >>> y_pred = [[0, 2],[-1, 2],[8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.708... + >>> mean_squared_error(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS + array([ 0.416..., 1. ]) + >>> mean_squared_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.824... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - return np.average(((y_pred - y_true) ** 2).mean(axis=1), - weights=sample_weight) + y_type, y_true, y_pred, output_weights = _check_reg_targets( + y_true, y_pred, output_weights) + individual_errors = np.average((y_true - y_pred) ** 2, axis=0, + weights=sample_weight) + if output_weights == 'uniform': + error = np.average(individual_errors) + elif output_weights is None: + error = individual_errors + else: + error = np.average(individual_errors, weights=output_weights) + + return error ############################################################################### # Regression score functions ############################################################################### -def explained_variance_score(y_true, y_pred, sample_weight=None): +def explained_variance_score(y_true, y_pred, + output_weights='uniform', + sample_weight=None): """Explained variance regression score function Best possible score is 1.0, lower values are worse. @@ -2237,6 +2340,24 @@ def explained_variance_score(y_true, y_pred, sample_weight=None): y_pred : array-like Estimated target values. + output_weights : string in ['uniform', 'variance'] or None + or array-like of shape [n_outputs] + Weights by which to average scores of outputs. Useful only if using + multiple outputs. + + ``ndarray`` : + array containing weights for the weighted average. + + ``'uniform'`` : + Scores of all outputs are averaged with uniform weight. + + ``'variance'`` : + Scores of all outputs are averaged, weighted by the variances + of each individual output. This corresponds to a global explained + + ``None`` : + No averaging is performed, an array of scores is returned. + sample_weight : array-like of shape = [n_samples], optional Sample weights. @@ -2258,24 +2379,38 @@ def explained_variance_score(y_true, y_pred, sample_weight=None): 0.957... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) + y_type, y_true, y_pred, output_weights = _check_reg_targets( + y_true, y_pred, output_weights) - if y_type != "continuous": + if y_type not in ["continuous", "continuous-multioutput"]: raise ValueError("{0} is not supported".format(y_type)) - _, numerator = _average_and_variance(y_true - y_pred, sample_weight) - _, denominator = _average_and_variance(y_true, sample_weight) - if denominator == 0.0: - if numerator == 0.0: - return 1.0 - else: - # arbitrary set to zero to avoid -inf scores, having a constant - # y_true is not interesting for scoring a regression anyway - return 0.0 - return 1 - numerator / denominator + _, numerator = _average_and_variance(y_true - y_pred, sample_weight, + axis=0) + _, denominator = _average_and_variance(y_true, sample_weight, + axis=0) + + nonzero_numerator = numerator != 0 + nonzero_denominator = denominator != 0 + valid_score = nonzero_numerator * nonzero_denominator + individual_scores = np.ones(y_true.shape[1]) + + individual_scores[valid_score] = 1 - (numerator[valid_score] / + denominator[valid_score]) + individual_scores[nonzero_numerator * ~nonzero_denominator] = 0. + if output_weights == 'uniform': + score = np.average(individual_scores) + elif output_weights == 'variance': + score = np.average(individual_scores, weights=denominator) + elif output_weights == None: + score = individual_scores + else: + score = np.average(individual_scores, weights=output_weights) + + return score -def r2_score(y_true, y_pred, sample_weight=None): +def r2_score(y_true, y_pred, output_weights='uniform', sample_weight=None): """R^2 (coefficient of determination) regression score function. Best possible score is 1.0, lower values are worse. @@ -2288,6 +2423,21 @@ def r2_score(y_true, y_pred, sample_weight=None): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. + output_weights : string in ['uniform', 'variance'] or None + or array-like of shape [n_outputs] + Weights by which to average scores of outputs. Useful only if using + multiple outputs. + + ``ndarray`` : + array containing weights for the weighted average. + + ``'uniform'`` : + Scores of all outputs are averaged with uniform weight. + + ``'variance'`` : + Scores of all outputs are averaged, weighted by the variances + of each individual output. This corresponds to a global explained + sample_weight : array-like of shape = [n_samples], optional Sample weights. @@ -2321,7 +2471,8 @@ def r2_score(y_true, y_pred, sample_weight=None): 0.938... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) + y_type, y_true, y_pred, output_weights = _check_reg_targets( + y_true, y_pred, output_weights=output_weights) if sample_weight is not None: sample_weight = column_or_1d(sample_weight) @@ -2329,16 +2480,28 @@ def r2_score(y_true, y_pred, sample_weight=None): else: weight = 1. - numerator = (weight * (y_true - y_pred) ** 2).sum(dtype=np.float64) + numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, + dtype=np.float64) denominator = (weight * (y_true - np.average( - y_true, axis=0, weights=sample_weight)) ** 2).sum(dtype=np.float64) - - if denominator == 0.0: - if numerator == 0.0: - return 1.0 - else: - # arbitrary set to zero to avoid -inf scores, having a constant - # y_true is not interesting for scoring a regression anyway - return 0.0 + y_true, axis=0, weights=sample_weight)) ** 2).sum(axis=0, + dtype=np.float64) + + nonzero_denominator = denominator != 0 + nonzero_numerator = numerator != 0 + valid_score = nonzero_denominator * nonzero_numerator + individual_scores = np.ones([y_true.shape[1]]) + individual_scores[valid_score] = 1 - (numerator[valid_score] / + denominator[valid_score]) + # arbitrary set to zero to avoid -inf scores, having a constant + # y_true is not interesting for scoring a regression anyway + individual_scores[nonzero_numerator * ~nonzero_denominator] = 0. + if output_weights == 'uniform': + score = np.average(individual_scores) + elif output_weights == 'variance': + score = np.average(individual_scores, weights=denominator) + elif output_weights is None: + score = individual_scores + else: + score = np.average(individual_scores, weights=output_weights) - return 1 - numerator / denominator + return score diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 7074dc00675eb..8f3fa861cf23b 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -283,6 +283,7 @@ # Regression metrics with "multioutput-continuous" format support MULTIOUTPUT_METRICS = [ "mean_absolute_error", "mean_squared_error", "r2_score", + "explained_variance_score" ] # Symmetric with respect to their input arguments y_true and y_pred @@ -1632,7 +1633,7 @@ def test_multioutput_regression(): assert_almost_equal(error, (1. / 3 + 2. / 3 + 2. / 3) / 4.) error = r2_score(y_true, y_pred) - assert_almost_equal(error, 1 - 5. / 2) + assert_almost_equal(error, -0.875) def test_multioutput_number_of_output_differ(): @@ -2409,7 +2410,8 @@ def test__check_reg_targets(): repeat=2): if type1 == type2 and n_out1 == n_out2: - y_type, y_check1, y_check2 = _check_reg_targets(y1, y2) + y_type, y_check1, y_check2, output_weights = \ + _check_reg_targets(y1, y2, None) assert_equal(type1, y_type) if type1 == 'continuous': assert_array_equal(y_check1, np.reshape(y1, (-1, 1))) @@ -2418,7 +2420,7 @@ def test__check_reg_targets(): assert_array_equal(y_check1, y1) assert_array_equal(y_check2, y2) else: - assert_raises(ValueError, _check_reg_targets, y1, y2) + assert_raises(ValueError, _check_reg_targets, y1, y2, None) def test_log_loss(): @@ -2707,3 +2709,66 @@ def test_sample_weight_invariance(): else: yield (check_sample_weight_invariance, name, metric, y_true, y_pred) + +def test_regression_multioutput_array(): + y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]] + y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]] + + mse = mean_squared_error(y_true, y_pred, output_weights=None) + mae = mean_absolute_error(y_true, y_pred, output_weights=None) + r = r2_score(y_true, y_pred, output_weights=None) + evs = explained_variance_score(y_true, y_pred, output_weights=None) + + assert_array_equal(mse, np.array([0.125, 0.5625])) + assert_array_equal(mae, np.array([0.25, 0.625])) + assert_array_almost_equal(r, np.array([0.95, 0.93]), decimal=2) + assert_array_almost_equal(evs, np.array([0.95, 0.93]), decimal=2) + + # mean_absolute_error and mean_squared_error are equal because + # it is a binary problem. + y_true = [[0, 0]]*4 + y_pred = [[1, 1]]*4 + mse = mean_squared_error(y_true, y_pred, output_weights=None) + mae = mean_absolute_error(y_true, y_pred, output_weights=None) + r = r2_score(y_true, y_pred, output_weights=None) + assert_array_equal(mse, np.array([1., 1.])) + assert_array_equal(mae, np.array([1., 1.])) + assert_array_almost_equal(r, np.array([0., 0.])) + + r = r2_score([[0, -1], [0, 1]], [[2, 2], [1, 1]], output_weights=None) + assert_array_equal(r, np.array([0, -3.5])) + assert_equal(np.mean(r), r2_score([[0, -1], [0, 1]], [[2, 2], [1, 1]])) + evs = explained_variance_score([[0, -1], [0, 1]], [[2, 2], [1, 1]], + output_weights=None) + assert_array_equal(evs, np.array([0, -1.25])) + + # Checking for the condition in which both numerator and denominator is + # zero. + y_true = [[1, 3], [-1, 2]] + y_pred = [[1, 4], [-1, 1]] + r = r2_score(y_true, y_pred, output_weights=None) + assert_array_equal(r, np.array([1., -3.])) + assert_equal(np.mean(r), r2_score(y_true, y_pred)) + evs = explained_variance_score(y_true, y_pred, output_weights=None) + assert_array_equal(evs, np.array([1., -3.])) + assert_equal(np.mean(evs), explained_variance_score(y_true, y_pred)) + + +def test_regression_custom_weights(): + y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]] + y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]] + + msew = mean_squared_error(y_true, y_pred, output_weights=[0.4, 0.6], + sample_weight=None) + maew = mean_absolute_error(y_true, y_pred, output_weights=[0.4, 0.6], + sample_weight=None) + rw = r2_score(y_true, y_pred, output_weights=[0.4, 0.6], + sample_weight=None) + evsw = explained_variance_score(y_true, y_pred, + output_weights=[0.4, 0.6], + sample_weight=None) + + assert_almost_equal(msew, 0.39, decimal=2) + assert_almost_equal(maew, 0.475, decimal=3) + assert_almost_equal(rw, 0.94, decimal=2) + assert_almost_equal(evsw, 0.94, decimal=2) From 5e64244c7bff8df54ca19e6c7dd52553e288945a Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Mon, 21 Jul 2014 13:30:10 +0200 Subject: [PATCH 2/2] Added documentation --- doc/modules/model_evaluation.rst | 44 +++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index c13ca6d232ef4..4df7d06aa2160 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -929,6 +929,14 @@ variance is estimated as follow: The best possible score is 1.0, lower values are worse. +The :func:`explained_variance_score` function has an `output_weights` keyword +with two possible values `None` and 'uniform'. If the value provided is `None`, +then the explained variance score is calculated for each dimension separately +and a numpy array is returned. If the value given is `uniform`, the +explained variance error is averaged over each dimension with a weight of +`1 / n_outputs`. If the value given is `variance`, then the explained +variance of each output is weighted by the scale of the variable. + Here a small example of usage of the :func:`explained_variance_score` function:: @@ -937,6 +945,14 @@ function:: >>> y_pred = [2.5, 0.0, 2, 8] >>> explained_variance_score(y_true, y_pred) # doctest: +ELLIPSIS 0.957... + >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] + >>> y_pred = [[0, 2], [-1, 2], [8, -5]] + >>> explained_variance_score(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS + array([ 0.967..., 1. ]) + >>> explained_variance_score(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.990... Mean absolute error ................... @@ -954,6 +970,14 @@ and :math:`y_i` is the corresponding true value, then the mean absolute error \text{MAE}(y, \hat{y}) = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}}-1} \left| y_i - \hat{y}_i \right|. + +The :func:`mean_absolute_error` function has an `output_weights` keyword +with two possible values `None` and 'uniform'. If the value provided is +`None`, then the mean absolute error is calculated for each dimension +separately and a numpy array is returned. If the value given is `uniform`, the +mean absolute error is averaged over each dimension with a weight of +`1 / n_outputs`. + Here a small example of usage of the :func:`mean_absolute_error` function:: >>> from sklearn.metrics import mean_absolute_error @@ -965,7 +989,11 @@ Here a small example of usage of the :func:`mean_absolute_error` function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_absolute_error(y_true, y_pred) 0.75 - + >>> mean_absolute_error(y_true, y_pred, output_weights=None) + array([ 0.5, 1. ]) + >>> mean_absolute_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.849... Mean squared error @@ -1021,6 +1049,14 @@ over :math:`n_{\text{samples}}` is defined as where :math:`\bar{y} = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}} - 1} y_i`. +The :func:`r2_score` function has an `output_weights` keyword with two possible +values `None` and 'uniform'. If the value provided is `None`, then the r2 score +is calculated for each dimension separately and a numpy array is returned. + If the value given is `uniform`, the r2 score is averaged over each dimension + with a weight of `1 / n_outputs`. If the value is `variance`, then the +output r2 scores are averaged weighted by the variances of the corresponding +variables, which corresponds to a global r2 score. + Here a small example of usage of the :func:`r2_score` function:: >>> from sklearn.metrics import r2_score @@ -1032,6 +1068,12 @@ Here a small example of usage of the :func:`r2_score` function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS 0.938... + >>> r2_score(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS + array([ 0.965..., 0.908...]) + >>> r2_score(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.925... .. topic:: Example: