diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 548613bbda656..f3e7ba6d64221 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -963,7 +963,25 @@ Regression metrics The :mod:`sklearn.metrics` implements several losses, scores and utility functions to measure regression performance. Some of those have been enhanced to handle the multioutput case: :func:`mean_absolute_error`, -:func:`mean_absolute_error` and :func:`r2_score`. +:func:`mean_absolute_error`, :func:`explained_variance_score` and +:func:`r2_score`. + + +These functions have an ``output_weights`` keyword argument which specifies +the way the scores for each individual target should be averaged. The default +is ``'uniform'``, which entails a uniformly weighted mean over outputs. If +an ``ndarray`` of shape ``(n_outputs,)`` is passed, then its entries are +interpreted as weights and an according weighted average is returned. If +``output_weights=None`` is specified, then all unaltered individual scores +will be returned in an array of shape ``(n_outputs,)``. + + +The :func:`r2_score` and :func:`explained_variance_score` additionally +accept ``output_weights='variance'``, which will lead to a weighting of +each individual score by the variance of the corresponding target variable. +This setting quantifies the globally captured unscaled variance. If the +target variables are of different scale, then this score puts more +importance on well explaining the higher variance variables. Explained variance score @@ -982,6 +1000,7 @@ variance is estimated as follow: The best possible score is 1.0, lower values are worse. + Here a small example of usage of the :func:`explained_variance_score` function:: @@ -990,6 +1009,14 @@ function:: >>> y_pred = [2.5, 0.0, 2, 8] >>> explained_variance_score(y_true, y_pred) # doctest: +ELLIPSIS 0.957... + >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] + >>> y_pred = [[0, 2], [-1, 2], [8, -5]] + >>> explained_variance_score(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS + array([ 0.967..., 1. ]) + >>> explained_variance_score(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.990... Mean absolute error ------------------- @@ -1007,6 +1034,7 @@ and :math:`y_i` is the corresponding true value, then the mean absolute error \text{MAE}(y, \hat{y}) = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}}-1} \left| y_i - \hat{y}_i \right|. + Here a small example of usage of the :func:`mean_absolute_error` function:: >>> from sklearn.metrics import mean_absolute_error @@ -1018,7 +1046,11 @@ Here a small example of usage of the :func:`mean_absolute_error` function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_absolute_error(y_true, y_pred) 0.75 - + >>> mean_absolute_error(y_true, y_pred, output_weights=None) + array([ 0.5, 1. ]) + >>> mean_absolute_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.849... Mean squared error @@ -1074,6 +1106,7 @@ over :math:`n_{\text{samples}}` is defined as where :math:`\bar{y} = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}} - 1} y_i`. + Here a small example of usage of the :func:`r2_score` function:: >>> from sklearn.metrics import r2_score @@ -1083,8 +1116,18 @@ Here a small example of usage of the :func:`r2_score` function:: 0.948... >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] - >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS + >>> r2_score(y_true, y_pred, output_weights='variance') # doctest: +ELLIPSIS 0.938... + >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] + >>> y_pred = [[0, 2], [-1, 2], [8, -5]] + >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS + 0.936... + >>> r2_score(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS + array([ 0.965..., 0.908...]) + >>> r2_score(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.925... .. topic:: Example: diff --git a/sklearn/metrics/regression.py b/sklearn/metrics/regression.py index 648422fa84693..b9cdc8f6bea1c 100644 --- a/sklearn/metrics/regression.py +++ b/sklearn/metrics/regression.py @@ -15,6 +15,8 @@ # Lars Buitinck # Joel Nothman # Noel Dawe +# Manoj Kumar +# Michael Eickenberg # License: BSD 3 clause from __future__ import division @@ -32,7 +34,7 @@ ] -def _check_reg_targets(y_true, y_pred): +def _check_reg_targets(y_true, y_pred, output_weights): """Check that y_true and y_pred belong to the same regression task Parameters @@ -41,6 +43,8 @@ def _check_reg_targets(y_true, y_pred): y_pred : array-like, + output_weights : array-like or string in ['uniform', 'variance'] or None + Returns ------- type_true : one of {'continuous', continuous-multioutput'} @@ -52,6 +56,13 @@ def _check_reg_targets(y_true, y_pred): y_pred : array-like of shape = [n_samples, n_outputs] Estimated target values. + + output_weights : array-like of shape = [n_outputs] or in ['uniform', 'variance', None (default)] + Custom output weights is output_weights is array-like. 'uniform' and + 'variance' indicate specific weight vectors. + None indicates no agglomeration of scores across targets: Scores are + returned separately per target. + """ check_consistent_length(y_true, y_pred) y_true = check_array(y_true, ensure_2d=False) @@ -67,12 +78,23 @@ def _check_reg_targets(y_true, y_pred): raise ValueError("y_true and y_pred have different number of output " "({0}!={1})".format(y_true.shape[1], y_pred.shape[1])) - y_type = 'continuous' if y_true.shape[1] == 1 else 'continuous-multioutput' + n_outputs = y_true.shape[1] + output_weights_options = (None, 'uniform', 'variance') + if output_weights not in output_weights_options: + output_weights = check_array(output_weights, ensure_2d=False) + if n_outputs == 1: + raise ValueError("Custom weights are useful only in " + "multi-output cases.") + elif n_outputs != len(output_weights): + raise ValueError(("There must be equally many custom weights " + "(%d) as outputs (%d).") % + (len(output_weights), n_outputs)) + y_type = 'continuous' if n_outputs == 1 else 'continuous-multioutput' - return y_type, y_true, y_pred + return y_type, y_true, y_pred, output_weights -def _average_and_variance(values, sample_weight=None): +def _average_and_variance(values, sample_weight=None, axis=None): """ Compute the (weighted) average and variance. @@ -83,6 +105,10 @@ def _average_and_variance(values, sample_weight=None): sample_weight : array-like of shape = [n_samples], optional Sample weights. + axis : integer or None, default None + Axis along which to calculate average and variance. Full array by + default. + Returns ------- average : float @@ -94,16 +120,28 @@ def _average_and_variance(values, sample_weight=None): values = np.asarray(values) if values.ndim == 1: values = values.reshape((-1, 1)) + n_samples, n_outputs = values.shape if sample_weight is not None: sample_weight = np.asarray(sample_weight) if sample_weight.ndim == 1: - sample_weight = sample_weight.reshape((-1, 1)) - average = np.average(values, weights=sample_weight) - variance = np.average((values - average)**2, weights=sample_weight) + sample_weight = sample_weight.reshape((n_samples, 1)) + # if multi output but sample weight only specified in one column, + # then we need to broadcast it over outputs + if (sample_weight is not None and n_outputs != sample_weight.shape[1]): + if sample_weight.shape[1] != 1: + raise ValueError("Sample weight shape and data shape " + "do not correspond.") + sample_weight = sample_weight * np.ones([1, n_outputs]) + + average = np.average(values, weights=sample_weight, axis=axis) + variance = np.average((values - average) ** 2, + weights=sample_weight, axis=axis) return average, variance -def mean_absolute_error(y_true, y_pred, sample_weight=None): +def mean_absolute_error(y_true, y_pred, + output_weights='uniform', + sample_weight=None): """Mean absolute error regression loss Parameters @@ -114,13 +152,32 @@ def mean_absolute_error(y_true, y_pred, sample_weight=None): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. + output_weights : string in ['uniform'] or None + or array-like of shape [n_outputs] + Weights by which to average scores of outputs. Useful only if using + multiple outputs. + + ``ndarray`` : + array containing weights for the weighted average. + + ``'uniform'`` : + Scores of all outputs are averaged with uniform weight. + + ``None`` : + No averaging is performed, an array of scores is returned. + sample_weight : array-like of shape = [n_samples], optional Sample weights. Returns ------- - loss : float - A positive floating point value (the best value is 0.0). + loss : float or ndarray of shape [n_outputs] + If output_weights is None, then mean absolute error is returned for + each output separately. + If output_weights is 'uniform' or an ndarray of weights, then the + weighted average of all output scores is returned. + + MAE output is non-negative floating point. The best value is 0.0. Examples -------- @@ -133,14 +190,28 @@ def mean_absolute_error(y_true, y_pred, sample_weight=None): >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_absolute_error(y_true, y_pred) 0.75 - + >>> mean_absolute_error(y_true, y_pred, output_weights=None) + array([ 0.5, 1. ]) + >>> mean_absolute_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.849... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - return np.average(np.abs(y_pred - y_true).mean(axis=1), - weights=sample_weight) - - -def mean_squared_error(y_true, y_pred, sample_weight=None): + y_type, y_true, y_pred, output_weights = _check_reg_targets( + y_true, y_pred, output_weights) + output_errors = np.average(np.abs(y_pred - y_true), + weights=sample_weight, axis=0) + if output_weights is None: + return output_errors + elif output_weights == 'uniform': + # pass None as weights to np.average: uniform mean + output_weights = None + + return np.average(output_errors, weights=output_weights) + + +def mean_squared_error(y_true, y_pred, + output_weights='uniform', + sample_weight=None): """Mean squared error regression loss Parameters @@ -151,13 +222,28 @@ def mean_squared_error(y_true, y_pred, sample_weight=None): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. + output_weights : string in ['uniform'] or None + or array-like of shape [n_outputs] + Weights by which to average scores of outputs. Useful only if using + multiple outputs. + + ``ndarray`` : + array containing weights for the weighted average. + + ``'uniform'`` : + Scores of all outputs are averaged with uniform weight. + + ``None`` : + No averaging is performed, an array of scores is returned. + sample_weight : array-like of shape = [n_samples], optional Sample weights. Returns ------- loss : float - A positive floating point value (the best value is 0.0). + A non-negative floating point value (the best value is 0.0), or an + array of floating point values, one for each individual target. Examples -------- @@ -170,14 +256,30 @@ def mean_squared_error(y_true, y_pred, sample_weight=None): >>> y_pred = [[0, 2],[-1, 2],[8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.708... + >>> mean_squared_error(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS + array([ 0.416..., 1. ]) + >>> mean_squared_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.824... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - return np.average(((y_pred - y_true) ** 2).mean(axis=1), - weights=sample_weight) - - -def explained_variance_score(y_true, y_pred, sample_weight=None): + y_type, y_true, y_pred, output_weights = _check_reg_targets( + y_true, y_pred, output_weights) + output_errors = np.average((y_true - y_pred) ** 2, axis=0, + weights=sample_weight) + if output_weights is None: + return output_errors + elif output_weights == 'uniform': + # pass None as weights to np.average: uniform mean + output_weights = None + + return np.average(output_errors, weights=output_weights) + + +def explained_variance_score(y_true, y_pred, + output_weights='uniform', + sample_weight=None): """Explained variance regression score function Best possible score is 1.0, lower values are worse. @@ -190,6 +292,24 @@ def explained_variance_score(y_true, y_pred, sample_weight=None): y_pred : array-like Estimated target values. + output_weights : string in ['uniform', 'variance'] or None + or array-like of shape [n_outputs] + Weights by which to average scores of outputs. Useful only if using + multiple outputs. + + ``ndarray`` : + array containing weights for the weighted average. + + ``'uniform'`` : + Scores of all outputs are averaged with uniform weight. + + ``'variance'`` : + Scores of all outputs are averaged, weighted by the variances + of each individual output. This corresponds to a global explained + + ``None`` : + No averaging is performed, an array of scores is returned + sample_weight : array-like of shape = [n_samples], optional Sample weights. @@ -211,24 +331,40 @@ def explained_variance_score(y_true, y_pred, sample_weight=None): 0.957... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) + y_type, y_true, y_pred, output_weights = _check_reg_targets( + y_true, y_pred, output_weights) - if y_type != "continuous": + if y_type not in ["continuous", "continuous-multioutput"]: raise ValueError("{0} is not supported".format(y_type)) - _, numerator = _average_and_variance(y_true - y_pred, sample_weight) - _, denominator = _average_and_variance(y_true, sample_weight) - if denominator == 0.0: - if numerator == 0.0: - return 1.0 - else: - # arbitrary set to zero to avoid -inf scores, having a constant - # y_true is not interesting for scoring a regression anyway - return 0.0 - return 1 - numerator / denominator - - -def r2_score(y_true, y_pred, sample_weight=None): + _, numerator = _average_and_variance(y_true - y_pred, sample_weight, + axis=0) + _, denominator = _average_and_variance(y_true, sample_weight, + axis=0) + + nonzero_numerator = numerator != 0 + nonzero_denominator = denominator != 0 + valid_score = nonzero_numerator & nonzero_denominator + output_scores = np.ones(y_true.shape[1]) + + output_scores[valid_score] = 1 - (numerator[valid_score] / + denominator[valid_score]) + output_scores[nonzero_numerator & ~nonzero_denominator] = 0. + if output_weights is None: + # return scores individually + return output_scores + elif output_weights == 'uniform': + # passing None as weights results is uniform mean + output_weights = None + elif output_weights == 'variance': + output_weights = denominator + + return np.average(output_scores, weights=output_weights) + + +def r2_score(y_true, y_pred, + output_weights='uniform', + sample_weight=None): """R^2 (coefficient of determination) regression score function. Best possible score is 1.0, lower values are worse. @@ -241,6 +377,21 @@ def r2_score(y_true, y_pred, sample_weight=None): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. + output_weights : string in ['uniform', 'variance'] or None + or array-like of shape [n_outputs] + Weights by which to average scores of outputs. Useful only if using + multiple outputs. + + ``ndarray`` : + array containing weights for the weighted average. + + ``'uniform'`` : + Scores of all outputs are averaged with uniform weight. + + ``'variance'`` : + Scores of all outputs are averaged, weighted by the variances + of each individual output. This corresponds to a global explained + sample_weight : array-like of shape = [n_samples], optional Sample weights. @@ -270,11 +421,12 @@ def r2_score(y_true, y_pred, sample_weight=None): 0.948... >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] - >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS + >>> r2_score(y_true, y_pred, output_weights='variance') # doctest: +ELLIPSIS 0.938... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) + y_type, y_true, y_pred, output_weights = _check_reg_targets( + y_true, y_pred, output_weights) if sample_weight is not None: sample_weight = column_or_1d(sample_weight) @@ -282,16 +434,27 @@ def r2_score(y_true, y_pred, sample_weight=None): else: weight = 1. - numerator = (weight * (y_true - y_pred) ** 2).sum(dtype=np.float64) + numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, + dtype=np.float64) denominator = (weight * (y_true - np.average( - y_true, axis=0, weights=sample_weight)) ** 2).sum(dtype=np.float64) - - if denominator == 0.0: - if numerator == 0.0: - return 1.0 - else: - # arbitrary set to zero to avoid -inf scores, having a constant - # y_true is not interesting for scoring a regression anyway - return 0.0 - - return 1 - numerator / denominator + y_true, axis=0, weights=sample_weight)) ** 2).sum(axis=0, + dtype=np.float64) + nonzero_denominator = denominator != 0 + nonzero_numerator = numerator != 0 + valid_score = nonzero_denominator & nonzero_numerator + output_scores = np.ones([y_true.shape[1]]) + output_scores[valid_score] = 1 - (numerator[valid_score] / + denominator[valid_score]) + # arbitrary set to zero to avoid -inf scores, having a constant + # y_true is not interesting for scoring a regression anyway + output_scores[nonzero_numerator & ~nonzero_denominator] = 0. + if output_weights is None: + # return scores individually + return output_scores + elif output_weights == 'uniform': + # passing None as weights results is uniform mean + output_weights = None + elif output_weights == 'variance': + output_weights = denominator + + return np.average(output_scores, weights=output_weights) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 8bc17ae327f70..32cd587dca576 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -277,6 +277,7 @@ # Regression metrics with "multioutput-continuous" format support MULTIOUTPUT_METRICS = [ "mean_absolute_error", "mean_squared_error", "r2_score", + "explained_variance_score" ] # Symmetric with respect to their input arguments y_true and y_pred diff --git a/sklearn/metrics/tests/test_regression.py b/sklearn/metrics/tests/test_regression.py index 55a26074b0702..ff18af20faa98 100644 --- a/sklearn/metrics/tests/test_regression.py +++ b/sklearn/metrics/tests/test_regression.py @@ -7,6 +7,7 @@ from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_array_almost_equal from sklearn.metrics import explained_variance_score from sklearn.metrics import mean_absolute_error @@ -38,8 +39,11 @@ def test_multioutput_regression(): error = mean_absolute_error(y_true, y_pred) assert_almost_equal(error, (1. / 3 + 2. / 3 + 2. / 3) / 4.) - error = r2_score(y_true, y_pred) - assert_almost_equal(error, 1 - 5. / 2) + error = r2_score(y_true, y_pred, output_weights='variance') + assert_almost_equal(error, 1. - 5. / 2) + error = r2_score(y_true, y_pred, output_weights='uniform') + assert_almost_equal(error, -.875) + def test_regression_metrics_at_limits(): @@ -63,7 +67,8 @@ def test__check_reg_targets(): repeat=2): if type1 == type2 and n_out1 == n_out2: - y_type, y_check1, y_check2 = _check_reg_targets(y1, y2) + y_type, y_check1, y_check2, output_weights = _check_reg_targets( + y1, y2, None) assert_equal(type1, y_type) if type1 == 'continuous': assert_array_equal(y_check1, np.reshape(y1, (-1, 1))) @@ -72,4 +77,68 @@ def test__check_reg_targets(): assert_array_equal(y_check1, y1) assert_array_equal(y_check2, y2) else: - assert_raises(ValueError, _check_reg_targets, y1, y2) + assert_raises(ValueError, _check_reg_targets, y1, y2, None) + + +def test_regression_multioutput_array(): + y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]] + y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]] + + mse = mean_squared_error(y_true, y_pred, output_weights=None) + mae = mean_absolute_error(y_true, y_pred, output_weights=None) + r = r2_score(y_true, y_pred, output_weights=None) + evs = explained_variance_score(y_true, y_pred, output_weights=None) + + assert_array_equal(mse, np.array([0.125, 0.5625])) + assert_array_equal(mae, np.array([0.25, 0.625])) + assert_array_almost_equal(r, np.array([0.95, 0.93]), decimal=2) + assert_array_almost_equal(evs, np.array([0.95, 0.93]), decimal=2) + + # mean_absolute_error and mean_squared_error are equal because + # it is a binary problem. + y_true = [[0, 0]]*4 + y_pred = [[1, 1]]*4 + mse = mean_squared_error(y_true, y_pred, output_weights=None) + mae = mean_absolute_error(y_true, y_pred, output_weights=None) + r = r2_score(y_true, y_pred, output_weights=None) + assert_array_equal(mse, np.array([1., 1.])) + assert_array_equal(mae, np.array([1., 1.])) + assert_array_almost_equal(r, np.array([0., 0.])) + + r = r2_score([[0, -1], [0, 1]], [[2, 2], [1, 1]], output_weights=None) + assert_array_equal(r, np.array([0, -3.5])) + assert_equal(np.mean(r), r2_score([[0, -1], [0, 1]], [[2, 2], [1, 1]])) + evs = explained_variance_score([[0, -1], [0, 1]], [[2, 2], [1, 1]], + output_weights=None) + assert_array_equal(evs, np.array([0, -1.25])) + + # Checking for the condition in which both numerator and denominator is + # zero. + y_true = [[1, 3], [-1, 2]] + y_pred = [[1, 4], [-1, 1]] + r = r2_score(y_true, y_pred, output_weights=None) + assert_array_equal(r, np.array([1., -3.])) + assert_equal(np.mean(r), r2_score(y_true, y_pred)) + evs = explained_variance_score(y_true, y_pred, output_weights=None) + assert_array_equal(evs, np.array([1., -3.])) + assert_equal(np.mean(evs), explained_variance_score(y_true, y_pred)) + + +def test_regression_custom_weights(): + y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]] + y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]] + + msew = mean_squared_error(y_true, y_pred, output_weights=[0.4, 0.6], + sample_weight=None) + maew = mean_absolute_error(y_true, y_pred, output_weights=[0.4, 0.6], + sample_weight=None) + rw = r2_score(y_true, y_pred, output_weights=[0.4, 0.6], + sample_weight=None) + evsw = explained_variance_score(y_true, y_pred, + output_weights=[0.4, 0.6], + sample_weight=None) + + assert_almost_equal(msew, 0.39, decimal=2) + assert_almost_equal(maew, 0.475, decimal=3) + assert_almost_equal(rw, 0.94, decimal=2) + assert_almost_equal(evsw, 0.94, decimal=2)