diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index ca3ad443d1224..150e592871e6a 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -920,6 +920,13 @@ variance is estimated as follow: The best possible score is 1.0, lower values are worse. +The :func:`explained_variance_score` function has an `output_weights` keyword +with two possible values `None` and 'uniform'. If the value provided is `None`, +then the explained variance score is calculated for each dimension separately +and a numpy array is returned. If the value given is `uniform`, the +explained variance error is averaged over each dimension with a weight of +`1 / n_outputs`. + Here a small example of usage of the :func:`explained_variance_score` function:: @@ -928,6 +935,14 @@ function:: >>> y_pred = [2.5, 0.0, 2, 8] >>> explained_variance_score(y_true, y_pred) # doctest: +ELLIPSIS 0.957... + >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] + >>> y_pred = [[0, 2], [-1, 2], [8, -5]] + >>> explained_variance_score(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS + array([ 0.967..., 1. ]) + >>> explained_variance_score(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.990... Mean absolute error ................... @@ -945,6 +960,13 @@ and :math:`y_i` is the corresponding true value, then the mean absolute error \text{MAE}(y, \hat{y}) = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}}-1} \left| y_i - \hat{y}_i \right|. +The :func:`mean_absolute_error` function has an `output_weights` keyword +with two possible values `None` and 'uniform'. If the value provided is +`None`, then the mean absolute error is calculated for each dimension +separately and a numpy array is returned. If the value given is `uniform`, the +mean absolute error is averaged over each dimension with a weight of +`1 / n_outputs`. + Here a small example of usage of the :func:`mean_absolute_error` function:: >>> from sklearn.metrics import mean_absolute_error @@ -956,8 +978,11 @@ Here a small example of usage of the :func:`mean_absolute_error` function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_absolute_error(y_true, y_pred) 0.75 - - + >>> mean_absolute_error(y_true, y_pred, output_weights=None) + array([ 0.5, 1. ]) + >>> mean_absolute_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.849... Mean squared error ................... @@ -975,6 +1000,13 @@ and :math:`y_i` is the corresponding true value, then the mean squared error \text{MSE}(y, \hat{y}) = \frac{1}{n_\text{samples}} \sum_{i=0}^{n_\text{samples} - 1} (y_i - \hat{y}_i)^2. +The :func:`mean_squared_error` function has an `output_weights` keyword +with two possible values `None` and 'uniform'. If the value provided is +`None`, then the mean squared error is calculated for each dimension +separately and a numpy array is returned. If the value given is `uniform`, the +mean squared error is averaged over each dimension with a weight of +`1 / n_outputs`. + Here a small example of usage of the :func:`mean_squared_error` function:: @@ -987,6 +1019,12 @@ function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.7083... + >>> mean_squared_error(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS + array([ 0.416..., 1. ]) + >>> mean_squared_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.824... .. topic:: Examples: @@ -1012,6 +1050,12 @@ over :math:`n_{\text{samples}}` is defined as where :math:`\bar{y} = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}} - 1} y_i`. +The :func:`r2_score` function has an `output_weights` keyword with two possible +values `None` and 'uniform'. If the value provided is `None`, then the r2 score +is calculated for each dimension separately and a numpy array is returned. + If the value given is `uniform`, the r2 score is averaged over each dimension + with a weight of `1 / n_outputs`. + Here a small example of usage of the :func:`r2_score` function:: >>> from sklearn.metrics import r2_score @@ -1022,8 +1066,13 @@ Here a small example of usage of the :func:`r2_score` function:: >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS - 0.938... - + 0.936... + >>> r2_score(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS + array([ 0.965..., 0.908...]) + >>> r2_score(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.925... .. topic:: Example: diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index ff7a054baf266..e3b8c1f309a03 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -15,6 +15,7 @@ # Jochen Wersdörfer # Lars Buitinck # Joel Nothman +# Manoj Kumar # License: BSD 3 clause from __future__ import division @@ -31,6 +32,7 @@ from ..utils import check_arrays from ..utils import deprecated from ..utils import column_or_1d +from ..utils import safe_asarray from ..utils.multiclass import unique_labels from ..utils.multiclass import type_of_target from ..utils.fixes import bincount @@ -39,8 +41,9 @@ ############################################################################### # General utilities ############################################################################### -def _check_reg_targets(y_true, y_pred): - """Check that y_true and y_pred belong to the same regression task +def _check_reg_targets(y_true, y_pred, output_weights): + """Check that y_true, y_pred and output_weights belong to the + same regression task Parameters ---------- @@ -48,6 +51,8 @@ def _check_reg_targets(y_true, y_pred): y_pred : array-like, + output_weights : array-like or string, ['uniform', None] + Returns ------- type_true : one of {'continuous', continuous-multioutput'} @@ -59,6 +64,15 @@ def _check_reg_targets(y_true, y_pred): y_pred : array-like of shape = [n_samples, n_outputs] Estimated target values. + + output_weights : array-like of shape = [n_outputs] + or string, ['uniform', None] + + 1] custom weights, if output_weights provided is + array-like + 2] 'uniform' or None if output_weights provided is + 'uniform' or None. + """ y_true, y_pred = check_arrays(y_true, y_pred) @@ -72,9 +86,20 @@ def _check_reg_targets(y_true, y_pred): raise ValueError("y_true and y_pred have different number of output " "({0}!={1})".format(y_true.shape[1], y_pred.shape[1])) - y_type = 'continuous' if y_true.shape[1] == 1 else 'continuous-multioutput' + n_outputs = y_true.shape[1] + output_weights_options = (None, 'uniform') + if output_weights not in output_weights_options: + output_weights = safe_asarray(output_weights) + if n_outputs == 1: + raise ValueError("Custom weights are useful only in " + "multi output cases.") + elif n_outputs != output_weights.shape[0]: + raise ValueError("Custom weights must have shape " + "(1, %d)." % output_shape) - return y_type, y_true, y_pred + y_type = 'continuous' if n_outputs == 1 else 'continuous-multioutput' + + return y_type, y_true, y_pred, output_weights def _check_clf_targets(y_true, y_pred): @@ -1943,7 +1968,7 @@ def hamming_loss(y_true, y_pred, classes=None): ############################################################################### # Regression loss functions ############################################################################### -def mean_absolute_error(y_true, y_pred): +def mean_absolute_error(y_true, y_pred, output_weights='uniform'): """Mean absolute error regression loss Parameters @@ -1954,10 +1979,30 @@ def mean_absolute_error(y_true, y_pred): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - Returns + output_weights : string, ['uniform' (default), None] or + array-like of shape = [n_outputs] + + This parameter is useful only for multi-output + tasks where both y_true and y_pred have shape [n_samples, n_outputs] + + ``'uniform'``: + A weight of 1/n_outputs is assigned to each dimension while + averaging. + + ``None``: + No averaging is done. + + Apart from this custom_weights of shape [n_outputs] can be given that + assigns user-defined weight to each dimension of the given input. + + Returns ------- - loss : float - A positive floating point value (the best value is 0.0). + loss : float or a numpy array of shape [n_outputs] + If output_weights is None, it returns a numpy array of floats corresponding + to the mean absolute error of each dimension. + + If output_weights is 'uniform' or user-defined it returns the corresponding + weighted macro-averaged mean absolute error. Examples -------- @@ -1970,13 +2015,24 @@ def mean_absolute_error(y_true, y_pred): >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_absolute_error(y_true, y_pred) 0.75 - + >>> mean_absolute_error(y_true, y_pred, output_weights=None) + array([ 0.5, 1. ]) + >>> mean_absolute_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.849... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - return np.mean(np.abs(y_pred - y_true)) + y_type, y_true, y_pred, output_weights = \ + _check_reg_targets(y_true, y_pred, output_weights) + + error = np.mean(np.abs(y_pred - y_true), axis=0) + if output_weights == 'uniform': + return np.mean(error) + elif output_weights is None: + return error + return np.average(error, weights=output_weights) -def mean_squared_error(y_true, y_pred): +def mean_squared_error(y_true, y_pred, output_weights='uniform'): """Mean squared error regression loss Parameters @@ -1984,13 +2040,35 @@ def mean_squared_error(y_true, y_pred): y_true : array-like of shape = [n_samples] or [n_samples, n_outputs] Ground truth (correct) target values. + It is recommended to normalize y_true, before using this function. + y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. + output_weights : string, ['uniform' (default), None] or + array-like of shape = [n_outputs] + + This parameter is useful only for multi-output + tasks where both y_true and y_pred have shape [n_samples, n_outputs] + + ``'uniform'``: + A weight of 1/n_outputs is assigned to each dimension while + averaging. + + ``None``: + No averaging is done. + + Apart from this custom_weights of shape [n_outputs] can be given that + assigns user-defined weight to each dimension of the given input. + Returns ------- - loss : float - A positive floating point value (the best value is 0.0). + loss : float or a numpy array of shape [n_outputs] + If output_weights is None, it returns a numpy array of floats corresponding + to the mean squared error of each dimension. + + If output_weights is 'uniform' or user-defined it returns the corresponding + weighted macro-averaged mean squared error. Examples -------- @@ -2003,16 +2081,29 @@ def mean_squared_error(y_true, y_pred): >>> y_pred = [[0, 2],[-1, 2],[8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.708... - + >>> mean_squared_error(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS + array([ 0.416..., 1. ]) + >>> mean_squared_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.824... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - return np.mean((y_pred - y_true) ** 2) + y_type, y_true, y_pred, output_weights = \ + _check_reg_targets(y_true, y_pred, output_weights) + + error = np.mean((y_pred - y_true)**2, axis=0) + if output_weights == 'uniform': + return np.mean(error) + elif output_weights is None: + return error + else: + return np.average(error, weights=output_weights) ############################################################################### # Regression score functions ############################################################################### -def explained_variance_score(y_true, y_pred): +def explained_variance_score(y_true, y_pred, output_weights='uniform'): """Explained variance regression score function Best possible score is 1.0, lower values are worse. @@ -2022,13 +2113,35 @@ def explained_variance_score(y_true, y_pred): y_true : array-like Ground truth (correct) target values. + It is recommended to normalize y_true before using this function. + y_pred : array-like Estimated target values. + output_weights : string, ['uniform' (default), None] or + array-like of shape = [n_outputs] + + This parameter is useful only for multi-output + tasks where both y_true and y_pred have shape [n_samples, n_outputs] + + ``'uniform'``: + A weight of 1/n_outputs is assigned to each dimension while + averaging. + + ``None``: + No averaging is done. + + Apart from this custom_weights of shape [n_outputs] can be given that + assigns user-defined weight to each dimension of the given input. + Returns ------- - score : float - The explained variance. + score : float or a numpy array of shape [n_outputs] + If output_weights is None, it returns a numpy array of floats corresponding + to the explained variance score of each dimension. + + If output_weights is 'uniform' or user-defined it returns the corresponding + weighted macro-averaged explained variance score. Notes ----- @@ -2041,26 +2154,42 @@ def explained_variance_score(y_true, y_pred): >>> y_pred = [2.5, 0.0, 2, 8] >>> explained_variance_score(y_true, y_pred) # doctest: +ELLIPSIS 0.957... - + >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] + >>> y_pred = [[0, 2], [-1, 2], [8, -5]] + >>> explained_variance_score(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS + array([ 0.967..., 1. ]) + >>> explained_variance_score(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.990... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - - if y_type != "continuous": - raise ValueError("{0} is not supported".format(y_type)) - - numerator = np.var(y_true - y_pred) - denominator = np.var(y_true) - if denominator == 0.0: - if numerator == 0.0: - return 1.0 - else: - # arbitrary set to zero to avoid -inf scores, having a constant - # y_true is not interesting for scoring a regression anyway - return 0.0 - return 1 - numerator / denominator - - -def r2_score(y_true, y_pred): + y_type, y_true, y_pred, output_weights = \ + _check_reg_targets(y_true, y_pred, output_weights) + + numerator = np.var(y_true - y_pred, axis=0) + denominator = np.var(y_true, axis=0) + + # Setting an array of ones for the case in which both numerator + # and denominator are zero. + explained_variance = np.ones(y_true.shape[1]) + nonzero_denominator = (denominator != 0.0) + nonzero_numerator = (numerator != 0.0) + valid_score = np.logical_and(nonzero_numerator, + nonzero_denominator) + explained_variance[valid_score] = (1 - + numerator[valid_score]/denominator[valid_score]) + + # arbitrary set to zero to avoid -inf scores, having a constant + # y_true is not interesting for scoring a regression anyway + explained_variance[np.logical_and(nonzero_numerator, + np.logical_not(nonzero_denominator))] = 0.0 + if output_weights == 'uniform': + return np.mean(explained_variance) + elif output_weights is None: + return explained_variance + return np.average(explained_variance, weights=output_weights) + +def r2_score(y_true, y_pred, output_weights='uniform'): """R^2 (coefficient of determination) regression score function. Best possible score is 1.0, lower values are worse. @@ -2073,10 +2202,31 @@ def r2_score(y_true, y_pred): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. + output_weights : string, ['uniform' (default), None] or + array-like of shape = [n_outputs] + + This parameter is useful only for multi-output + tasks where both y_true and y_pred have shape [n_samples, n_outputs] + + ``'uniform'``: + A weight of 1/n_outputs is assigned to each dimension while + averaging. + + ``None``: + No averaging is done. + + Apart from this custom_weights of shape [n_outputs] can be given that + assigns user-defined weight to each dimension of the given input. + + Returns ------- - z : float - The R^2 score. + z : float or a numpy array of shape [n_outputs] + If output_weights is None, it returns a numpy array of floats corresponding + to the R^2 score of each dimension. + + If output_weights is 'uniform' or user-defined it returns the corresponding + weighted macro-averaged R^2 score. Notes ----- @@ -2100,23 +2250,38 @@ def r2_score(y_true, y_pred): >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS - 0.938... - + 0.936... + >>> r2_score(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS + array([ 0.965..., 0.908...]) + >>> r2_score(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.925... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) + y_type, y_true, y_pred, output_weights = \ + _check_reg_targets(y_true, y_pred, output_weights) if len(y_true) == 1: raise ValueError("r2_score can only be computed given more than one" " sample.") - numerator = ((y_true - y_pred) ** 2).sum(dtype=np.float64) - denominator = ((y_true - y_true.mean(axis=0)) ** 2).sum(dtype=np.float64) - - if denominator == 0.0: - if numerator == 0.0: - return 1.0 - else: - # arbitrary set to zero to avoid -inf scores, having a constant - # y_true is not interesting for scoring a regression anyway - return 0.0 - return 1 - numerator / denominator + numerator = ((y_true - y_pred) ** 2).sum(dtype=np.float64, axis=0) + denominator = ((y_true - y_true.mean(axis=0)) ** 2).sum(dtype=np.float64, axis=0) + + # Set an array of ones for the condition that both numerator + # and denominator are zero. + r2 = np.ones(y_true.shape[1]) + nonzero_denominator = (denominator != 0.0) + nonzero_numerator = (numerator != 0.0) + valid_score = np.logical_and(nonzero_numerator, nonzero_denominator) + r2[valid_score] = 1 - numerator[valid_score]/denominator[valid_score] + # Denominator is zero and numerator is non-zero + # arbitrary set to zero to avoid -inf scores, having a constant + # y_true is not interesting for scoring a regression anyway + r2[np.logical_and(np.logical_not(nonzero_denominator), + nonzero_numerator)] = 0.0 + if output_weights == 'uniform': + return np.mean(r2) + elif output_weights is None: + return r2 + return np.average(r2, weights=output_weights) diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index f4c63b491e8ff..6dc8bf06f8c56 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -204,6 +204,7 @@ "mean_absolute_error": mean_absolute_error, "mean_squared_error": mean_squared_error, "r2_score": r2_score, + "explained_variance_score": explained_variance_score } @@ -1244,7 +1245,7 @@ def test_multioutput_regression(): assert_almost_equal(error, (1. / 3 + 2. / 3 + 2. / 3) / 4.) error = r2_score(y_true, y_pred) - assert_almost_equal(error, 1 - 5. / 2) + assert_almost_equal(error, -0.875) def test_multioutput_number_of_output_differ(): @@ -1962,7 +1963,8 @@ def test__check_reg_targets(): repeat=2): if type1 == type2 and n_out1 == n_out2: - y_type, y_check1, y_check2 = _check_reg_targets(y1, y2) + y_type, y_check1, y_check2, output_weights = \ + _check_reg_targets(y1, y2, None) assert_equal(type1, y_type) if type1 == 'continuous': assert_array_equal(y_check1, np.reshape(y1, (-1, 1))) @@ -1971,7 +1973,7 @@ def test__check_reg_targets(): assert_array_equal(y_check1, y1) assert_array_equal(y_check2, y2) else: - assert_raises(ValueError, _check_reg_targets, y1, y2) + assert_raises(ValueError, _check_reg_targets, y1, y2, None) def test_log_loss(): @@ -1999,3 +2001,61 @@ def test_log_loss(): y_pred = np.asarray(y_pred) > .5 loss = log_loss(y_true, y_pred, normalize=True, eps=.1) assert_almost_equal(loss, log_loss(y_true, np.clip(y_pred, .1, .9))) + +def test_regression_multioutput_array(): + y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]] + y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]] + + mse = mean_squared_error(y_true, y_pred, output_weights=None) + mae = mean_absolute_error(y_true, y_pred, output_weights=None) + r = r2_score(y_true, y_pred, output_weights=None) + evs = explained_variance_score(y_true, y_pred, output_weights=None) + + assert_array_equal(mse, np.array([0.125, 0.5625])) + assert_array_equal(mae, np.array([0.25, 0.625])) + assert_array_almost_equal(r, np.array([0.95, 0.93]), decimal=2) + assert_array_almost_equal(evs, np.array([0.95, 0.93]), decimal=2) + + # mean_absolute_error and mean_squared_error are equal because + # it is a binary problem. + y_true = [[0, 0]]*4 + y_pred = [[1, 1]]*4 + mse = mean_squared_error(y_true, y_pred, output_weights=None) + mae = mean_absolute_error(y_true, y_pred, output_weights=None) + r = r2_score(y_true, y_pred, output_weights=None) + assert_array_equal(mse, np.array([1., 1.])) + assert_array_equal(mae, np.array([1., 1.])) + assert_array_almost_equal(r, np.array([0., 0.])) + + r = r2_score([[0, -1], [0, 1]], [[2, 2], [1, 1]], output_weights=None) + assert_array_equal(r, np.array([0, -3.5])) + assert_equal(np.mean(r), r2_score([[0, -1], [0, 1]], [[2, 2], [1, 1]])) + evs = explained_variance_score([[0, -1], [0, 1]], [[2, 2], [1, 1]], + output_weights=None) + assert_array_equal(evs, np.array([0, -1.25])) + + # Checking for the condition in which both numerator and denominator is + # zero. + y_true = [[1, 3], [-1, 2]] + y_pred = [[1, 4], [-1, 1]] + r = r2_score(y_true, y_pred, output_weights=None) + assert_array_equal(r, np.array([1., -3.])) + assert_equal(np.mean(r), r2_score(y_true, y_pred)) + evs = explained_variance_score(y_true, y_pred, output_weights=None) + assert_array_equal(evs, np.array([1., -3.])) + assert_equal(np.mean(evs), explained_variance_score(y_true, y_pred)) + + +def test_regression_custom_weights(): + y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]] + y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]] + + msew = mean_squared_error(y_true, y_pred, output_weights=[0.4, 0.6]) + maew = mean_absolute_error(y_true, y_pred, output_weights=[0.4, 0.6]) + rw = r2_score(y_true, y_pred, output_weights=[0.4, 0.6]) + evsw = explained_variance_score(y_true, y_pred, output_weights=[0.4, 0.6]) + + assert_almost_equal(msew, 0.39, decimal=2) + assert_almost_equal(maew, 0.475, decimal=3) + assert_almost_equal(rw, 0.94, decimal=2) + assert_almost_equal(evsw, 0.94, decimal=2)