From a33a9d1751c8c626fa92149e231de17fc95e9e69 Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Sun, 6 Oct 2013 20:51:47 +0530 Subject: [PATCH 01/13] Regreesion metrics return arrays for multi-output cases --- sklearn/metrics/metrics.py | 77 +++++++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 21 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 410308afb81a7..541d62b349324 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -1883,7 +1883,7 @@ def hamming_loss(y_true, y_pred, classes=None): ############################################################################### # Regression loss functions ############################################################################### -def mean_absolute_error(y_true, y_pred): +def mean_absolute_error(y_true, y_pred, average=True): """Mean absolute error regression loss Parameters @@ -1894,10 +1894,14 @@ def mean_absolute_error(y_true, y_pred): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. + average : True or False + Default value is True. If False, returns an array (multi-output) + Returns ------- - loss : float - A positive floating point value (the best value is 0.0). + loss : float or a numpy array of floats + If average is True, a positive floating point value (the best value is 0.0). + Else, a numpy array of positive floating points is returned. Examples -------- @@ -1910,13 +1914,18 @@ def mean_absolute_error(y_true, y_pred): >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_absolute_error(y_true, y_pred) 0.75 - + >>> mean_absolute_error(y_true, y_pred, average=False) + array([0.5, 1.]) """ + if not average: + axis = 0 + else: + axis = None y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - return np.mean(np.abs(y_pred - y_true)) + return np.mean(np.abs(y_pred - y_true), axis=axis) -def mean_squared_error(y_true, y_pred): +def mean_squared_error(y_true, y_pred, average=True): """Mean squared error regression loss Parameters @@ -1927,10 +1936,14 @@ def mean_squared_error(y_true, y_pred): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. + average : True or False + Default value is True. If False, returns an array (multi-output) + Returns ------- - loss : float - A positive floating point value (the best value is 0.0). + loss : float or a numpy array of floats + If average is True, a positive floating point value (the best value is 0.0). + Else, a numpy array of positive floating points is returned. Examples -------- @@ -1943,10 +1956,16 @@ def mean_squared_error(y_true, y_pred): >>> y_pred = [[0, 2],[-1, 2],[8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.708... - + >>> mean_squared_error(y_true, y_pred, average=False) + 0.708... + array([0.41666667, 1.]) """ + if not average: + axis = 0 + else: + axis = None y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - return np.mean((y_pred - y_true) ** 2) + return np.mean((y_pred - y_true) ** 2, axis=axis) ############################################################################### @@ -2000,7 +2019,7 @@ def explained_variance_score(y_true, y_pred): return 1 - numerator / denominator -def r2_score(y_true, y_pred): +def r2_score(y_true, y_pred, average=True): """R^2 (coefficient of determination) regression score function. Best possible score is 1.0, lower values are worse. @@ -2013,10 +2032,15 @@ def r2_score(y_true, y_pred): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. + average : True or False + Default value is True. If False, returns an array (multi-output) + Returns ------- - z : float - The R^2 score. + z : float or a numpy array of floats + If average is true, it returns the R^2 score, flattened across 1-D. + If average is False, it returns an array of floats corresponding to + the R^2 score of each dimension. Notes ----- @@ -2041,23 +2065,34 @@ def r2_score(y_true, y_pred): >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS 0.938... - + >>> r2_score(y_true, y_pred, average=False) # doctest: +ELLIPSIS + array([0.96543779, 0.90816327]) """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) + y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) if len(y_true) == 1: raise ValueError("r2_score can only be computed given more than one" " sample.") - numerator = ((y_true - y_pred) ** 2).sum(dtype=np.float64) - denominator = ((y_true - y_true.mean(axis=0)) ** 2).sum(dtype=np.float64) - if denominator == 0.0: - if numerator == 0.0: - return 1.0 + if not average: + axis = 0 + else: + axis = None + numerator = ((y_true - y_pred) ** 2).sum(dtype=np.float64, axis=axis) + denominator = ((y_true - y_true.mean(axis=0)) ** 2).sum(dtype=np.float64, axis=axis) + if denominator.sum() == 0.0: + if numerator.sum() == 0.0: + if average: + return 1.0 + else: + return np.ones(y_true.shape[0], dtype=np.float64) else: # arbitrary set to zero to avoid -inf scores, having a constant # y_true is not interesting for scoring a regression anyway - return 0.0 + if average: + return 0.0 + else: + return np.zeros(y_true.shape[0], dtype=np.float64) return 1 - numerator / denominator From ce91610ec39ffae19c306c3f7f902d34ae651f48 Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Mon, 7 Oct 2013 16:42:13 +0530 Subject: [PATCH 02/13] Added tests, hopefully resolved tests failures. --- sklearn/metrics/metrics.py | 30 +++++++++++++++------------ sklearn/metrics/tests/test_metrics.py | 22 ++++++++++++++++++++ 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 541d62b349324..0be4644e7a7cb 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -1895,11 +1895,13 @@ def mean_absolute_error(y_true, y_pred, average=True): Estimated target values. average : True or False - Default value is True. If False, returns an array (multi-output) + If True returns a float. + If False, returns an array (multi-output) + (default: True) Returns ------- - loss : float or a numpy array of floats + loss : float or a numpy array of shape[n_outputs] If average is True, a positive floating point value (the best value is 0.0). Else, a numpy array of positive floating points is returned. @@ -1915,7 +1917,7 @@ def mean_absolute_error(y_true, y_pred, average=True): >>> mean_absolute_error(y_true, y_pred) 0.75 >>> mean_absolute_error(y_true, y_pred, average=False) - array([0.5, 1.]) + array([ 0.5, 1. ]) """ if not average: axis = 0 @@ -1937,11 +1939,13 @@ def mean_squared_error(y_true, y_pred, average=True): Estimated target values. average : True or False - Default value is True. If False, returns an array (multi-output) + If True returns a float. + If False, returns an array (multi-output) + (default: True) Returns ------- - loss : float or a numpy array of floats + loss : float or a numpy array of shape[n_outputs] If average is True, a positive floating point value (the best value is 0.0). Else, a numpy array of positive floating points is returned. @@ -1957,8 +1961,7 @@ def mean_squared_error(y_true, y_pred, average=True): >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.708... >>> mean_squared_error(y_true, y_pred, average=False) - 0.708... - array([0.41666667, 1.]) + array([ 0.41666667, 1. ]) """ if not average: axis = 0 @@ -2033,11 +2036,13 @@ def r2_score(y_true, y_pred, average=True): Estimated target values. average : True or False - Default value is True. If False, returns an array (multi-output) + If True returns a float. + If False, returns an array (multi-output) + (default: True) Returns ------- - z : float or a numpy array of floats + z : float or a numpy array of shape[n_outputs] If average is true, it returns the R^2 score, flattened across 1-D. If average is False, it returns an array of floats corresponding to the R^2 score of each dimension. @@ -2066,9 +2071,8 @@ def r2_score(y_true, y_pred, average=True): >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS 0.938... >>> r2_score(y_true, y_pred, average=False) # doctest: +ELLIPSIS - array([0.96543779, 0.90816327]) + array([ 0.96543779, 0.90816327]) """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) if len(y_true) == 1: raise ValueError("r2_score can only be computed given more than one" @@ -2085,14 +2089,14 @@ def r2_score(y_true, y_pred, average=True): if average: return 1.0 else: - return np.ones(y_true.shape[0], dtype=np.float64) + return np.ones(y_true.shape[1], dtype=np.float64) else: # arbitrary set to zero to avoid -inf scores, having a constant # y_true is not interesting for scoring a regression anyway if average: return 0.0 else: - return np.zeros(y_true.shape[0], dtype=np.float64) + return np.zeros(y_true.shape[1], dtype=np.float64) return 1 - numerator / denominator diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 1ca78d556be42..b2283dbc92304 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -1992,3 +1992,25 @@ def test_log_loss(): y_pred = np.asarray(y_pred) > .5 loss = log_loss(y_true, y_pred, normalize=True, eps=.1) assert_almost_equal(loss, log_loss(y_true, np.clip(y_pred, .1, .9))) + +def test_issue_2200(): + y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]] + y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]] + + mse = list(mean_squared_error(y_true, y_pred, average=False)) + mae = list(mean_absolute_error(y_true, y_pred, average=False)) + r = list(r2_score(y_true, y_pred, average=False)) + assert_equal(mse, [0.125, 0.5625]) + assert_equal(mae, [0.25, 0.625]) + assert_almost_equal(r, [0.95, 0.93], decimal=2) + + # mean_absolute_error and mean_squared_error are equal because + # it is a binary problem. + y_true = [[0, 0]]*4 + y_pred = [[1, 1]]*4 + mse = list(mean_squared_error(y_true, y_pred, average=False)) + mae = list(mean_absolute_error(y_true, y_pred, average=False)) + r = list(r2_score(y_true, y_pred, average=False)) + assert_equal(mse, [1., 1.]) + assert_equal(mae, [1., 1.]) + assert_equal(r, [0., 0.]) From 5f0ed74cebbf30d744721a6a6871e3a174fd5e9a Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Mon, 7 Oct 2013 18:52:11 +0530 Subject: [PATCH 03/13] Changed True to micro --- sklearn/metrics/metrics.py | 28 +++++++++++++-------------- sklearn/metrics/tests/test_metrics.py | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 0be4644e7a7cb..5e779d58debca 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -1883,7 +1883,7 @@ def hamming_loss(y_true, y_pred, classes=None): ############################################################################### # Regression loss functions ############################################################################### -def mean_absolute_error(y_true, y_pred, average=True): +def mean_absolute_error(y_true, y_pred, average='micro'): """Mean absolute error regression loss Parameters @@ -1894,10 +1894,10 @@ def mean_absolute_error(y_true, y_pred, average=True): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - average : True or False - If True returns a float. + average : 'micro' or False + If 'micro' returns a float. If False, returns an array (multi-output) - (default: True) + (default: 'micro') Returns ------- @@ -1927,7 +1927,7 @@ def mean_absolute_error(y_true, y_pred, average=True): return np.mean(np.abs(y_pred - y_true), axis=axis) -def mean_squared_error(y_true, y_pred, average=True): +def mean_squared_error(y_true, y_pred, average='micro'): """Mean squared error regression loss Parameters @@ -1938,15 +1938,15 @@ def mean_squared_error(y_true, y_pred, average=True): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - average : True or False - If True returns a float. + average : 'micro' or False + If 'micro' returns a float. If False, returns an array (multi-output) - (default: True) + (default: 'micro') Returns ------- loss : float or a numpy array of shape[n_outputs] - If average is True, a positive floating point value (the best value is 0.0). + If average is "micro", a positive floating point value (the best value is 0.0). Else, a numpy array of positive floating points is returned. Examples @@ -2022,7 +2022,7 @@ def explained_variance_score(y_true, y_pred): return 1 - numerator / denominator -def r2_score(y_true, y_pred, average=True): +def r2_score(y_true, y_pred, average='micro'): """R^2 (coefficient of determination) regression score function. Best possible score is 1.0, lower values are worse. @@ -2035,15 +2035,15 @@ def r2_score(y_true, y_pred, average=True): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - average : True or False - If True returns a float. + average : 'micro' or False + If 'micro' returns a float. If False, returns an array (multi-output) - (default: True) + (default: 'micro') Returns ------- z : float or a numpy array of shape[n_outputs] - If average is true, it returns the R^2 score, flattened across 1-D. + If average is 'micro', it returns the R^2 score, flattened across 1-D. If average is False, it returns an array of floats corresponding to the R^2 score of each dimension. diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index b2283dbc92304..1e22624ce48fd 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -1993,7 +1993,7 @@ def test_log_loss(): loss = log_loss(y_true, y_pred, normalize=True, eps=.1) assert_almost_equal(loss, log_loss(y_true, np.clip(y_pred, .1, .9))) -def test_issue_2200(): +def test_regression_multioutput_array(): y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]] y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]] From 1829ad9b29fe7414d5ef293d37ab363f359099a3 Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Mon, 7 Oct 2013 23:30:46 +0530 Subject: [PATCH 04/13] Changed False to None, assert_array_equal --- sklearn/metrics/metrics.py | 18 +++++++++--------- sklearn/metrics/tests/test_metrics.py | 24 ++++++++++++------------ 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 5e779d58debca..ba6e686a96590 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -1894,15 +1894,15 @@ def mean_absolute_error(y_true, y_pred, average='micro'): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - average : 'micro' or False + average : 'micro' or None If 'micro' returns a float. - If False, returns an array (multi-output) + If None, returns an array (multi-output) (default: 'micro') Returns ------- loss : float or a numpy array of shape[n_outputs] - If average is True, a positive floating point value (the best value is 0.0). + If average is 'micro', a positive floating point value (the best value is 0.0). Else, a numpy array of positive floating points is returned. Examples @@ -1938,15 +1938,15 @@ def mean_squared_error(y_true, y_pred, average='micro'): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - average : 'micro' or False + average : 'micro' or None If 'micro' returns a float. - If False, returns an array (multi-output) + If None, returns an array (multi-output) (default: 'micro') Returns ------- loss : float or a numpy array of shape[n_outputs] - If average is "micro", a positive floating point value (the best value is 0.0). + If average is 'micro', a positive floating point value (the best value is 0.0). Else, a numpy array of positive floating points is returned. Examples @@ -2035,16 +2035,16 @@ def r2_score(y_true, y_pred, average='micro'): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - average : 'micro' or False + average : 'micro' or None If 'micro' returns a float. - If False, returns an array (multi-output) + If None, returns an array (multi-output) (default: 'micro') Returns ------- z : float or a numpy array of shape[n_outputs] If average is 'micro', it returns the R^2 score, flattened across 1-D. - If average is False, it returns an array of floats corresponding to + If average is None, it returns an array of floats corresponding to the R^2 score of each dimension. Notes diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 1e22624ce48fd..c68503c360c8d 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -1997,20 +1997,20 @@ def test_regression_multioutput_array(): y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]] y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]] - mse = list(mean_squared_error(y_true, y_pred, average=False)) - mae = list(mean_absolute_error(y_true, y_pred, average=False)) - r = list(r2_score(y_true, y_pred, average=False)) - assert_equal(mse, [0.125, 0.5625]) - assert_equal(mae, [0.25, 0.625]) - assert_almost_equal(r, [0.95, 0.93], decimal=2) + mse = mean_squared_error(y_true, y_pred, average=False) + mae = mean_absolute_error(y_true, y_pred, average=False) + r = r2_score(y_true, y_pred, average=False) + assert_array_equal(mse, np.array([0.125, 0.5625])) + assert_array_equal(mae, np.array([0.25, 0.625])) + assert_array_almost_equal(r, np.array([0.95, 0.93]), decimal=2) # mean_absolute_error and mean_squared_error are equal because # it is a binary problem. y_true = [[0, 0]]*4 y_pred = [[1, 1]]*4 - mse = list(mean_squared_error(y_true, y_pred, average=False)) - mae = list(mean_absolute_error(y_true, y_pred, average=False)) - r = list(r2_score(y_true, y_pred, average=False)) - assert_equal(mse, [1., 1.]) - assert_equal(mae, [1., 1.]) - assert_equal(r, [0., 0.]) + mse = mean_squared_error(y_true, y_pred, average=False) + mae = mean_absolute_error(y_true, y_pred, average=False) + r = r2_score(y_true, y_pred, average=False) + assert_array_equal(mse, np.array([1., 1.])) + assert_array_equal(mae, np.array([1., 1.])) + assert_array_almost_equal(r, np.array([0., 0.])) From 557f4a7133d0929877ab59961e58dc87a959c76d Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Thu, 17 Oct 2013 01:50:09 +0530 Subject: [PATCH 05/13] Added output_weights, docs, tests etc --- doc/modules/model_evaluation.rst | 26 +++++- sklearn/metrics/metrics.py | 121 ++++++++++++++------------ sklearn/metrics/tests/test_metrics.py | 26 ++++-- 3 files changed, 110 insertions(+), 63 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index ca3ad443d1224..67d2370a99778 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -945,6 +945,12 @@ and :math:`y_i` is the corresponding true value, then the mean absolute error \text{MAE}(y, \hat{y}) = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}}-1} \left| y_i - \hat{y}_i \right|. +Output_weights is a function argument that can take two values, None and uniform. +If the value provided is None, then the mean absolute error is calculated +for each dimension separately and a numpy array is returned. If the value given +is uniform, then a weight of unity is assigned to each dimension during +macro-averaging, and a float is returned. + Here a small example of usage of the :func:`mean_absolute_error` function:: >>> from sklearn.metrics import mean_absolute_error @@ -956,7 +962,8 @@ Here a small example of usage of the :func:`mean_absolute_error` function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_absolute_error(y_true, y_pred) 0.75 - + >>> mean_absolute_error(y_true, y_pred, output_weights=None) + array([ 0.5, 1. ]) Mean squared error @@ -975,6 +982,12 @@ and :math:`y_i` is the corresponding true value, then the mean squared error \text{MSE}(y, \hat{y}) = \frac{1}{n_\text{samples}} \sum_{i=0}^{n_\text{samples} - 1} (y_i - \hat{y}_i)^2. +Output_weights is a function argument that can take two values, None and uniform. +If the value provided is None, then the mean squared error is calculated +for each dimension separately and a numpy array is returned. If the value given +is uniform, then a weight of unity is assigned to each dimension during +macro-averaging, and a float is returned. + Here a small example of usage of the :func:`mean_squared_error` function:: @@ -987,6 +1000,8 @@ function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.7083... + >>> mean_squared_error(y_true, y_pred, output_weights=None) + array([ 0.417..., 1. ]) .. topic:: Examples: @@ -1012,6 +1027,12 @@ over :math:`n_{\text{samples}}` is defined as where :math:`\bar{y} = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}} - 1} y_i`. +Output_weights is a function argument that can take two values, None and uniform. +If the value provided is None, then the r2 score is calculated +for each dimension separately and a numpy array is returned. If the value given +is uniform, then a weight of unity is assigned to each dimension during +macro-averaging, and a float is returned. + Here a small example of usage of the :func:`r2_score` function:: >>> from sklearn.metrics import r2_score @@ -1023,7 +1044,8 @@ Here a small example of usage of the :func:`r2_score` function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS 0.938... - + >>> r2_score(y_true, y_pred, output_weights=None) # doctest: +ELLIPSIS + array([ 0.965..., 0.908...]) .. topic:: Example: diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index b0db4d3163ff2..1bbea0eb9b96f 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -1954,15 +1954,16 @@ def mean_absolute_error(y_true, y_pred, average='micro'): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - average : 'micro' or None - If 'micro' returns a float. - If None, returns an array (multi-output) - (default: 'micro') + output_weights : string, ['uniform' (default), None] + Assigns weight to each dimension of the given input. + If the option given is uniform, then a weight of unity is assigned + to each dimension during averaging. If the option given is None, + then no averaging is done. - Returns + Returns ------- loss : float or a numpy array of shape[n_outputs] - If average is 'micro', a positive floating point value (the best value is 0.0). + If output_weights is 'uniform', a positive floating point value (the best value is 0.0). Else, a numpy array of positive floating points is returned. Examples @@ -1976,18 +1977,22 @@ def mean_absolute_error(y_true, y_pred, average='micro'): >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_absolute_error(y_true, y_pred) 0.75 - >>> mean_absolute_error(y_true, y_pred, average=False) + >>> mean_absolute_error(y_true, y_pred, output_weights=None) array([ 0.5, 1. ]) """ - if not average: - axis = 0 - else: + output_weights_options = (None, 'uniform') + if output_weights not in output_weights_options: + raise ValueError('output_weights has to be one of ' + + str(output_weights_options)) + if output_weights == 'uniform': axis = None + else: + axis = 0 y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) return np.mean(np.abs(y_pred - y_true), axis=axis) -def mean_squared_error(y_true, y_pred, average='micro'): +def mean_squared_error(y_true, y_pred, output_weights='uniform'): """Mean squared error regression loss Parameters @@ -1998,15 +2003,16 @@ def mean_squared_error(y_true, y_pred, average='micro'): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - average : 'micro' or None - If 'micro' returns a float. - If None, returns an array (multi-output) - (default: 'micro') + output_weights : string, ['uniform' (default), None] + Assigns weight to each dimension of the given input. + If the option given is uniform, then a weight of unity is assigned + to each dimension during averaging. If the option given is None, + then no averaging is done. Returns ------- loss : float or a numpy array of shape[n_outputs] - If average is 'micro', a positive floating point value (the best value is 0.0). + If output_weights is 'ones', a positive floating point value (the best value is 0.0). Else, a numpy array of positive floating points is returned. Examples @@ -2020,13 +2026,17 @@ def mean_squared_error(y_true, y_pred, average='micro'): >>> y_pred = [[0, 2],[-1, 2],[8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.708... - >>> mean_squared_error(y_true, y_pred, average=False) - array([ 0.41666667, 1. ]) + >>> mean_squared_error(y_true, y_pred, output_weights=None) + array([ 0.417..., 1. ]) """ - if not average: - axis = 0 - else: + output_weights_options = (None, 'uniform') + if output_weights not in output_weights_options: + raise ValueError('output_weights has to be one of ' + + str(output_weights_options)) + if output_weights == 'uniform': axis = None + else: + axis = 0 y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) return np.mean((y_pred - y_true) ** 2, axis=axis) @@ -2082,7 +2092,7 @@ def explained_variance_score(y_true, y_pred): return 1 - numerator / denominator -def r2_score(y_true, y_pred, average='micro'): +def r2_score(y_true, y_pred, output_weights='uniform'): """R^2 (coefficient of determination) regression score function. Best possible score is 1.0, lower values are worse. @@ -2095,17 +2105,18 @@ def r2_score(y_true, y_pred, average='micro'): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - average : 'micro' or None - If 'micro' returns a float. - If None, returns an array (multi-output) - (default: 'micro') + output_weights : string, ['uniform' (default), None] + Assigns weight to each dimension of the given input. + If the option given is uniform, then a weight of unity is assigned + to each dimension while averaging. If the option given is None, then + no averaging is done. Returns ------- - z : float or a numpy array of shape[n_outputs] - If average is 'micro', it returns the R^2 score, flattened across 1-D. - If average is None, it returns an array of floats corresponding to - the R^2 score of each dimension. + z : float or a numpy array of shape [n_outputs] + If output_weights is 'uniform', it returns the macro-averaged R^2 score, + If output_weights is None, it returns a numpy array of floats corresponding + to the R^2 score of each dimension. Notes ----- @@ -2129,33 +2140,35 @@ def r2_score(y_true, y_pred, average='micro'): >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS - 0.938... - >>> r2_score(y_true, y_pred, average=False) # doctest: +ELLIPSIS - array([ 0.96543779, 0.90816327]) + 0.937... + >>> r2_score(y_true, y_pred, output_weights=None) # doctest: +ELLIPSIS + array([ 0.965..., 0.908...]) """ y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) if len(y_true) == 1: raise ValueError("r2_score can only be computed given more than one" " sample.") - if not average: - axis = 0 - else: - axis = None - numerator = ((y_true - y_pred) ** 2).sum(dtype=np.float64, axis=axis) - denominator = ((y_true - y_true.mean(axis=0)) ** 2).sum(dtype=np.float64, axis=axis) - if denominator.sum() == 0.0: - if numerator.sum() == 0.0: - if average: - return 1.0 - else: - return np.ones(y_true.shape[1], dtype=np.float64) - else: - # arbitrary set to zero to avoid -inf scores, having a constant - # y_true is not interesting for scoring a regression anyway - if average: - return 0.0 - else: - return np.zeros(y_true.shape[1], dtype=np.float64) - - return 1 - numerator / denominator + output_weights_options = (None, 'uniform') + if output_weights not in output_weights_options: + raise ValueError('output_weights has to be one of ' + + str(output_weights_options)) + + numerator = ((y_true - y_pred) ** 2).sum(dtype=np.float64, axis=0) + denominator = ((y_true - y_true.mean(axis=0)) ** 2).sum(dtype=np.float64, axis=0) + + # Set an array of ones for the condition that both numerator + # and denominator are zero. + r2 = np.ones(y_true.shape[1]) + nonzero_denominator = (denominator != 0.0) + nonzero_numerator = (numerator != 0.0) + valid_score = np.logical_and(nonzero_numerator, nonzero_denominator) + r2[valid_score] = 1 - numerator[valid_score]/denominator[valid_score] + # Denominator is zero and numerator is non-zero + # arbitrary set to zero to avoid -inf scores, having a constant + # y_true is not interesting for scoring a regression anyway + r2[np.logical_and(np.logical_not(nonzero_denominator), + nonzero_numerator)] = 0.0 + if output_weights == 'uniform': + return np.mean(r2) + return r2 diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 97faf5b0de037..9159f808a9923 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -1244,7 +1244,7 @@ def test_multioutput_regression(): assert_almost_equal(error, (1. / 3 + 2. / 3 + 2. / 3) / 4.) error = r2_score(y_true, y_pred) - assert_almost_equal(error, 1 - 5. / 2) + assert_almost_equal(error, -0.875) def test_multioutput_number_of_output_differ(): @@ -2004,9 +2004,9 @@ def test_regression_multioutput_array(): y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]] y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]] - mse = mean_squared_error(y_true, y_pred, average=False) - mae = mean_absolute_error(y_true, y_pred, average=False) - r = r2_score(y_true, y_pred, average=False) + mse = mean_squared_error(y_true, y_pred, output_weights=None) + mae = mean_absolute_error(y_true, y_pred, output_weights=None) + r = r2_score(y_true, y_pred, output_weights=None) assert_array_equal(mse, np.array([0.125, 0.5625])) assert_array_equal(mae, np.array([0.25, 0.625])) assert_array_almost_equal(r, np.array([0.95, 0.93]), decimal=2) @@ -2015,9 +2015,21 @@ def test_regression_multioutput_array(): # it is a binary problem. y_true = [[0, 0]]*4 y_pred = [[1, 1]]*4 - mse = mean_squared_error(y_true, y_pred, average=False) - mae = mean_absolute_error(y_true, y_pred, average=False) - r = r2_score(y_true, y_pred, average=False) + mse = mean_squared_error(y_true, y_pred, output_weights=None) + mae = mean_absolute_error(y_true, y_pred, output_weights=None) + r = r2_score(y_true, y_pred, output_weights=None) assert_array_equal(mse, np.array([1., 1.])) assert_array_equal(mae, np.array([1., 1.])) assert_array_almost_equal(r, np.array([0., 0.])) + + r = r2_score([[0, -1], [0, 1]], [[2, 2], [1, 1]], output_weights=None) + assert_array_equal(r, np.array([0, -3.5])) + assert_equal(np.mean(r), r2_score([[0, -1], [0, 1]], [[2, 2], [1, 1]])) + + # Checking for the condition in which both numerator and denominator is + # zero. + y_true = [[1, 3], [-1, 2]] + y_pred = [[1, 4], [-1, 1]] + r = r2_score(y_true, y_pred, output_weights=None) + assert_array_equal(r, np.array([1., -3.])) + assert_equal(np.mean(r), r2_score(y_true, y_pred)) From cd81cee0823f1f5f8308695fe663626f32d6531f Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Thu, 17 Oct 2013 13:19:46 +0530 Subject: [PATCH 06/13] Minor changes --- doc/modules/model_evaluation.rst | 36 +++++++++++++++++--------------- sklearn/metrics/metrics.py | 13 ++++++------ 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 67d2370a99778..2fa1d3e758dc9 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -945,11 +945,12 @@ and :math:`y_i` is the corresponding true value, then the mean absolute error \text{MAE}(y, \hat{y}) = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}}-1} \left| y_i - \hat{y}_i \right|. -Output_weights is a function argument that can take two values, None and uniform. -If the value provided is None, then the mean absolute error is calculated -for each dimension separately and a numpy array is returned. If the value given -is uniform, then a weight of unity is assigned to each dimension during -macro-averaging, and a float is returned. +The :func:`mean_absolute_error` function has an `output_weights` keyword +with two possible values `None` and 'uniform'. If the value provided is +`None`, then the mean absolute error is calculated for each dimension +separately and a numpy array is returned. If the value given is `uniform`, +then a weight of unity is assigned to each dimension during macro-averaging, +and a float is returned. Here a small example of usage of the :func:`mean_absolute_error` function:: @@ -982,11 +983,12 @@ and :math:`y_i` is the corresponding true value, then the mean squared error \text{MSE}(y, \hat{y}) = \frac{1}{n_\text{samples}} \sum_{i=0}^{n_\text{samples} - 1} (y_i - \hat{y}_i)^2. -Output_weights is a function argument that can take two values, None and uniform. -If the value provided is None, then the mean squared error is calculated -for each dimension separately and a numpy array is returned. If the value given -is uniform, then a weight of unity is assigned to each dimension during -macro-averaging, and a float is returned. +The :func:`mean_squared_error` function has an `output_weights` keyword +with two possible values `None` and 'uniform'. If the value provided is +`None`, then the mean squared error is calculated for each dimension +separately and a numpy array is returned. If the value given is `uniform`, +then a weight of unity is assigned to each dimension during macro-averaging, +and a float is returned. Here a small example of usage of the :func:`mean_squared_error` function:: @@ -1000,8 +1002,8 @@ function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.7083... - >>> mean_squared_error(y_true, y_pred, output_weights=None) - array([ 0.417..., 1. ]) + >>> mean_squared_error(y_true, y_pred, output_weights=None) # doctest: +ELLIPSIS + array([ 0.416..., 1. ]) .. topic:: Examples: @@ -1027,10 +1029,10 @@ over :math:`n_{\text{samples}}` is defined as where :math:`\bar{y} = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}} - 1} y_i`. -Output_weights is a function argument that can take two values, None and uniform. -If the value provided is None, then the r2 score is calculated -for each dimension separately and a numpy array is returned. If the value given -is uniform, then a weight of unity is assigned to each dimension during +The :func:`r2_score` function has an `output_weights` keyword with two possible +values `None` and 'uniform'. If the value provided is `None`, then the r2 score +is calculated for each dimension separately and a numpy array is returned. If the +value given is `uniform`, then a weight of unity is assigned to each dimension during macro-averaging, and a float is returned. Here a small example of usage of the :func:`r2_score` function:: @@ -1043,7 +1045,7 @@ Here a small example of usage of the :func:`r2_score` function:: >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS - 0.938... + 0.936... >>> r2_score(y_true, y_pred, output_weights=None) # doctest: +ELLIPSIS array([ 0.965..., 0.908...]) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 1bbea0eb9b96f..dabaaf0901651 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -15,6 +15,7 @@ # Jochen Wersdörfer # Lars Buitinck # Joel Nothman +# Manoj Kumar # License: BSD 3 clause from __future__ import division @@ -1943,7 +1944,7 @@ def hamming_loss(y_true, y_pred, classes=None): ############################################################################### # Regression loss functions ############################################################################### -def mean_absolute_error(y_true, y_pred, average='micro'): +def mean_absolute_error(y_true, y_pred, output_weights='uniform'): """Mean absolute error regression loss Parameters @@ -2011,8 +2012,8 @@ def mean_squared_error(y_true, y_pred, output_weights='uniform'): Returns ------- - loss : float or a numpy array of shape[n_outputs] - If output_weights is 'ones', a positive floating point value (the best value is 0.0). + loss : float or a numpy array of shape [n_outputs] + If output_weights is 'uniform', a positive floating point value (the best value is 0.0). Else, a numpy array of positive floating points is returned. Examples @@ -2026,8 +2027,8 @@ def mean_squared_error(y_true, y_pred, output_weights='uniform'): >>> y_pred = [[0, 2],[-1, 2],[8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.708... - >>> mean_squared_error(y_true, y_pred, output_weights=None) - array([ 0.417..., 1. ]) + >>> mean_squared_error(y_true, y_pred, output_weights=None) # doctest: +ELLIPSIS + array([ 0.416..., 1. ]) """ output_weights_options = (None, 'uniform') if output_weights not in output_weights_options: @@ -2140,7 +2141,7 @@ def r2_score(y_true, y_pred, output_weights='uniform'): >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS - 0.937... + 0.936... >>> r2_score(y_true, y_pred, output_weights=None) # doctest: +ELLIPSIS array([ 0.965..., 0.908...]) """ From 1d6fc195a8aedb742c6768c3582fc958a4a6eaa6 Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Thu, 17 Oct 2013 13:37:33 +0530 Subject: [PATCH 07/13] Change to docs --- doc/modules/model_evaluation.rst | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 2fa1d3e758dc9..95cd55b137565 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -948,9 +948,8 @@ and :math:`y_i` is the corresponding true value, then the mean absolute error The :func:`mean_absolute_error` function has an `output_weights` keyword with two possible values `None` and 'uniform'. If the value provided is `None`, then the mean absolute error is calculated for each dimension -separately and a numpy array is returned. If the value given is `uniform`, -then a weight of unity is assigned to each dimension during macro-averaging, -and a float is returned. +separately and a numpy array is returned. If the value given is `uniform`, the +mean absolute error is averaged over each dimension with a weight of `1 / n_outputs`. Here a small example of usage of the :func:`mean_absolute_error` function:: @@ -986,9 +985,8 @@ and :math:`y_i` is the corresponding true value, then the mean squared error The :func:`mean_squared_error` function has an `output_weights` keyword with two possible values `None` and 'uniform'. If the value provided is `None`, then the mean squared error is calculated for each dimension -separately and a numpy array is returned. If the value given is `uniform`, -then a weight of unity is assigned to each dimension during macro-averaging, -and a float is returned. +separately and a numpy array is returned. If the value given is `uniform`, the +mean squared error is averaged over each dimension with a weight of `1 / n_outputs`. Here a small example of usage of the :func:`mean_squared_error` function:: From 76936ea649375a93c6e630f2876557b73ac5794f Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Fri, 18 Oct 2013 00:53:01 +0530 Subject: [PATCH 08/13] Extended multioutput to explained_variance_score --- doc/modules/model_evaluation.rst | 12 ++++++ sklearn/metrics/metrics.py | 55 +++++++++++++++++++-------- sklearn/metrics/tests/test_metrics.py | 9 +++++ 3 files changed, 60 insertions(+), 16 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 95cd55b137565..0659601151926 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -920,6 +920,13 @@ variance is estimated as follow: The best possible score is 1.0, lower values are worse. +The :func:`explained_variance_score` function has an `output_weights` keyword +with two possible values `None` and 'uniform'. If the value provided is `None`, +then the explained variance score is calculated for each dimension separately +and a numpy array is returned. If the value given is `uniform`, then a weight +of unity is assigned to each dimension during macro-averaging, and a float +is returned. + Here a small example of usage of the :func:`explained_variance_score` function:: @@ -928,6 +935,11 @@ function:: >>> y_pred = [2.5, 0.0, 2, 8] >>> explained_variance_score(y_true, y_pred) # doctest: +ELLIPSIS 0.957... + >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] + >>> y_pred = [[0, 2], [-1, 2], [8, -5]] + >>> explained_variance_score(y_true, + ... y_pred, output_weights=None) # doctest: +ELLIPSIS + array([ 0.967..., 1. ]) Mean absolute error ................... diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index dabaaf0901651..3ef9da9762433 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -2045,7 +2045,7 @@ def mean_squared_error(y_true, y_pred, output_weights='uniform'): ############################################################################### # Regression score functions ############################################################################### -def explained_variance_score(y_true, y_pred): +def explained_variance_score(y_true, y_pred, output_weights='uniform'): """Explained variance regression score function Best possible score is 1.0, lower values are worse. @@ -2058,10 +2058,19 @@ def explained_variance_score(y_true, y_pred): y_pred : array-like Estimated target values. + output_weights : string, ['uniform' (default), None] + Assigns weight to each dimension of the given input. + If the option given is uniform, then a weight of unity is assigned + to each dimension while averaging. If the option given is None, then + no averaging is done. + Returns ------- - score : float - The explained variance. + score : float or a numpy array of shape [n_outputs] + If output_weights is 'uniform', it returns the macro-averaged explained + variance score. + If output_weights is None, it returns a numpy array of floats corresponding + to the explained variance score of each dimension. Notes ----- @@ -2074,24 +2083,38 @@ def explained_variance_score(y_true, y_pred): >>> y_pred = [2.5, 0.0, 2, 8] >>> explained_variance_score(y_true, y_pred) # doctest: +ELLIPSIS 0.957... - + >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] + >>> y_pred = [[0, 2], [-1, 2], [8, -5]] + >>> explained_variance_score(y_true, + ... y_pred, output_weights=None) # doctest: +ELLIPSIS + array([ 0.967..., 1. ]) """ y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) + output_weights_options = (None, 'uniform') + if output_weights not in output_weights_options: + raise ValueError('output_weights has to be one of ' + + str(output_weights_options)) - if y_type != "continuous": - raise ValueError("{0} is not supported".format(y_type)) + numerator = np.var(y_true - y_pred, axis=0) + denominator = np.var(y_true, axis=0) - numerator = np.var(y_true - y_pred) - denominator = np.var(y_true) - if denominator == 0.0: - if numerator == 0.0: - return 1.0 - else: - # arbitrary set to zero to avoid -inf scores, having a constant - # y_true is not interesting for scoring a regression anyway - return 0.0 - return 1 - numerator / denominator + # Setting an array of ones for the case in which both numerator + # and denominator are zero. + explained_variance = np.ones(y_true.shape[1]) + nonzero_denominator = (denominator != 0.0) + nonzero_numerator = (numerator != 0.0) + valid_score = np.logical_and(nonzero_numerator, + nonzero_denominator) + explained_variance[valid_score] = (1 - + numerator[valid_score]/denominator[valid_score]) + # arbitrary set to zero to avoid -inf scores, having a constant + # y_true is not interesting for scoring a regression anyway + explained_variance[np.logical_and(nonzero_numerator, + np.logical_not(nonzero_denominator))] = 0.0 + if output_weights == 'uniform': + return np.mean(explained_variance) + return explained_variance def r2_score(y_true, y_pred, output_weights='uniform'): """R^2 (coefficient of determination) regression score function. diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 9159f808a9923..ee429721dd854 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -204,6 +204,7 @@ "mean_absolute_error": mean_absolute_error, "mean_squared_error": mean_squared_error, "r2_score": r2_score, + "explained_variance_score": explained_variance_score } @@ -2007,9 +2008,11 @@ def test_regression_multioutput_array(): mse = mean_squared_error(y_true, y_pred, output_weights=None) mae = mean_absolute_error(y_true, y_pred, output_weights=None) r = r2_score(y_true, y_pred, output_weights=None) + evs = explained_variance_score(y_true, y_pred, output_weights=None) assert_array_equal(mse, np.array([0.125, 0.5625])) assert_array_equal(mae, np.array([0.25, 0.625])) assert_array_almost_equal(r, np.array([0.95, 0.93]), decimal=2) + assert_array_almost_equal(evs, np.array([0.95, 0.93]), decimal=2) # mean_absolute_error and mean_squared_error are equal because # it is a binary problem. @@ -2025,6 +2028,9 @@ def test_regression_multioutput_array(): r = r2_score([[0, -1], [0, 1]], [[2, 2], [1, 1]], output_weights=None) assert_array_equal(r, np.array([0, -3.5])) assert_equal(np.mean(r), r2_score([[0, -1], [0, 1]], [[2, 2], [1, 1]])) + evs = explained_variance_score([[0, -1], [0, 1]], [[2, 2], [1, 1]], + output_weights=None) + assert_array_equal(evs, np.array([0, -1.25])) # Checking for the condition in which both numerator and denominator is # zero. @@ -2033,3 +2039,6 @@ def test_regression_multioutput_array(): r = r2_score(y_true, y_pred, output_weights=None) assert_array_equal(r, np.array([1., -3.])) assert_equal(np.mean(r), r2_score(y_true, y_pred)) + evs = explained_variance_score(y_true, y_pred, output_weights=None) + assert_array_equal(evs, np.array([1., -3.])) + assert_equal(np.mean(evs), explained_variance_score(y_true, y_pred)) From 887ed0734faf50c8dc7886624114e18e933daf2b Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Fri, 18 Oct 2013 17:55:01 +0530 Subject: [PATCH 09/13] DOC: Minor change --- doc/modules/model_evaluation.rst | 10 ++++++---- sklearn/metrics/metrics.py | 22 ++++++++++++++-------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 0659601151926..9062e4e93f2da 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -937,8 +937,8 @@ function:: 0.957... >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] - >>> explained_variance_score(y_true, - ... y_pred, output_weights=None) # doctest: +ELLIPSIS + >>> explained_variance_score(y_true, y_pred, + ... output_weights=None) # doctest: +ELLIPSIS array([ 0.967..., 1. ]) Mean absolute error @@ -1012,7 +1012,8 @@ function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.7083... - >>> mean_squared_error(y_true, y_pred, output_weights=None) # doctest: +ELLIPSIS + >>> mean_squared_error(y_true, y_pred, + ... output_weights=None) # doctest: +ELLIPSIS array([ 0.416..., 1. ]) .. topic:: Examples: @@ -1056,7 +1057,8 @@ Here a small example of usage of the :func:`r2_score` function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS 0.936... - >>> r2_score(y_true, y_pred, output_weights=None) # doctest: +ELLIPSIS + >>> r2_score(y_true, y_pred, + ... output_weights=None) # doctest: +ELLIPSIS array([ 0.965..., 0.908...]) .. topic:: Example: diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 3ef9da9762433..20d2ed6d7820f 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -1959,7 +1959,8 @@ def mean_absolute_error(y_true, y_pred, output_weights='uniform'): Assigns weight to each dimension of the given input. If the option given is uniform, then a weight of unity is assigned to each dimension during averaging. If the option given is None, - then no averaging is done. + then no averaging is done. This parameter is useful only for multi-output + tasks where both y_true and y_pred have shape [n_samples, n_outputs] Returns ------- @@ -2008,7 +2009,8 @@ def mean_squared_error(y_true, y_pred, output_weights='uniform'): Assigns weight to each dimension of the given input. If the option given is uniform, then a weight of unity is assigned to each dimension during averaging. If the option given is None, - then no averaging is done. + then no averaging is done. This parameter is useful only for multi-output + tasks where both y_true and y_pred have shape [n_samples, n_outputs] Returns ------- @@ -2027,7 +2029,8 @@ def mean_squared_error(y_true, y_pred, output_weights='uniform'): >>> y_pred = [[0, 2],[-1, 2],[8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.708... - >>> mean_squared_error(y_true, y_pred, output_weights=None) # doctest: +ELLIPSIS + >>> mean_squared_error(y_true, y_pred, + ... output_weights=None) # doctest: +ELLIPSIS array([ 0.416..., 1. ]) """ output_weights_options = (None, 'uniform') @@ -2062,7 +2065,8 @@ def explained_variance_score(y_true, y_pred, output_weights='uniform'): Assigns weight to each dimension of the given input. If the option given is uniform, then a weight of unity is assigned to each dimension while averaging. If the option given is None, then - no averaging is done. + no averaging is done. This parameter is useful only for multi-output + tasks where both y_true and y_pred have shape [n_samples, n_outputs] Returns ------- @@ -2085,8 +2089,8 @@ def explained_variance_score(y_true, y_pred, output_weights='uniform'): 0.957... >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] - >>> explained_variance_score(y_true, - ... y_pred, output_weights=None) # doctest: +ELLIPSIS + >>> explained_variance_score(y_true, y_pred, + ... output_weights=None) # doctest: +ELLIPSIS array([ 0.967..., 1. ]) """ y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) @@ -2133,7 +2137,8 @@ def r2_score(y_true, y_pred, output_weights='uniform'): Assigns weight to each dimension of the given input. If the option given is uniform, then a weight of unity is assigned to each dimension while averaging. If the option given is None, then - no averaging is done. + no averaging is done. This parameter is useful only for multi-output + tasks where both y_true and y_pred have shape [n_samples, n_outputs] Returns ------- @@ -2165,7 +2170,8 @@ def r2_score(y_true, y_pred, output_weights='uniform'): >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS 0.936... - >>> r2_score(y_true, y_pred, output_weights=None) # doctest: +ELLIPSIS + >>> r2_score(y_true, y_pred, + ... output_weights=None) # doctest: +ELLIPSIS array([ 0.965..., 0.908...]) """ y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) From aabdc4dae206eb7d5313240c74c5c0c44c36f627 Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Tue, 22 Oct 2013 00:12:08 +0530 Subject: [PATCH 10/13] Added custom_weights option --- doc/modules/model_evaluation.rst | 25 +++- sklearn/metrics/metrics.py | 203 +++++++++++++++++++------- sklearn/metrics/tests/test_metrics.py | 16 ++ 3 files changed, 182 insertions(+), 62 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 9062e4e93f2da..29c503db23789 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -937,9 +937,12 @@ function:: 0.957... >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] - >>> explained_variance_score(y_true, y_pred, - ... output_weights=None) # doctest: +ELLIPSIS + >>> explained_variance_score(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS array([ 0.967..., 1. ]) + >>> explained_variance_score(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.990... Mean absolute error ................... @@ -976,7 +979,9 @@ Here a small example of usage of the :func:`mean_absolute_error` function:: 0.75 >>> mean_absolute_error(y_true, y_pred, output_weights=None) array([ 0.5, 1. ]) - + >>> mean_absolute_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.849... Mean squared error ................... @@ -1012,9 +1017,12 @@ function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.7083... - >>> mean_squared_error(y_true, y_pred, - ... output_weights=None) # doctest: +ELLIPSIS + >>> mean_squared_error(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS array([ 0.416..., 1. ]) + >>> mean_squared_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.824... .. topic:: Examples: @@ -1057,9 +1065,12 @@ Here a small example of usage of the :func:`r2_score` function:: >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS 0.936... - >>> r2_score(y_true, y_pred, - ... output_weights=None) # doctest: +ELLIPSIS + >>> r2_score(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS array([ 0.965..., 0.908...]) + >>> r2_score(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.925... .. topic:: Example: diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 20d2ed6d7820f..38812799306d6 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -32,6 +32,7 @@ from ..utils import check_arrays from ..utils import deprecated from ..utils import column_or_1d +from ..utils import safe_asarray from ..utils.multiclass import unique_labels from ..utils.multiclass import type_of_target from ..utils.fixes import bincount @@ -1955,18 +1956,29 @@ def mean_absolute_error(y_true, y_pred, output_weights='uniform'): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - output_weights : string, ['uniform' (default), None] - Assigns weight to each dimension of the given input. - If the option given is uniform, then a weight of unity is assigned - to each dimension during averaging. If the option given is None, - then no averaging is done. This parameter is useful only for multi-output + output_weights : string, ['uniform' (default), None] or + array-like of shape = [n_outputs] + + This parameter is useful only for multi-output tasks where both y_true and y_pred have shape [n_samples, n_outputs] + ``'uniform'``: + A weight of unity is assigned to each dimension while averaging. + + ``None``: + No averaging is done. + + Apart from this custom_weights of shape [n_outputs] can be given that + assigns user-defined weight to each dimension of the given input. + Returns ------- - loss : float or a numpy array of shape[n_outputs] - If output_weights is 'uniform', a positive floating point value (the best value is 0.0). - Else, a numpy array of positive floating points is returned. + loss : float or a numpy array of shape [n_outputs] + If output_weights is None, it returns a numpy array of floats corresponding + to the R^2 score of each dimension. + + If output_weights is 'uniform' or user-defined it returns the corresponding + weighted macro-averaged R^2 score. Examples -------- @@ -1981,17 +1993,30 @@ def mean_absolute_error(y_true, y_pred, output_weights='uniform'): 0.75 >>> mean_absolute_error(y_true, y_pred, output_weights=None) array([ 0.5, 1. ]) + >>> mean_absolute_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.849... """ output_weights_options = (None, 'uniform') + y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) + output_shape = y_true.shape[1] if output_weights not in output_weights_options: - raise ValueError('output_weights has to be one of ' + - str(output_weights_options)) - if output_weights == 'uniform': - axis = None - else: - axis = 0 + # Check for custom weights. + output_weights = safe_asarray(output_weights) + if output_shape == 1: + raise ValueError("Custom weights are useful only in " + "multi output cases.") + elif output_shape != output_weights.shape[0]: + raise ValueError("Custom weights must have shape " + "(1, %d)." % output_shape) + y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - return np.mean(np.abs(y_pred - y_true), axis=axis) + mae_array = np.mean(np.abs(y_pred - y_true), axis=0) + if output_weights == 'uniform': + return np.mean(mae_array) + elif output_weights is None: + return mae_array + return np.average(mae_array, weights=output_weights) def mean_squared_error(y_true, y_pred, output_weights='uniform'): @@ -2005,18 +2030,29 @@ def mean_squared_error(y_true, y_pred, output_weights='uniform'): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - output_weights : string, ['uniform' (default), None] - Assigns weight to each dimension of the given input. - If the option given is uniform, then a weight of unity is assigned - to each dimension during averaging. If the option given is None, - then no averaging is done. This parameter is useful only for multi-output + output_weights : string, ['uniform' (default), None] or + array-like of shape = [n_outputs] + + This parameter is useful only for multi-output tasks where both y_true and y_pred have shape [n_samples, n_outputs] + ``'uniform'``: + A weight of unity is assigned to each dimension while averaging. + + ``None``: + No averaging is done. + + Apart from this custom_weights of shape [n_outputs] can be given that + assigns user-defined weight to each dimension of the given input. + Returns ------- loss : float or a numpy array of shape [n_outputs] - If output_weights is 'uniform', a positive floating point value (the best value is 0.0). - Else, a numpy array of positive floating points is returned. + If output_weights is None, it returns a numpy array of floats corresponding + to the R^2 score of each dimension. + + If output_weights is 'uniform' or user-defined it returns the corresponding + weighted macro-averaged R^2 score. Examples -------- @@ -2029,20 +2065,33 @@ def mean_squared_error(y_true, y_pred, output_weights='uniform'): >>> y_pred = [[0, 2],[-1, 2],[8, -5]] >>> mean_squared_error(y_true, y_pred) # doctest: +ELLIPSIS 0.708... - >>> mean_squared_error(y_true, y_pred, - ... output_weights=None) # doctest: +ELLIPSIS + >>> mean_squared_error(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS array([ 0.416..., 1. ]) + >>> mean_squared_error(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.824... """ output_weights_options = (None, 'uniform') + y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) + output_shape = y_true.shape[1] if output_weights not in output_weights_options: - raise ValueError('output_weights has to be one of ' + - str(output_weights_options)) + # Check for custom weights. + output_weights = safe_asarray(output_weights) + if output_shape == 1: + raise ValueError("Custom weights are useful only in " + "multi output cases.") + elif output_shape != output_weights.shape[0]: + raise ValueError("Custom weights must have shape " + "(1, %d)." % output_shape) + + mse_array = np.mean((y_pred - y_true)**2, axis=0) if output_weights == 'uniform': - axis = None + return np.mean(mse_array) + elif output_weights is None: + return mse_array else: - axis = 0 - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - return np.mean((y_pred - y_true) ** 2, axis=axis) + return np.average(mse_array, weights=output_weights) ############################################################################### @@ -2061,20 +2110,29 @@ def explained_variance_score(y_true, y_pred, output_weights='uniform'): y_pred : array-like Estimated target values. - output_weights : string, ['uniform' (default), None] - Assigns weight to each dimension of the given input. - If the option given is uniform, then a weight of unity is assigned - to each dimension while averaging. If the option given is None, then - no averaging is done. This parameter is useful only for multi-output + output_weights : string, ['uniform' (default), None] or + array-like of shape = [n_outputs] + + This parameter is useful only for multi-output tasks where both y_true and y_pred have shape [n_samples, n_outputs] + ``'uniform'``: + A weight of unity is assigned to each dimension while averaging. + + ``None``: + No averaging is done. + + Apart from this custom_weights of shape [n_outputs] can be given that + assigns user-defined weight to each dimension of the given input. + Returns ------- score : float or a numpy array of shape [n_outputs] - If output_weights is 'uniform', it returns the macro-averaged explained - variance score. If output_weights is None, it returns a numpy array of floats corresponding - to the explained variance score of each dimension. + to the R^2 score of each dimension. + + If output_weights is 'uniform' or user-defined it returns the corresponding + weighted macro-averaged R^2 score. Notes ----- @@ -2089,15 +2147,25 @@ def explained_variance_score(y_true, y_pred, output_weights='uniform'): 0.957... >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] - >>> explained_variance_score(y_true, y_pred, - ... output_weights=None) # doctest: +ELLIPSIS + >>> explained_variance_score(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS array([ 0.967..., 1. ]) + >>> explained_variance_score(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.990... """ y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) + output_shape = y_true.shape[1] output_weights_options = (None, 'uniform') if output_weights not in output_weights_options: - raise ValueError('output_weights has to be one of ' + - str(output_weights_options)) + # Check for custom weights. + output_weights = safe_asarray(output_weights) + if output_shape == 1: + raise ValueError("Custom weights are useful only in " + "multi output cases.") + elif output_shape != output_weights.shape[0]: + raise ValueError("Custom weights must have shape " + "(1, %d)." % output_shape) numerator = np.var(y_true - y_pred, axis=0) denominator = np.var(y_true, axis=0) @@ -2118,7 +2186,9 @@ def explained_variance_score(y_true, y_pred, output_weights='uniform'): np.logical_not(nonzero_denominator))] = 0.0 if output_weights == 'uniform': return np.mean(explained_variance) - return explained_variance + elif output_weights is None: + return explained_variance + return np.average(explained_variance, weights=output_weights) def r2_score(y_true, y_pred, output_weights='uniform'): """R^2 (coefficient of determination) regression score function. @@ -2133,20 +2203,31 @@ def r2_score(y_true, y_pred, output_weights='uniform'): y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. - output_weights : string, ['uniform' (default), None] - Assigns weight to each dimension of the given input. - If the option given is uniform, then a weight of unity is assigned - to each dimension while averaging. If the option given is None, then - no averaging is done. This parameter is useful only for multi-output + output_weights : string, ['uniform' (default), None] or + array-like of shape = [n_outputs] + + This parameter is useful only for multi-output tasks where both y_true and y_pred have shape [n_samples, n_outputs] + ``'uniform'``: + A weight of unity is assigned to each dimension while averaging. + + ``None``: + No averaging is done. + + Apart from this custom_weights of shape [n_outputs] can be given that + assigns user-defined weight to each dimension of the given input. + + Returns ------- z : float or a numpy array of shape [n_outputs] - If output_weights is 'uniform', it returns the macro-averaged R^2 score, If output_weights is None, it returns a numpy array of floats corresponding to the R^2 score of each dimension. + If output_weights is 'uniform' or user-defined it returns the corresponding + weighted macro-averaged R^2 score. + Notes ----- This is not a symmetric function. @@ -2170,26 +2251,36 @@ def r2_score(y_true, y_pred, output_weights='uniform'): >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> r2_score(y_true, y_pred) # doctest: +ELLIPSIS 0.936... - >>> r2_score(y_true, y_pred, - ... output_weights=None) # doctest: +ELLIPSIS + >>> r2_score(y_true, y_pred, output_weights=None) + ... # doctest: +ELLIPSIS array([ 0.965..., 0.908...]) + >>> r2_score(y_true, y_pred, output_weights=[0.3, 0.7]) + ... # doctest: +ELLIPSIS + 0.925... """ y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) + output_shape = y_true.shape[1] if len(y_true) == 1: raise ValueError("r2_score can only be computed given more than one" " sample.") output_weights_options = (None, 'uniform') if output_weights not in output_weights_options: - raise ValueError('output_weights has to be one of ' + - str(output_weights_options)) + # Check for custom weights. + output_weights = safe_asarray(output_weights) + if output_shape == 1: + raise ValueError("Custom weights are useful only in " + "multi output cases.") + elif output_shape != output_weights.shape[0]: + raise ValueError("Custom weights must have shape " + "(1, %d)." % output_shape) numerator = ((y_true - y_pred) ** 2).sum(dtype=np.float64, axis=0) denominator = ((y_true - y_true.mean(axis=0)) ** 2).sum(dtype=np.float64, axis=0) # Set an array of ones for the condition that both numerator # and denominator are zero. - r2 = np.ones(y_true.shape[1]) + r2 = np.ones(output_shape) nonzero_denominator = (denominator != 0.0) nonzero_numerator = (numerator != 0.0) valid_score = np.logical_and(nonzero_numerator, nonzero_denominator) @@ -2201,4 +2292,6 @@ def r2_score(y_true, y_pred, output_weights='uniform'): nonzero_numerator)] = 0.0 if output_weights == 'uniform': return np.mean(r2) - return r2 + elif output_weights is None: + return r2 + return np.average(r2, weights=output_weights) diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index ee429721dd854..c29bf5717c834 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -2009,6 +2009,7 @@ def test_regression_multioutput_array(): mae = mean_absolute_error(y_true, y_pred, output_weights=None) r = r2_score(y_true, y_pred, output_weights=None) evs = explained_variance_score(y_true, y_pred, output_weights=None) + assert_array_equal(mse, np.array([0.125, 0.5625])) assert_array_equal(mae, np.array([0.25, 0.625])) assert_array_almost_equal(r, np.array([0.95, 0.93]), decimal=2) @@ -2042,3 +2043,18 @@ def test_regression_multioutput_array(): evs = explained_variance_score(y_true, y_pred, output_weights=None) assert_array_equal(evs, np.array([1., -3.])) assert_equal(np.mean(evs), explained_variance_score(y_true, y_pred)) + + +def test_regression_custom_weights(): + y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]] + y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]] + + msew = mean_squared_error(y_true, y_pred, output_weights=[0.4, 0.6]) + maew = mean_absolute_error(y_true, y_pred, output_weights=[0.4, 0.6]) + rw = r2_score(y_true, y_pred, output_weights=[0.4, 0.6]) + evsw = explained_variance_score(y_true, y_pred, output_weights=[0.4, 0.6]) + + assert_almost_equal(msew, 0.39, decimal=2) + assert_almost_equal(maew, 0.475, decimal=3) + assert_almost_equal(rw, 0.94, decimal=2) + assert_almost_equal(evsw, 0.94, decimal=2) From 057c0c406f04fc4ba81dd7c5e52208b62a884e59 Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Wed, 23 Oct 2013 00:07:08 +0530 Subject: [PATCH 11/13] Moved stuff to _check_reg_targets; Doc changes --- doc/modules/model_evaluation.rst | 18 +++-- sklearn/metrics/metrics.py | 108 +++++++++++--------------- sklearn/metrics/tests/test_metrics.py | 5 +- 3 files changed, 57 insertions(+), 74 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 29c503db23789..150e592871e6a 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -923,9 +923,9 @@ The best possible score is 1.0, lower values are worse. The :func:`explained_variance_score` function has an `output_weights` keyword with two possible values `None` and 'uniform'. If the value provided is `None`, then the explained variance score is calculated for each dimension separately -and a numpy array is returned. If the value given is `uniform`, then a weight -of unity is assigned to each dimension during macro-averaging, and a float -is returned. +and a numpy array is returned. If the value given is `uniform`, the +explained variance error is averaged over each dimension with a weight of +`1 / n_outputs`. Here a small example of usage of the :func:`explained_variance_score` function:: @@ -964,7 +964,8 @@ The :func:`mean_absolute_error` function has an `output_weights` keyword with two possible values `None` and 'uniform'. If the value provided is `None`, then the mean absolute error is calculated for each dimension separately and a numpy array is returned. If the value given is `uniform`, the -mean absolute error is averaged over each dimension with a weight of `1 / n_outputs`. +mean absolute error is averaged over each dimension with a weight of +`1 / n_outputs`. Here a small example of usage of the :func:`mean_absolute_error` function:: @@ -1003,7 +1004,8 @@ The :func:`mean_squared_error` function has an `output_weights` keyword with two possible values `None` and 'uniform'. If the value provided is `None`, then the mean squared error is calculated for each dimension separately and a numpy array is returned. If the value given is `uniform`, the -mean squared error is averaged over each dimension with a weight of `1 / n_outputs`. +mean squared error is averaged over each dimension with a weight of +`1 / n_outputs`. Here a small example of usage of the :func:`mean_squared_error` function:: @@ -1050,9 +1052,9 @@ where :math:`\bar{y} = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{sample The :func:`r2_score` function has an `output_weights` keyword with two possible values `None` and 'uniform'. If the value provided is `None`, then the r2 score -is calculated for each dimension separately and a numpy array is returned. If the -value given is `uniform`, then a weight of unity is assigned to each dimension during -macro-averaging, and a float is returned. +is calculated for each dimension separately and a numpy array is returned. + If the value given is `uniform`, the r2 score is averaged over each dimension + with a weight of `1 / n_outputs`. Here a small example of usage of the :func:`r2_score` function:: diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 38812799306d6..90d8454b52fa5 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -41,8 +41,9 @@ ############################################################################### # General utilities ############################################################################### -def _check_reg_targets(y_true, y_pred): - """Check that y_true and y_pred belong to the same regression task +def _check_reg_targets(y_true, y_pred, output_weights): + """Check that y_true, y_pred and output_weights belong to the + same regression task Parameters ---------- @@ -50,6 +51,8 @@ def _check_reg_targets(y_true, y_pred): y_pred : array-like, + output_weights : array-like + Returns ------- type_true : one of {'continuous', continuous-multioutput'} @@ -61,6 +64,9 @@ def _check_reg_targets(y_true, y_pred): y_pred : array-like of shape = [n_samples, n_outputs] Estimated target values. + + output_weights : array-like of shape = [n_outputs] + Custom weights. """ y_true, y_pred = check_arrays(y_true, y_pred) @@ -74,9 +80,20 @@ def _check_reg_targets(y_true, y_pred): raise ValueError("y_true and y_pred have different number of output " "({0}!={1})".format(y_true.shape[1], y_pred.shape[1])) + n_outputs = y_true.shape[1] + output_weights_options = (None, 'uniform') + if output_weights not in output_weights_options: + output_weights = safe_asarray(output_weights) + if n_outputs == 1: + raise ValueError("Custom weights are useful only in " + "multi output cases.") + elif n_outputs != output_weights.shape[0]: + raise ValueError("Custom weights must have shape " + "(1, %d)." % output_shape) + y_type = 'continuous' if y_true.shape[1] == 1 else 'continuous-multioutput' - return y_type, y_true, y_pred + return y_type, y_true, y_pred, output_weights def _check_clf_targets(y_true, y_pred): @@ -1963,7 +1980,8 @@ def mean_absolute_error(y_true, y_pred, output_weights='uniform'): tasks where both y_true and y_pred have shape [n_samples, n_outputs] ``'uniform'``: - A weight of unity is assigned to each dimension while averaging. + A weight of 1/n_outputs is assigned to each dimension while + averaging. ``None``: No averaging is done. @@ -1975,10 +1993,10 @@ def mean_absolute_error(y_true, y_pred, output_weights='uniform'): ------- loss : float or a numpy array of shape [n_outputs] If output_weights is None, it returns a numpy array of floats corresponding - to the R^2 score of each dimension. + to the mean absolute error of each dimension. If output_weights is 'uniform' or user-defined it returns the corresponding - weighted macro-averaged R^2 score. + weighted macro-averaged mean absolute error. Examples -------- @@ -1997,20 +2015,9 @@ def mean_absolute_error(y_true, y_pred, output_weights='uniform'): ... # doctest: +ELLIPSIS 0.849... """ - output_weights_options = (None, 'uniform') - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - output_shape = y_true.shape[1] - if output_weights not in output_weights_options: - # Check for custom weights. - output_weights = safe_asarray(output_weights) - if output_shape == 1: - raise ValueError("Custom weights are useful only in " - "multi output cases.") - elif output_shape != output_weights.shape[0]: - raise ValueError("Custom weights must have shape " - "(1, %d)." % output_shape) + y_type, y_true, y_pred, output_weights = \ + _check_reg_targets(y_true, y_pred, output_weights) - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) mae_array = np.mean(np.abs(y_pred - y_true), axis=0) if output_weights == 'uniform': return np.mean(mae_array) @@ -2037,7 +2044,8 @@ def mean_squared_error(y_true, y_pred, output_weights='uniform'): tasks where both y_true and y_pred have shape [n_samples, n_outputs] ``'uniform'``: - A weight of unity is assigned to each dimension while averaging. + A weight of 1/n_outputs is assigned to each dimension while + averaging. ``None``: No averaging is done. @@ -2049,10 +2057,10 @@ def mean_squared_error(y_true, y_pred, output_weights='uniform'): ------- loss : float or a numpy array of shape [n_outputs] If output_weights is None, it returns a numpy array of floats corresponding - to the R^2 score of each dimension. + to the mean squared error of each dimension. If output_weights is 'uniform' or user-defined it returns the corresponding - weighted macro-averaged R^2 score. + weighted macro-averaged mean squared error. Examples -------- @@ -2072,18 +2080,8 @@ def mean_squared_error(y_true, y_pred, output_weights='uniform'): ... # doctest: +ELLIPSIS 0.824... """ - output_weights_options = (None, 'uniform') - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - output_shape = y_true.shape[1] - if output_weights not in output_weights_options: - # Check for custom weights. - output_weights = safe_asarray(output_weights) - if output_shape == 1: - raise ValueError("Custom weights are useful only in " - "multi output cases.") - elif output_shape != output_weights.shape[0]: - raise ValueError("Custom weights must have shape " - "(1, %d)." % output_shape) + y_type, y_true, y_pred, output_weights = \ + _check_reg_targets(y_true, y_pred, output_weights) mse_array = np.mean((y_pred - y_true)**2, axis=0) if output_weights == 'uniform': @@ -2117,7 +2115,8 @@ def explained_variance_score(y_true, y_pred, output_weights='uniform'): tasks where both y_true and y_pred have shape [n_samples, n_outputs] ``'uniform'``: - A weight of unity is assigned to each dimension while averaging. + A weight of 1/n_outputs is assigned to each dimension while + averaging. ``None``: No averaging is done. @@ -2129,10 +2128,10 @@ def explained_variance_score(y_true, y_pred, output_weights='uniform'): ------- score : float or a numpy array of shape [n_outputs] If output_weights is None, it returns a numpy array of floats corresponding - to the R^2 score of each dimension. + to the explained variance score of each dimension. If output_weights is 'uniform' or user-defined it returns the corresponding - weighted macro-averaged R^2 score. + weighted macro-averaged explained variance score. Notes ----- @@ -2154,18 +2153,8 @@ def explained_variance_score(y_true, y_pred, output_weights='uniform'): ... # doctest: +ELLIPSIS 0.990... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - output_shape = y_true.shape[1] - output_weights_options = (None, 'uniform') - if output_weights not in output_weights_options: - # Check for custom weights. - output_weights = safe_asarray(output_weights) - if output_shape == 1: - raise ValueError("Custom weights are useful only in " - "multi output cases.") - elif output_shape != output_weights.shape[0]: - raise ValueError("Custom weights must have shape " - "(1, %d)." % output_shape) + y_type, y_true, y_pred, output_weights = \ + _check_reg_targets(y_true, y_pred, output_weights) numerator = np.var(y_true - y_pred, axis=0) denominator = np.var(y_true, axis=0) @@ -2210,7 +2199,8 @@ def r2_score(y_true, y_pred, output_weights='uniform'): tasks where both y_true and y_pred have shape [n_samples, n_outputs] ``'uniform'``: - A weight of unity is assigned to each dimension while averaging. + A weight of 1/n_outputs is assigned to each dimension while + averaging. ``None``: No averaging is done. @@ -2258,29 +2248,19 @@ def r2_score(y_true, y_pred, output_weights='uniform'): ... # doctest: +ELLIPSIS 0.925... """ - y_type, y_true, y_pred = _check_reg_targets(y_true, y_pred) - output_shape = y_true.shape[1] + y_type, y_true, y_pred, output_weights = \ + _check_reg_targets(y_true, y_pred, output_weights) + if len(y_true) == 1: raise ValueError("r2_score can only be computed given more than one" " sample.") - output_weights_options = (None, 'uniform') - if output_weights not in output_weights_options: - # Check for custom weights. - output_weights = safe_asarray(output_weights) - if output_shape == 1: - raise ValueError("Custom weights are useful only in " - "multi output cases.") - elif output_shape != output_weights.shape[0]: - raise ValueError("Custom weights must have shape " - "(1, %d)." % output_shape) - numerator = ((y_true - y_pred) ** 2).sum(dtype=np.float64, axis=0) denominator = ((y_true - y_true.mean(axis=0)) ** 2).sum(dtype=np.float64, axis=0) # Set an array of ones for the condition that both numerator # and denominator are zero. - r2 = np.ones(output_shape) + r2 = np.ones(y_true.shape[1]) nonzero_denominator = (denominator != 0.0) nonzero_numerator = (numerator != 0.0) valid_score = np.logical_and(nonzero_numerator, nonzero_denominator) diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index c29bf5717c834..6dc8bf06f8c56 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -1963,7 +1963,8 @@ def test__check_reg_targets(): repeat=2): if type1 == type2 and n_out1 == n_out2: - y_type, y_check1, y_check2 = _check_reg_targets(y1, y2) + y_type, y_check1, y_check2, output_weights = \ + _check_reg_targets(y1, y2, None) assert_equal(type1, y_type) if type1 == 'continuous': assert_array_equal(y_check1, np.reshape(y1, (-1, 1))) @@ -1972,7 +1973,7 @@ def test__check_reg_targets(): assert_array_equal(y_check1, y1) assert_array_equal(y_check2, y2) else: - assert_raises(ValueError, _check_reg_targets, y1, y2) + assert_raises(ValueError, _check_reg_targets, y1, y2, None) def test_log_loss(): From de43e6ea93ec30daa9bda1668c13348a1c3834a1 Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Wed, 23 Oct 2013 14:41:34 +0530 Subject: [PATCH 12/13] Minor changes --- sklearn/metrics/metrics.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 90d8454b52fa5..3071dd6212891 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -51,7 +51,7 @@ def _check_reg_targets(y_true, y_pred, output_weights): y_pred : array-like, - output_weights : array-like + output_weights : array-like or string, ['uniform', None] Returns ------- @@ -66,7 +66,13 @@ def _check_reg_targets(y_true, y_pred, output_weights): Estimated target values. output_weights : array-like of shape = [n_outputs] - Custom weights. + or string, ['uniform', None] + + 1] custom weights, if output_weights provided is + array-like + 2] 'uniform' or None if output_weights provided is + 'uniform' or None. + """ y_true, y_pred = check_arrays(y_true, y_pred) @@ -91,7 +97,7 @@ def _check_reg_targets(y_true, y_pred, output_weights): raise ValueError("Custom weights must have shape " "(1, %d)." % output_shape) - y_type = 'continuous' if y_true.shape[1] == 1 else 'continuous-multioutput' + y_type = 'continuous' if n_outputs == 1 else 'continuous-multioutput' return y_type, y_true, y_pred, output_weights @@ -2018,12 +2024,12 @@ def mean_absolute_error(y_true, y_pred, output_weights='uniform'): y_type, y_true, y_pred, output_weights = \ _check_reg_targets(y_true, y_pred, output_weights) - mae_array = np.mean(np.abs(y_pred - y_true), axis=0) + error = np.mean(np.abs(y_pred - y_true), axis=0) if output_weights == 'uniform': - return np.mean(mae_array) + return np.mean(error) elif output_weights is None: - return mae_array - return np.average(mae_array, weights=output_weights) + return error + return np.average(error, weights=output_weights) def mean_squared_error(y_true, y_pred, output_weights='uniform'): @@ -2083,13 +2089,13 @@ def mean_squared_error(y_true, y_pred, output_weights='uniform'): y_type, y_true, y_pred, output_weights = \ _check_reg_targets(y_true, y_pred, output_weights) - mse_array = np.mean((y_pred - y_true)**2, axis=0) + error = np.mean((y_pred - y_true)**2, axis=0) if output_weights == 'uniform': - return np.mean(mse_array) + return np.mean(error) elif output_weights is None: - return mse_array + return error else: - return np.average(mse_array, weights=output_weights) + return np.average(error, weights=output_weights) ############################################################################### From b56ef150860881d1a218e6cfb83889309503fdfc Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Fri, 1 Nov 2013 01:44:55 +0530 Subject: [PATCH 13/13] Added docstring to advise the user to normalize y_true --- sklearn/metrics/metrics.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 3071dd6212891..e3b8c1f309a03 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -2040,6 +2040,8 @@ def mean_squared_error(y_true, y_pred, output_weights='uniform'): y_true : array-like of shape = [n_samples] or [n_samples, n_outputs] Ground truth (correct) target values. + It is recommended to normalize y_true, before using this function. + y_pred : array-like of shape = [n_samples] or [n_samples, n_outputs] Estimated target values. @@ -2111,6 +2113,8 @@ def explained_variance_score(y_true, y_pred, output_weights='uniform'): y_true : array-like Ground truth (correct) target values. + It is recommended to normalize y_true before using this function. + y_pred : array-like Estimated target values.