From 2ecc96d9699e89e961716d0a140df67b09aab34b Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 16 Apr 2020 11:08:20 +0200 Subject: [PATCH 01/46] check diabetes --- .../ensemble/tests/test_gradient_boosting.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index c7653ddac959c..c9626582e29ee 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -57,6 +57,10 @@ perm = rng.permutation(boston.target.size) boston.data = boston.data[perm] boston.target = boston.target[perm] +diabetes = datasets.load_diabetes() +perm = rng.permutation(diabetes.target.size) +diabetes.data = diabetes.data[perm] +diabetes.target = diabetes.target[perm] # also load the iris dataset # and randomly permute it @@ -215,7 +219,7 @@ def test_classification_synthetic(loss): def check_boston(loss, subsample): # Check consistency on dataset boston house prices with least squares # and least absolute deviation. - ones = np.ones(len(boston.target)) + ones = np.ones(len(diabetes.target)) last_y_pred = None for sample_weight in None, ones, 2 * ones: clf = GradientBoostingRegressor(n_estimators=100, @@ -225,15 +229,17 @@ def check_boston(loss, subsample): min_samples_split=2, random_state=1) - assert_raises(ValueError, clf.predict, boston.data) - clf.fit(boston.data, boston.target, + assert_raises(ValueError, clf.predict, diabetes.data) + clf.fit(diabetes.data, diabetes.target, sample_weight=sample_weight) - leaves = clf.apply(boston.data) - assert leaves.shape == (506, 100) - - y_pred = clf.predict(boston.data) - mse = mean_squared_error(boston.target, y_pred) - assert mse < 6.0 + leaves = clf.apply(diabetes.data) + assert leaves.shape == (442, 100) + + y_pred = clf.predict(diabetes.data) + mse = mean_squared_error(diabetes.target, y_pred) + print('loss:{}, sub:{}'.format(loss, subsample)) + print(mse) + assert mse < 2000.0 if last_y_pred is not None: assert_array_almost_equal(last_y_pred, y_pred) From deacbf5eb741e69a9627efebc9ba55185f093d83 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 16 Apr 2020 12:30:31 +0200 Subject: [PATCH 02/46] use diabetes and cali --- .../ensemble/tests/test_gradient_boosting.py | 62 +++++++++---------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index c9626582e29ee..1972cec1a0fa3 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -51,12 +51,8 @@ true_result = [-1, 1, 1] rng = np.random.RandomState(0) -# also load the boston dataset +# also load the diabetes dataset # and randomly permute it -boston = datasets.load_boston() -perm = rng.permutation(boston.target.size) -boston.data = boston.data[perm] -boston.target = boston.target[perm] diabetes = datasets.load_diabetes() perm = rng.permutation(diabetes.target.size) diabetes.data = diabetes.data[perm] @@ -216,10 +212,14 @@ def test_classification_synthetic(loss): check_classification_synthetic(loss) -def check_boston(loss, subsample): - # Check consistency on dataset boston house prices with least squares +def check_california(loss, subsample): + # Check consistency on dataset california house prices with least squares # and least absolute deviation. - ones = np.ones(len(diabetes.target)) + california = datasets.fetch_california_housing() + perm = rng.permutation(500) + california.data = california.data[perm] + california.target = california.target[perm] + ones = np.ones(len(california.target)) last_y_pred = None for sample_weight in None, ones, 2 * ones: clf = GradientBoostingRegressor(n_estimators=100, @@ -229,28 +229,26 @@ def check_boston(loss, subsample): min_samples_split=2, random_state=1) - assert_raises(ValueError, clf.predict, diabetes.data) - clf.fit(diabetes.data, diabetes.target, + assert_raises(ValueError, clf.predict, california.data) + clf.fit(california.data, california.target, sample_weight=sample_weight) - leaves = clf.apply(diabetes.data) - assert leaves.shape == (442, 100) + leaves = clf.apply(california.data) + assert leaves.shape == (500, 100) - y_pred = clf.predict(diabetes.data) - mse = mean_squared_error(diabetes.target, y_pred) - print('loss:{}, sub:{}'.format(loss, subsample)) - print(mse) - assert mse < 2000.0 + y_pred = clf.predict(california.data) + mse = mean_squared_error(california.target, y_pred) + assert mse < 0.1 if last_y_pred is not None: - assert_array_almost_equal(last_y_pred, y_pred) + assert_array_almost_equal(last_y_pred, y_pred, decimal=0) last_y_pred = y_pred @pytest.mark.parametrize('loss', ('ls', 'lad', 'huber')) @pytest.mark.parametrize('subsample', (1.0, 0.5)) -def test_boston(loss, subsample): - check_boston(loss, subsample) +def test_california(loss, subsample): + check_california(loss, subsample) def check_iris(subsample, sample_weight): @@ -317,8 +315,8 @@ def test_regression_synthetic(): def test_feature_importances(): - X = np.array(boston.data, dtype=np.float32) - y = np.array(boston.target, dtype=np.float32) + X = np.array(diabetes.data, dtype=np.float32) + y = np.array(diabetes.target, dtype=np.float32) clf = GradientBoostingRegressor(n_estimators=100, max_depth=5, min_samples_split=2, random_state=1) @@ -605,14 +603,14 @@ def test_quantile_loss(): max_depth=4, alpha=0.5, random_state=7) - clf_quantile.fit(boston.data, boston.target) - y_quantile = clf_quantile.predict(boston.data) + clf_quantile.fit(diabetes.data, diabetes.target) + y_quantile = clf_quantile.predict(diabetes.data) clf_lad = GradientBoostingRegressor(n_estimators=100, loss='lad', max_depth=4, random_state=7) - clf_lad.fit(boston.data, boston.target) - y_lad = clf_lad.predict(boston.data) + clf_lad.fit(diabetes.data, diabetes.target) + y_lad = clf_lad.predict(diabetes.data) assert_array_almost_equal(y_quantile, y_lad, decimal=4) @@ -1019,7 +1017,7 @@ def test_complete_regression(): est = GradientBoostingRegressor(n_estimators=20, max_depth=None, random_state=1, max_leaf_nodes=k + 1) - est.fit(boston.data, boston.target) + est.fit(diabetes.data, diabetes.target) tree = est.estimators_[-1, 0].tree_ assert (tree.children_left[tree.children_left == TREE_LEAF].shape[0] == @@ -1031,14 +1029,14 @@ def test_zero_estimator_reg(): est = GradientBoostingRegressor(n_estimators=20, max_depth=1, random_state=1, init='zero') - est.fit(boston.data, boston.target) - y_pred = est.predict(boston.data) - mse = mean_squared_error(boston.target, y_pred) - assert_almost_equal(mse, 33.0, decimal=0) + est.fit(diabetes.data, diabetes.target) + y_pred = est.predict(diabetes.data) + mse = mean_squared_error(diabetes.target, y_pred) + assert_almost_equal(mse, 3664.0, decimal=0) est = GradientBoostingRegressor(n_estimators=20, max_depth=1, random_state=1, init='foobar') - assert_raises(ValueError, est.fit, boston.data, boston.target) + assert_raises(ValueError, est.fit, diabetes.data, diabetes.target) def test_zero_estimator_clf(): From f257ff108ba88fa5084675465676411e387a772f Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 16 Apr 2020 12:31:26 +0200 Subject: [PATCH 03/46] pytest network --- sklearn/ensemble/tests/test_gradient_boosting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 1972cec1a0fa3..1504384bf58ef 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -245,6 +245,7 @@ def check_california(loss, subsample): last_y_pred = y_pred +@pytest.mark.network @pytest.mark.parametrize('loss', ('ls', 'lad', 'huber')) @pytest.mark.parametrize('subsample', (1.0, 0.5)) def test_california(loss, subsample): From 31a116edf5b7556eb1be3e4352852a412718cc79 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 28 May 2020 20:10:44 +0200 Subject: [PATCH 04/46] BUG make _weighted_percentile behave as NumPy --- sklearn/utils/stats.py | 74 +++++++++++++++++++++++++------ sklearn/utils/tests/test_stats.py | 16 +++++++ 2 files changed, 77 insertions(+), 13 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 7b44575e97b33..a4d52a9f11d27 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -4,7 +4,8 @@ from .fixes import _take_along_axis -def _weighted_percentile(array, sample_weight, percentile=50): +def _weighted_percentile(array, sample_weight, percentile=50, + interpolation="linear"): """Compute weighted percentile Computes lower weighted percentile. If `array` is a 2D array, the @@ -25,9 +26,19 @@ def _weighted_percentile(array, sample_weight, percentile=50): percentile: int, default=50 Percentile to compute. Must be value between 0 and 100. + interpolation : {"linear", "lower", "higher"}, default="linear" + This optional parameter specifies the interpolation method to + use when the desired percentile lies between two data points + ``i < j``: + * 'linear': ``i + (j - i) * fraction``, where ``fraction`` + is the fractional part of the index surrounded by ``i`` + and ``j``. + * 'lower': ``i``. + * 'higher': ``j``. + Returns ------- - percentile : int if `array` 1D, ndarray if `array` 2D + percentile_value : int if `array` 1D, ndarray if `array` 2D Weighted percentile. """ n_dim = array.ndim @@ -35,27 +46,64 @@ def _weighted_percentile(array, sample_weight, percentile=50): return array[()] if array.ndim == 1: array = array.reshape((-1, 1)) - # When sample_weight 1D, repeat for each array.shape[1] + if (array.shape != sample_weight.shape and array.shape[0] == sample_weight.shape[0]): + # when `sample_weight` is 1D, we repeat it for each column of `array` sample_weight = np.tile(sample_weight, (array.shape[1], 1)).T sorted_idx = np.argsort(array, axis=0) sorted_weights = _take_along_axis(sample_weight, sorted_idx, axis=0) - # Find index of median prediction for each sample + # find the lower percentile value indices weight_cdf = stable_cumsum(sorted_weights, axis=0) adjusted_percentile = percentile / 100 * weight_cdf[-1] - percentile_idx = np.array([ + percentile_value_lower_idx = np.array([ np.searchsorted(weight_cdf[:, i], adjusted_percentile[i]) for i in range(weight_cdf.shape[1]) ]) - percentile_idx = np.array(percentile_idx) - # In rare cases, percentile_idx equals to sorted_idx.shape[0] - max_idx = sorted_idx.shape[0] - 1 - percentile_idx = np.apply_along_axis(lambda x: np.clip(x, 0, max_idx), - axis=0, arr=percentile_idx) + # clip the indices in case the indices is the last element found + max_idx = sorted_idx.shape[0] - 1 col_index = np.arange(array.shape[1]) - percentile_in_sorted = sorted_idx[percentile_idx, col_index] - percentile = array[percentile_in_sorted, col_index] - return percentile[0] if n_dim == 1 else percentile + + def _reduce_dim(arr, n_dim): + return arr[0] if n_dim == 1 else arr + + if interpolation in ("lower", "linear"): + percentile_value_lower_idx = np.apply_along_axis( + lambda x: np.clip(x, 0, max_idx), axis=0, + arr=percentile_value_lower_idx, + ) + percentile_value_lower = array[ + sorted_idx[percentile_value_lower_idx, col_index] + ] + if interpolation == "lower": + return _reduce_dim(percentile_value_lower, n_dim) + + if interpolation in ("higher", "linear"): + percentile_value_higher_idx = np.apply_along_axis( + lambda x: np.clip(x, 0, max_idx), axis=0, + arr=percentile_value_lower_idx + 1, + ) + percentile_value_higher = array[ + sorted_idx[percentile_value_higher_idx, col_index] + ] + if interpolation == "higher": + return _reduce_dim(percentile_value_higher, n_dim) + + ratio = percentile / adjusted_percentile + percentile_lower = weight_cdf[percentile_value_lower_idx] * ratio + percentile_higher = weight_cdf[percentile_value_higher_idx] * ratio + + # interpolate linearly for the given percentile + percentile_value = ( + percentile_value_lower + (percentile - percentile_lower) * + ((percentile_value_higher - percentile_value_lower) / + (percentile_higher - percentile_lower)) + ) + print(percentile_higher, percentile_lower) + print(percentile_value_higher, percentile_value_lower) + print(percentile_value) + print(weight_cdf) + + return _reduce_dim(percentile_value, n_dim) diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index fe0d267393db0..d55131e42572f 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -87,3 +87,19 @@ def test_weighted_percentile_2d(): for i in range(x_2d.shape[1]) ] assert_allclose(w_median, p_axis_0) + + +def test_weighted_percentile_np_equivalent(): + # check that our weighted percentile lead to the same results than + # unweighted NumPy implementation with unit weights + rng = np.random.RandomState(42) + X = rng.randn(10) + X.sort() + sample_weight = np.ones(X.shape) + + np_median = np.median(X) + np_percentile = np.percentile(X, 50) + sklearn_median = _weighted_percentile(X, sample_weight, percentile=50.0) + + assert sklearn_median == approx(np_median) + assert sklearn_median == approx(np_percentile) From 6c8a4056a718ee4a439ea8409a8ba58b645205b2 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 May 2020 00:53:09 +0200 Subject: [PATCH 05/46] iter --- setup.cfg | 2 +- sklearn/utils/stats.py | 110 +++++++++++++++++------------- sklearn/utils/tests/test_stats.py | 91 ++++++++++++++++-------- 3 files changed, 127 insertions(+), 76 deletions(-) diff --git a/setup.cfg b/setup.cfg index f086993b26a29..95e4417b816e1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,7 @@ addopts = --ignore examples --ignore maint_tools --doctest-modules - --disable-pytest-warnings + # --disable-pytest-warnings -rxXs filterwarnings = diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index a4d52a9f11d27..19013e3b83470 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -35,12 +35,21 @@ def _weighted_percentile(array, sample_weight, percentile=50, and ``j``. * 'lower': ``i``. * 'higher': ``j``. + * 'nearest': ``i`` or ``j``, whichever is nearest. Returns ------- percentile_value : int if `array` 1D, ndarray if `array` 2D Weighted percentile. """ + possible_interpolation = ("linear", "lower", "higher", "nearest") + if interpolation not in possible_interpolation: + raise ValueError( + f"'interpolation' should be one of " + f"{', '.join(possible_interpolation)}. Got '{interpolation}' " + f"instead." + ) + n_dim = array.ndim if n_dim == 0: return array[()] @@ -51,59 +60,64 @@ def _weighted_percentile(array, sample_weight, percentile=50, array.shape[0] == sample_weight.shape[0]): # when `sample_weight` is 1D, we repeat it for each column of `array` sample_weight = np.tile(sample_weight, (array.shape[1], 1)).T + + n_rows, n_cols = array.shape + sorted_idx = np.argsort(array, axis=0) sorted_weights = _take_along_axis(sample_weight, sorted_idx, axis=0) - - # find the lower percentile value indices + percentile = [percentile / 100] * n_cols weight_cdf = stable_cumsum(sorted_weights, axis=0) - adjusted_percentile = percentile / 100 * weight_cdf[-1] - percentile_value_lower_idx = np.array([ - np.searchsorted(weight_cdf[:, i], adjusted_percentile[i]) - for i in range(weight_cdf.shape[1]) - ]) - # clip the indices in case the indices is the last element found - max_idx = sorted_idx.shape[0] - 1 - col_index = np.arange(array.shape[1]) - - def _reduce_dim(arr, n_dim): + def _squeeze_arr(arr, n_dim): return arr[0] if n_dim == 1 else arr - if interpolation in ("lower", "linear"): - percentile_value_lower_idx = np.apply_along_axis( - lambda x: np.clip(x, 0, max_idx), axis=0, - arr=percentile_value_lower_idx, - ) - percentile_value_lower = array[ - sorted_idx[percentile_value_lower_idx, col_index] - ] - if interpolation == "lower": - return _reduce_dim(percentile_value_lower, n_dim) - - if interpolation in ("higher", "linear"): - percentile_value_higher_idx = np.apply_along_axis( - lambda x: np.clip(x, 0, max_idx), axis=0, - arr=percentile_value_lower_idx + 1, + if interpolation == "nearest": + # compute by nearest-rank method + adjusted_percentile = percentile * weight_cdf[-1] + percentile_value_idx = np.array([ + np.searchsorted(weight_cdf[:, i], adjusted_percentile[i]) + for i in range(weight_cdf.shape[1]) + ]) + + percentile_value_idx = np.apply_along_axis( + lambda x: np.clip(x, 0, n_rows - 1), axis=0, + arr=percentile_value_idx ) - percentile_value_higher = array[ - sorted_idx[percentile_value_higher_idx, col_index] + percentile_value = array[ + sorted_idx[percentile_value_idx, np.arange(n_cols)], + np.arange(n_cols) ] - if interpolation == "higher": - return _reduce_dim(percentile_value_higher, n_dim) - - ratio = percentile / adjusted_percentile - percentile_lower = weight_cdf[percentile_value_lower_idx] * ratio - percentile_higher = weight_cdf[percentile_value_higher_idx] * ratio - - # interpolate linearly for the given percentile - percentile_value = ( - percentile_value_lower + (percentile - percentile_lower) * - ((percentile_value_higher - percentile_value_lower) / - (percentile_higher - percentile_lower)) - ) - print(percentile_higher, percentile_lower) - print(percentile_value_higher, percentile_value_lower) - print(percentile_value) - print(weight_cdf) - - return _reduce_dim(percentile_value, n_dim) + return _squeeze_arr(percentile_value, n_dim) + + elif interpolation in ("linear", "lower", "higher"): + # compute by linear interpolation between closest ranks method + adjusted_percentile = (weight_cdf - 0.5 * sorted_weights) + with np.errstate(invalid="ignore"): + adjusted_percentile /= weight_cdf[-1] + + if interpolation in ("lower", "higher"): + percentile_idx = np.array([ + np.searchsorted(adjusted_percentile[:, col], percentile[col], + side="left") + for col in range(adjusted_percentile.shape[1]) + ]) + if interpolation == "lower": + percentile_idx -= 1 + percentile_idx = np.apply_along_axis( + lambda x: np.clip(x, 0, n_rows - 1), axis=0, + arr=percentile_idx + ) + percentile = np.nan_to_num([ + adjusted_percentile[percentile_idx[i], i] + for i in range(adjusted_percentile.shape[1]) + ], nan=percentile) + + percentile_value = np.array([ + np.interp( + x=percentile[col], + xp=adjusted_percentile[:, col], + fp=array[sorted_idx[:, col], col], + ) + for col in range(array.shape[1]) + ]) + return _squeeze_arr(percentile_value, n_dim) diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index d55131e42572f..60c60d510ce5a 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -1,11 +1,12 @@ import numpy as np from numpy.testing import assert_allclose -from pytest import approx +import pytest from sklearn.utils.stats import _weighted_percentile -def test_weighted_percentile(): +@pytest.mark.parametrize("interpolation", ["linear", "nearest"]) +def test_weighted_percentile(interpolation): y = np.empty(102, dtype=np.float64) y[:50] = 0 y[-51:] = 2 @@ -13,29 +14,36 @@ def test_weighted_percentile(): y[50] = 1 sw = np.ones(102, dtype=np.float64) sw[-1] = 0.0 - score = _weighted_percentile(y, sw, 50) - assert approx(score) == 1 + score = _weighted_percentile(y, sw, 50, interpolation=interpolation) + assert score == pytest.approx(1) -def test_weighted_percentile_equal(): +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest"] +) +def test_weighted_percentile_equal(interpolation): y = np.empty(102, dtype=np.float64) y.fill(0.0) sw = np.ones(102, dtype=np.float64) sw[-1] = 0.0 - score = _weighted_percentile(y, sw, 50) + score = _weighted_percentile(y, sw, 50, interpolation=interpolation) assert score == 0 -def test_weighted_percentile_zero_weight(): +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest"] +) +def test_weighted_percentile_zero_weight(interpolation): y = np.empty(102, dtype=np.float64) y.fill(1.0) sw = np.ones(102, dtype=np.float64) sw.fill(0.0) - score = _weighted_percentile(y, sw, 50) - assert approx(score) == 1.0 + score = _weighted_percentile(y, sw, 50, interpolation=interpolation) + assert pytest.approx(score) == 1.0 -def test_weighted_median_equal_weights(): +@pytest.mark.parametrize("interpolation", ["linear", "nearest"]) +def test_weighted_median_equal_weights(interpolation): # Checks weighted percentile=0.5 is same as median when weights equal rng = np.random.RandomState(0) # Odd size as _weighted_percentile takes lower weighted percentile @@ -43,11 +51,12 @@ def test_weighted_median_equal_weights(): weights = np.ones(x.shape) median = np.median(x) - w_median = _weighted_percentile(x, weights) - assert median == approx(w_median) + w_median = _weighted_percentile(x, weights, interpolation=interpolation) + assert median == pytest.approx(w_median) -def test_weighted_median_integer_weights(): +@pytest.mark.parametrize("interpolation", ["linear", "nearest"]) +def test_weighted_median_integer_weights(interpolation): # Checks weighted percentile=0.5 is same as median when manually weight # data rng = np.random.RandomState(0) @@ -56,12 +65,15 @@ def test_weighted_median_integer_weights(): x_manual = np.repeat(x, weights) median = np.median(x_manual) - w_median = _weighted_percentile(x, weights) + w_median = _weighted_percentile(x, weights, interpolation=interpolation) - assert median == approx(w_median) + assert median == pytest.approx(w_median) -def test_weighted_percentile_2d(): +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest"] +) +def test_weighted_percentile_2d(interpolation): # Check for when array 2D and sample_weight 1D rng = np.random.RandomState(0) x1 = rng.randint(10, size=10) @@ -70,36 +82,61 @@ def test_weighted_percentile_2d(): x2 = rng.randint(20, size=10) x_2d = np.vstack((x1, x2)).T - w_median = _weighted_percentile(x_2d, w1) + w_median = _weighted_percentile(x_2d, w1, interpolation=interpolation) p_axis_0 = [ - _weighted_percentile(x_2d[:, i], w1) + _weighted_percentile(x_2d[:, i], w1, interpolation=interpolation) for i in range(x_2d.shape[1]) ] assert_allclose(w_median, p_axis_0) - # Check when array and sample_weight boht 2D + # Check when array and sample_weight both 2D w2 = rng.choice(5, size=10) w_2d = np.vstack((w1, w2)).T - w_median = _weighted_percentile(x_2d, w_2d) + w_median = _weighted_percentile(x_2d, w_2d, interpolation=interpolation) p_axis_0 = [ - _weighted_percentile(x_2d[:, i], w_2d[:, i]) + _weighted_percentile(x_2d[:, i], w_2d[:, i], + interpolation=interpolation) for i in range(x_2d.shape[1]) ] assert_allclose(w_median, p_axis_0) -def test_weighted_percentile_np_equivalent(): +def test_weighted_percentile_np_median(): # check that our weighted percentile lead to the same results than - # unweighted NumPy implementation with unit weights + # unweighted NumPy implementation with unit weights for the median rng = np.random.RandomState(42) X = rng.randn(10) X.sort() sample_weight = np.ones(X.shape) np_median = np.median(X) - np_percentile = np.percentile(X, 50) - sklearn_median = _weighted_percentile(X, sample_weight, percentile=50.0) + sklearn_median = _weighted_percentile( + X, sample_weight, percentile=50.0, interpolation="linear" + ) - assert sklearn_median == approx(np_median) - assert sklearn_median == approx(np_percentile) + assert sklearn_median == pytest.approx(np_median) + + +@pytest.mark.parametrize("interpolation", ["linear", "lower", "higher"]) +@pytest.mark.parametrize("percentile", np.arange(0, 101, 10)) +def test_weighted_percentile_np_percentile(interpolation, percentile): + rng = np.random.RandomState(0) + X = rng.randn(10) + X.sort() + sample_weight = np.ones(X.shape) + + np_percentile = np.percentile(X, percentile, interpolation=interpolation) + sklearn_percentile = _weighted_percentile( + X, sample_weight, percentile=percentile, interpolation=interpolation, + ) + + assert sklearn_percentile == pytest.approx(np_percentile) + + +def test_weighted_percentile_wrong_interpolation(): + err_msg = "'interpolation' should be one of" + with pytest.raises(ValueError, match=err_msg): + X = np.random.randn(10) + sample_weight = np.ones(X.shape) + _weighted_percentile(X, sample_weight, 50, interpolation="xxxx") From 23af759e51001c405040d0d312fcf208ad167077 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 May 2020 00:54:26 +0200 Subject: [PATCH 06/46] revert setup.cfg --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 95e4417b816e1..f086993b26a29 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,7 @@ addopts = --ignore examples --ignore maint_tools --doctest-modules - # --disable-pytest-warnings + --disable-pytest-warnings -rxXs filterwarnings = From 3be7c090532f27e9c57806249291c95d44ed73f0 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 May 2020 11:08:36 +0200 Subject: [PATCH 07/46] iter --- sklearn/utils/stats.py | 8 +++++--- sklearn/utils/tests/test_stats.py | 5 ++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 19013e3b83470..0ee44778ea6de 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -91,9 +91,11 @@ def _squeeze_arr(arr, n_dim): elif interpolation in ("linear", "lower", "higher"): # compute by linear interpolation between closest ranks method - adjusted_percentile = (weight_cdf - 0.5 * sorted_weights) - with np.errstate(invalid="ignore"): - adjusted_percentile /= weight_cdf[-1] + # adjusted_percentile = (weight_cdf - 0.5 * sorted_weights) + # with np.errstate(invalid="ignore"): + # adjusted_percentile /= weight_cdf[-1] + adjusted_percentile = (weight_cdf - sorted_weights) + adjusted_percentile /= ((weight_cdf[-1] * (n_rows - 1)) / (n_rows)) if interpolation in ("lower", "higher"): percentile_idx = np.array([ diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index 60c60d510ce5a..b702921381d6b 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -5,7 +5,6 @@ from sklearn.utils.stats import _weighted_percentile -@pytest.mark.parametrize("interpolation", ["linear", "nearest"]) def test_weighted_percentile(interpolation): y = np.empty(102, dtype=np.float64) y[:50] = 0 @@ -14,7 +13,7 @@ def test_weighted_percentile(interpolation): y[50] = 1 sw = np.ones(102, dtype=np.float64) sw[-1] = 0.0 - score = _weighted_percentile(y, sw, 50, interpolation=interpolation) + score = _weighted_percentile(y, sw, 50, interpolation="nearest") assert score == pytest.approx(1) @@ -124,7 +123,7 @@ def test_weighted_percentile_np_percentile(interpolation, percentile): rng = np.random.RandomState(0) X = rng.randn(10) X.sort() - sample_weight = np.ones(X.shape) + sample_weight = np.ones(X.shape) / X.shape[0] np_percentile = np.percentile(X, percentile, interpolation=interpolation) sklearn_percentile = _weighted_percentile( From 06aeab1acaafd6e6f3f279a3a7e614f08672d4b9 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 May 2020 11:11:46 +0200 Subject: [PATCH 08/46] iter --- sklearn/utils/stats.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 0ee44778ea6de..e9e79e0cc4b7a 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -91,9 +91,7 @@ def _squeeze_arr(arr, n_dim): elif interpolation in ("linear", "lower", "higher"): # compute by linear interpolation between closest ranks method - # adjusted_percentile = (weight_cdf - 0.5 * sorted_weights) - # with np.errstate(invalid="ignore"): - # adjusted_percentile /= weight_cdf[-1] + # as in NumPy adjusted_percentile = (weight_cdf - sorted_weights) adjusted_percentile /= ((weight_cdf[-1] * (n_rows - 1)) / (n_rows)) From f3892923ae9ad242aab51af9f19f5722afc66b51 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 May 2020 11:27:00 +0200 Subject: [PATCH 09/46] iter --- sklearn/utils/stats.py | 7 ++++--- sklearn/utils/tests/test_stats.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index e9e79e0cc4b7a..2a522ea6c6fc1 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -65,7 +65,7 @@ def _weighted_percentile(array, sample_weight, percentile=50, sorted_idx = np.argsort(array, axis=0) sorted_weights = _take_along_axis(sample_weight, sorted_idx, axis=0) - percentile = [percentile / 100] * n_cols + percentile = np.array([percentile / 100] * n_cols) weight_cdf = stable_cumsum(sorted_weights, axis=0) def _squeeze_arr(arr, n_dim): @@ -93,7 +93,8 @@ def _squeeze_arr(arr, n_dim): # compute by linear interpolation between closest ranks method # as in NumPy adjusted_percentile = (weight_cdf - sorted_weights) - adjusted_percentile /= ((weight_cdf[-1] * (n_rows - 1)) / (n_rows)) + with np.errstate(invalid="ignore"): + adjusted_percentile /= ((weight_cdf[-1] * (n_rows - 1)) / (n_rows)) if interpolation in ("lower", "higher"): percentile_idx = np.array([ @@ -101,7 +102,7 @@ def _squeeze_arr(arr, n_dim): side="left") for col in range(adjusted_percentile.shape[1]) ]) - if interpolation == "lower": + if interpolation == "lower" and np.all(percentile < 1): percentile_idx -= 1 percentile_idx = np.apply_along_axis( lambda x: np.clip(x, 0, n_rows - 1), axis=0, diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index b702921381d6b..177cf1687168a 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -5,7 +5,7 @@ from sklearn.utils.stats import _weighted_percentile -def test_weighted_percentile(interpolation): +def test_weighted_percentile(): y = np.empty(102, dtype=np.float64) y[:50] = 0 y[-51:] = 2 @@ -118,7 +118,7 @@ def test_weighted_percentile_np_median(): @pytest.mark.parametrize("interpolation", ["linear", "lower", "higher"]) -@pytest.mark.parametrize("percentile", np.arange(0, 101, 10)) +@pytest.mark.parametrize("percentile", np.arange(0, 101, 2.5)) def test_weighted_percentile_np_percentile(interpolation, percentile): rng = np.random.RandomState(0) X = rng.randn(10) From 9e1222f4b6defda30b73c0fab06ce9ac655517c9 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 May 2020 11:40:11 +0200 Subject: [PATCH 10/46] iter --- sklearn/utils/stats.py | 10 ++++++++-- sklearn/utils/tests/test_stats.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 2a522ea6c6fc1..24872bed55e9e 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -26,7 +26,7 @@ def _weighted_percentile(array, sample_weight, percentile=50, percentile: int, default=50 Percentile to compute. Must be value between 0 and 100. - interpolation : {"linear", "lower", "higher"}, default="linear" + interpolation : {"linear", "lower", "higher", "nearest"}, default="linear" This optional parameter specifies the interpolation method to use when the desired percentile lies between two data points ``i < j``: @@ -73,6 +73,9 @@ def _squeeze_arr(arr, n_dim): if interpolation == "nearest": # compute by nearest-rank method + # The rank correspond to P / 100 * N; P in [0, 100] + # Here, N is replaced by the sum of the weight and P will be adjusted + # by the weights adjusted_percentile = percentile * weight_cdf[-1] percentile_value_idx = np.array([ np.searchsorted(weight_cdf[:, i], adjusted_percentile[i]) @@ -91,7 +94,9 @@ def _squeeze_arr(arr, n_dim): elif interpolation in ("linear", "lower", "higher"): # compute by linear interpolation between closest ranks method - # as in NumPy + # NumPy uses the variant p = (x - 1) / (N - 1); x in [1, N] + # Here, N is replaced by the sum of the weights and (1) is taking into + # account the weights. adjusted_percentile = (weight_cdf - sorted_weights) with np.errstate(invalid="ignore"): adjusted_percentile /= ((weight_cdf[-1] * (n_rows - 1)) / (n_rows)) @@ -103,6 +108,7 @@ def _squeeze_arr(arr, n_dim): for col in range(adjusted_percentile.shape[1]) ]) if interpolation == "lower" and np.all(percentile < 1): + # P = 100 is a corner case for "lower" percentile_idx -= 1 percentile_idx = np.apply_along_axis( lambda x: np.clip(x, 0, n_rows - 1), axis=0, diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index 177cf1687168a..b3052e9a01f0c 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -123,7 +123,7 @@ def test_weighted_percentile_np_percentile(interpolation, percentile): rng = np.random.RandomState(0) X = rng.randn(10) X.sort() - sample_weight = np.ones(X.shape) / X.shape[0] + sample_weight = np.ones(X.shape) np_percentile = np.percentile(X, percentile, interpolation=interpolation) sklearn_percentile = _weighted_percentile( From 9314aee394c6e143e0902b6a19b68719feea14aa Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 May 2020 12:42:15 +0200 Subject: [PATCH 11/46] iter --- sklearn/metrics/tests/test_score_objects.py | 59 +++++++++-------- sklearn/utils/stats.py | 72 ++++++--------------- sklearn/utils/tests/test_stats.py | 24 +++---- 3 files changed, 61 insertions(+), 94 deletions(-) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index f49197a706e70..ae72474dee110 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -520,44 +520,49 @@ def test_classification_scorer_sample_weight(): f"with sample weights: {str(e)}") -@ignore_warnings -def test_regression_scorer_sample_weight(): - # Test that regression scorers support sample_weight or raise sensible - # errors - +@pytest.fixture +def regressor_and_data(): # Odd number of test samples req for neg_median_absolute_error X, y = make_regression(n_samples=101, n_features=20, random_state=0) y = _require_positive_y(y) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - sample_weight = np.ones_like(y_test) + sample_weight_test = np.ones_like(y_test) # Odd number req for neg_median_absolute_error - sample_weight[:11] = 0 + sample_weight_test[:11] = 0 reg = DecisionTreeRegressor(random_state=0) reg.fit(X_train, y_train) - for name, scorer in SCORERS.items(): - if name not in REGRESSION_SCORERS: - # skip classification scorers - continue - try: - weighted = scorer(reg, X_test, y_test, - sample_weight=sample_weight) - ignored = scorer(reg, X_test[11:], y_test[11:]) - unweighted = scorer(reg, X_test, y_test) - assert weighted != unweighted, ( - f"scorer {name} behaves identically when called with " - f"sample weights: {weighted} vs {unweighted}") - assert_almost_equal(weighted, ignored, - err_msg=f"scorer {name} behaves differently " - f"when ignoring samples and setting " - f"sample_weight to 0: {weighted} vs {ignored}") + return reg, X_test, y_test, sample_weight_test - except TypeError as e: - assert "sample_weight" in str(e), ( - f"scorer {name} raises unhelpful exception when called " - f"with sample weights: {str(e)}") + +@ignore_warnings +@pytest.mark.parametrize("name, scorer", SCORERS.items()) +def test_regression_scorer_sample_weight(regressor_and_data, name, scorer): + # Test that regression scorers support sample_weight or raise sensible + # errors + reg, X_test, y_test, sample_weight = regressor_and_data + + if name not in REGRESSION_SCORERS: + # skip classification scorers + return + try: + weighted = scorer(reg, X_test, y_test, sample_weight=sample_weight) + ignored = scorer(reg, X_test[11:], y_test[11:]) + unweighted = scorer(reg, X_test, y_test) + assert weighted != unweighted, ( + f"scorer {name} behaves identically when called with " + f"sample weights: {weighted} vs {unweighted}") + assert_almost_equal(weighted, ignored, + err_msg=f"scorer {name} behaves differently " + f"when ignoring samples and setting " + f"sample_weight to 0: {weighted} vs {ignored}") + + except TypeError as e: + assert "sample_weight" in str(e), ( + f"scorer {name} raises unhelpful exception when called " + f"with sample weights: {str(e)}") @pytest.mark.parametrize('name', SCORERS) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 24872bed55e9e..b175ff026f89a 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -26,23 +26,15 @@ def _weighted_percentile(array, sample_weight, percentile=50, percentile: int, default=50 Percentile to compute. Must be value between 0 and 100. - interpolation : {"linear", "lower", "higher", "nearest"}, default="linear" - This optional parameter specifies the interpolation method to - use when the desired percentile lies between two data points - ``i < j``: - * 'linear': ``i + (j - i) * fraction``, where ``fraction`` - is the fractional part of the index surrounded by ``i`` - and ``j``. - * 'lower': ``i``. - * 'higher': ``j``. - * 'nearest': ``i`` or ``j``, whichever is nearest. + interpolation : {"linear", "lower", "higher"}, default="linear" + FIXME Returns ------- percentile_value : int if `array` 1D, ndarray if `array` 2D Weighted percentile. """ - possible_interpolation = ("linear", "lower", "higher", "nearest") + possible_interpolation = ("linear", "lower", "higher") if interpolation not in possible_interpolation: raise ValueError( f"'interpolation' should be one of " @@ -71,60 +63,38 @@ def _weighted_percentile(array, sample_weight, percentile=50, def _squeeze_arr(arr, n_dim): return arr[0] if n_dim == 1 else arr - if interpolation == "nearest": - # compute by nearest-rank method - # The rank correspond to P / 100 * N; P in [0, 100] - # Here, N is replaced by the sum of the weight and P will be adjusted - # by the weights - adjusted_percentile = percentile * weight_cdf[-1] - percentile_value_idx = np.array([ - np.searchsorted(weight_cdf[:, i], adjusted_percentile[i]) - for i in range(weight_cdf.shape[1]) + adjusted_percentile = (weight_cdf - sorted_weights) + with np.errstate(invalid="ignore"): + adjusted_percentile /= ((weight_cdf[-1] * (n_rows - 1)) / (n_rows)) + + if interpolation in ("lower", "higher"): + percentile_idx = np.array([ + np.searchsorted(adjusted_percentile[:, col], percentile[col], + side="left") + for col in range(n_cols) ]) - percentile_value_idx = np.apply_along_axis( + if interpolation == "lower" and np.all(percentile < 1): + # P = 100 is a corner case for "lower" + percentile_idx -= 1 + + percentile_idx = np.apply_along_axis( lambda x: np.clip(x, 0, n_rows - 1), axis=0, - arr=percentile_value_idx + arr=percentile_idx ) percentile_value = array[ - sorted_idx[percentile_value_idx, np.arange(n_cols)], + sorted_idx[percentile_idx, np.arange(n_cols)], np.arange(n_cols) ] return _squeeze_arr(percentile_value, n_dim) - elif interpolation in ("linear", "lower", "higher"): - # compute by linear interpolation between closest ranks method - # NumPy uses the variant p = (x - 1) / (N - 1); x in [1, N] - # Here, N is replaced by the sum of the weights and (1) is taking into - # account the weights. - adjusted_percentile = (weight_cdf - sorted_weights) - with np.errstate(invalid="ignore"): - adjusted_percentile /= ((weight_cdf[-1] * (n_rows - 1)) / (n_rows)) - - if interpolation in ("lower", "higher"): - percentile_idx = np.array([ - np.searchsorted(adjusted_percentile[:, col], percentile[col], - side="left") - for col in range(adjusted_percentile.shape[1]) - ]) - if interpolation == "lower" and np.all(percentile < 1): - # P = 100 is a corner case for "lower" - percentile_idx -= 1 - percentile_idx = np.apply_along_axis( - lambda x: np.clip(x, 0, n_rows - 1), axis=0, - arr=percentile_idx - ) - percentile = np.nan_to_num([ - adjusted_percentile[percentile_idx[i], i] - for i in range(adjusted_percentile.shape[1]) - ], nan=percentile) - + else: # interpolation == "linear" percentile_value = np.array([ np.interp( x=percentile[col], xp=adjusted_percentile[:, col], fp=array[sorted_idx[:, col], col], ) - for col in range(array.shape[1]) + for col in range(n_cols) ]) return _squeeze_arr(percentile_value, n_dim) diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index b3052e9a01f0c..c10ce3e074bb1 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -13,13 +13,11 @@ def test_weighted_percentile(): y[50] = 1 sw = np.ones(102, dtype=np.float64) sw[-1] = 0.0 - score = _weighted_percentile(y, sw, 50, interpolation="nearest") + score = _weighted_percentile(y, sw, 50, interpolation="lower") assert score == pytest.approx(1) -@pytest.mark.parametrize( - "interpolation", ["linear", "lower", "higher", "nearest"] -) +@pytest.mark.parametrize("interpolation", ["linear", "lower", "higher"]) def test_weighted_percentile_equal(interpolation): y = np.empty(102, dtype=np.float64) y.fill(0.0) @@ -29,9 +27,7 @@ def test_weighted_percentile_equal(interpolation): assert score == 0 -@pytest.mark.parametrize( - "interpolation", ["linear", "lower", "higher", "nearest"] -) +@pytest.mark.parametrize("interpolation", ["linear", "lower", "higher"]) def test_weighted_percentile_zero_weight(interpolation): y = np.empty(102, dtype=np.float64) y.fill(1.0) @@ -41,8 +37,7 @@ def test_weighted_percentile_zero_weight(interpolation): assert pytest.approx(score) == 1.0 -@pytest.mark.parametrize("interpolation", ["linear", "nearest"]) -def test_weighted_median_equal_weights(interpolation): +def test_weighted_median_equal_weights(): # Checks weighted percentile=0.5 is same as median when weights equal rng = np.random.RandomState(0) # Odd size as _weighted_percentile takes lower weighted percentile @@ -50,12 +45,11 @@ def test_weighted_median_equal_weights(interpolation): weights = np.ones(x.shape) median = np.median(x) - w_median = _weighted_percentile(x, weights, interpolation=interpolation) + w_median = _weighted_percentile(x, weights) assert median == pytest.approx(w_median) -@pytest.mark.parametrize("interpolation", ["linear", "nearest"]) -def test_weighted_median_integer_weights(interpolation): +def test_weighted_median_integer_weights(): # Checks weighted percentile=0.5 is same as median when manually weight # data rng = np.random.RandomState(0) @@ -64,14 +58,12 @@ def test_weighted_median_integer_weights(interpolation): x_manual = np.repeat(x, weights) median = np.median(x_manual) - w_median = _weighted_percentile(x, weights, interpolation=interpolation) + w_median = _weighted_percentile(x, weights) assert median == pytest.approx(w_median) -@pytest.mark.parametrize( - "interpolation", ["linear", "lower", "higher", "nearest"] -) +@pytest.mark.parametrize("interpolation", ["linear", "lower", "higher"]) def test_weighted_percentile_2d(interpolation): # Check for when array 2D and sample_weight 1D rng = np.random.RandomState(0) From 0e857a96df97f8a2f97020584b343bed6b612d64 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 May 2020 13:10:27 +0200 Subject: [PATCH 12/46] iter --- sklearn/utils/stats.py | 3 ++- sklearn/utils/tests/test_stats.py | 10 +++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index b175ff026f89a..b4e70ed6c9bc2 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -59,13 +59,14 @@ def _weighted_percentile(array, sample_weight, percentile=50, sorted_weights = _take_along_axis(sample_weight, sorted_idx, axis=0) percentile = np.array([percentile / 100] * n_cols) weight_cdf = stable_cumsum(sorted_weights, axis=0) + non_zero = np.count_nonzero(sorted_weights, axis=0) def _squeeze_arr(arr, n_dim): return arr[0] if n_dim == 1 else arr adjusted_percentile = (weight_cdf - sorted_weights) with np.errstate(invalid="ignore"): - adjusted_percentile /= ((weight_cdf[-1] * (n_rows - 1)) / (n_rows)) + adjusted_percentile /= ((weight_cdf[-1] * (non_zero - 1)) / (non_zero)) if interpolation in ("lower", "higher"): percentile_idx = np.array([ diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index c10ce3e074bb1..6ee8aaae3c19f 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -5,7 +5,11 @@ from sklearn.utils.stats import _weighted_percentile -def test_weighted_percentile(): +@pytest.mark.parametrize( + "interpolation, expected_median", + [("lower", 0), ("linear", 1), ("higher", 1)] +) +def test_weighted_percentile(interpolation, expected_median): y = np.empty(102, dtype=np.float64) y[:50] = 0 y[-51:] = 2 @@ -13,8 +17,8 @@ def test_weighted_percentile(): y[50] = 1 sw = np.ones(102, dtype=np.float64) sw[-1] = 0.0 - score = _weighted_percentile(y, sw, 50, interpolation="lower") - assert score == pytest.approx(1) + score = _weighted_percentile(y, sw, 50, interpolation=interpolation) + assert score == pytest.approx(expected_median) @pytest.mark.parametrize("interpolation", ["linear", "lower", "higher"]) From 5988234b2a3208c565081fafc02a301c503eafbd Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 May 2020 13:17:08 +0200 Subject: [PATCH 13/46] improve documentation --- sklearn/utils/stats.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index b4e70ed6c9bc2..c07d7016e0957 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -16,22 +16,31 @@ def _weighted_percentile(array, sample_weight, percentile=50, Parameters ---------- - array : 1D or 2D array + array : ndarray of shape (n,) or (n, m) Values to take the weighted percentile of. - sample_weight: 1D or 2D array + sample_weight: ndarray of (n,) or (n, m) Weights for each value in `array`. Must be same shape as `array` or of shape `(array.shape[0],)`. - percentile: int, default=50 + percentile: inr or float, default=50 Percentile to compute. Must be value between 0 and 100. interpolation : {"linear", "lower", "higher"}, default="linear" - FIXME + The interpolation method to use when the percentile lies between + data points `i` and `j`: + + * `"linear"`: `i + (j - i) * fraction`, where `fraction` is the + fractional part of the index surrounded by `i` and `j`; + * `"lower"`: i`; + * `"higher"`: `j`. + + .. versionadded: 0.24 Returns ------- - percentile_value : int if `array` 1D, ndarray if `array` 2D + percentile_value : float or int if `array` of shape (n,), otherwise\ + ndarray of shape (m,) Weighted percentile. """ possible_interpolation = ("linear", "lower", "higher") From 81008731571d2669be3dc7e5b38fdb5da9c0f2b6 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 May 2020 13:48:06 +0200 Subject: [PATCH 14/46] iter --- sklearn/utils/stats.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index c07d7016e0957..16bd359a1f8f4 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -92,6 +92,8 @@ def _squeeze_arr(arr, n_dim): lambda x: np.clip(x, 0, n_rows - 1), axis=0, arr=percentile_idx ) + percentile_idx = np.nan_to_num(percentile_idx, nan=0) + percentile_value = array[ sorted_idx[percentile_idx, np.arange(n_cols)], np.arange(n_cols) From cc4a17277a560af92141c99521e5d45f6cd23a9e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 May 2020 14:52:36 +0200 Subject: [PATCH 15/46] iter --- sklearn/utils/stats.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 16bd359a1f8f4..80785de4370ee 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -109,4 +109,7 @@ def _squeeze_arr(arr, n_dim): ) for col in range(n_cols) ]) + + nan_value = np.isnan(percentile_value) + percentile_value[nan_value] = array[0, nan_value] return _squeeze_arr(percentile_value, n_dim) From 4e100a93678b7ff81a760578ee444135fa5c4321 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 May 2020 15:04:22 +0200 Subject: [PATCH 16/46] parametrize debug --- sklearn/ensemble/tests/test_gradient_boosting.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 5461258887054..64374c670453f 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -1133,7 +1133,8 @@ def test_probability_exponential(): assert_array_equal(y_pred, true_result) -def test_non_uniform_weights_toy_edge_case_reg(): +@pytest.mark.parametrize("loss", ['huber', 'ls', 'lad', 'quantile']) +def test_non_uniform_weights_toy_edge_case_reg(loss): X = [[1, 0], [1, 0], [1, 0], @@ -1141,11 +1142,11 @@ def test_non_uniform_weights_toy_edge_case_reg(): y = [0, 0, 1, 0] # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] - for loss in ('huber', 'ls', 'lad', 'quantile'): - gb = GradientBoostingRegressor(learning_rate=1.0, n_estimators=2, - loss=loss) - gb.fit(X, y, sample_weight=sample_weight) - assert gb.predict([[1, 0]])[0] > 0.5 + gb = GradientBoostingRegressor( + learning_rate=1.0, n_estimators=2, loss=loss, + ) + gb.fit(X, y, sample_weight=sample_weight) + assert gb.predict([[1, 0]])[0] > 0.5 def test_non_uniform_weights_toy_edge_case_clf(): From 25e5d2436a05cf1b9ed11581e9dfea4cc85aefe9 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 1 Jun 2020 10:41:23 +0200 Subject: [PATCH 17/46] iter --- .../ensemble/tests/test_gradient_boosting.py | 4 +++- sklearn/utils/stats.py | 3 +-- sklearn/utils/tests/test_stats.py | 17 +++++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 64374c670453f..223c4f1f95382 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -1143,9 +1143,11 @@ def test_non_uniform_weights_toy_edge_case_reg(loss): # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] gb = GradientBoostingRegressor( - learning_rate=1.0, n_estimators=2, loss=loss, + learning_rate=0.01, n_estimators=10, loss=loss, ) gb.fit(X, y, sample_weight=sample_weight) + print(gb.predict([[1, 0]])) + print(gb.train_score_) assert gb.predict([[1, 0]])[0] > 0.5 diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 80785de4370ee..94399c9f708d2 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -68,14 +68,13 @@ def _weighted_percentile(array, sample_weight, percentile=50, sorted_weights = _take_along_axis(sample_weight, sorted_idx, axis=0) percentile = np.array([percentile / 100] * n_cols) weight_cdf = stable_cumsum(sorted_weights, axis=0) - non_zero = np.count_nonzero(sorted_weights, axis=0) def _squeeze_arr(arr, n_dim): return arr[0] if n_dim == 1 else arr adjusted_percentile = (weight_cdf - sorted_weights) with np.errstate(invalid="ignore"): - adjusted_percentile /= ((weight_cdf[-1] * (non_zero - 1)) / (non_zero)) + adjusted_percentile /= weight_cdf[-1] - sorted_weights if interpolation in ("lower", "higher"): percentile_idx = np.array([ diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index 6ee8aaae3c19f..72f4e21ae8f4f 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -135,3 +135,20 @@ def test_weighted_percentile_wrong_interpolation(): X = np.random.randn(10) sample_weight = np.ones(X.shape) _weighted_percentile(X, sample_weight, 50, interpolation="xxxx") + + +@pytest.mark.parametrize("percentile", np.arange(2.5, 100, 2.5)) +def test_weighted_percentile_non_unit_weight(percentile): + # check the cumulative sum of the weight on the left and right side of the + # percentile + rng = np.random.RandomState(42) + X = rng.randn(1000) + X.sort() + sample_weight = rng.random(X.shape) + sample_weight = sample_weight / sample_weight.sum() + sample_weight *= 100 + + percentile_value = _weighted_percentile(X, sample_weight, percentile) + X_percentile_idx = np.searchsorted(X, percentile_value) + assert sample_weight[:X_percentile_idx - 1].sum() < percentile + assert sample_weight[:X_percentile_idx + 1].sum() > percentile From ff2a6e009e4354e56b3da6ec35932fc53a48f26c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 2 Jun 2020 14:11:21 +0200 Subject: [PATCH 18/46] case we have a single weight non null --- sklearn/utils/stats.py | 24 ++++++++++++++++++++---- sklearn/utils/tests/test_stats.py | 13 +++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 94399c9f708d2..c06a8c0616c08 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -1,3 +1,5 @@ +from collections.abc import Iterable + import numpy as np from .extmath import stable_cumsum @@ -75,6 +77,7 @@ def _squeeze_arr(arr, n_dim): adjusted_percentile = (weight_cdf - sorted_weights) with np.errstate(invalid="ignore"): adjusted_percentile /= weight_cdf[-1] - sorted_weights + adjusted_percentile = np.nan_to_num(adjusted_percentile, nan=1) if interpolation in ("lower", "higher"): percentile_idx = np.array([ @@ -97,7 +100,7 @@ def _squeeze_arr(arr, n_dim): sorted_idx[percentile_idx, np.arange(n_cols)], np.arange(n_cols) ] - return _squeeze_arr(percentile_value, n_dim) + percentile_value = _squeeze_arr(percentile_value, n_dim) else: # interpolation == "linear" percentile_value = np.array([ @@ -109,6 +112,19 @@ def _squeeze_arr(arr, n_dim): for col in range(n_cols) ]) - nan_value = np.isnan(percentile_value) - percentile_value[nan_value] = array[0, nan_value] - return _squeeze_arr(percentile_value, n_dim) + percentile_value = _squeeze_arr(percentile_value, n_dim) + + single_sample_weight = np.count_nonzero(sample_weight, axis=0) + if np.any(single_sample_weight == 1): + if not isinstance(percentile_value, Iterable): + percentile_value = _squeeze_arr( + array[np.nonzero(sample_weight)], n_dim + ) + else: + percentile_value = np.array([ + array[np.flatnonzero(sample_weight[:, col])[0], col] + if n_nonzero == 1 else percentile_value[col] + for col, n_nonzero in enumerate(single_sample_weight) + ]) + + return percentile_value diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index 72f4e21ae8f4f..7a7ba68c30113 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -152,3 +152,16 @@ def test_weighted_percentile_non_unit_weight(percentile): X_percentile_idx = np.searchsorted(X, percentile_value) assert sample_weight[:X_percentile_idx - 1].sum() < percentile assert sample_weight[:X_percentile_idx + 1].sum() > percentile + + +@pytest.mark.parametrize("n_features", [None, 2]) +def test_weighted_percentile_single_weight(n_features): + rng = np.random.RandomState(42) + X = rng.randn(10) if n_features is None else rng.randn(10, n_features) + X.sort(axis=0) + sample_weight = np.zeros(X.shape) + pos_weight_idx = 4 + sample_weight[pos_weight_idx] = 1 + + percentile_value = _weighted_percentile(X, sample_weight, 50) + assert percentile_value == pytest.approx(X[pos_weight_idx]) From 201f0c7150bc7f64a9f815123f5bb80e8c9f20cb Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 2 Jun 2020 14:15:38 +0200 Subject: [PATCH 19/46] update test --- sklearn/ensemble/tests/test_gradient_boosting.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 223c4f1f95382..f463192c4c22a 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -1135,20 +1135,15 @@ def test_probability_exponential(): @pytest.mark.parametrize("loss", ['huber', 'ls', 'lad', 'quantile']) def test_non_uniform_weights_toy_edge_case_reg(loss): - X = [[1, 0], - [1, 0], - [1, 0], - [0, 1]] + X = [[0], [0], [1], [0]] y = [0, 0, 1, 0] # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] gb = GradientBoostingRegressor( - learning_rate=0.01, n_estimators=10, loss=loss, + learning_rate=0.01, n_estimators=5, loss=loss, ) gb.fit(X, y, sample_weight=sample_weight) - print(gb.predict([[1, 0]])) - print(gb.train_score_) - assert gb.predict([[1, 0]])[0] > 0.5 + assert gb.predict([[1]])[0] > 0.5 def test_non_uniform_weights_toy_edge_case_clf(): From d8a4a73ab51015116566fd9986fb713de18889a1 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 2 Jun 2020 14:19:00 +0200 Subject: [PATCH 20/46] compat old numpy --- sklearn/utils/stats.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index c06a8c0616c08..80e0e0a653ee0 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -77,7 +77,8 @@ def _squeeze_arr(arr, n_dim): adjusted_percentile = (weight_cdf - sorted_weights) with np.errstate(invalid="ignore"): adjusted_percentile /= weight_cdf[-1] - sorted_weights - adjusted_percentile = np.nan_to_num(adjusted_percentile, nan=1) + nan_mask = np.isnan(adjusted_percentile) + adjusted_percentile[nan_mask] = 1 if interpolation in ("lower", "higher"): percentile_idx = np.array([ From 450f4c810abd21741de933dd61cc38bb49260305 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 2 Jun 2020 14:22:42 +0200 Subject: [PATCH 21/46] iter --- sklearn/utils/stats.py | 2 ++ sklearn/utils/tests/test_stats.py | 9 +++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 80e0e0a653ee0..fd18f1a2cc20b 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -117,6 +117,8 @@ def _squeeze_arr(arr, n_dim): single_sample_weight = np.count_nonzero(sample_weight, axis=0) if np.any(single_sample_weight == 1): + # edge case where a single weight is non-null in which case the + # previous methods will fail if not isinstance(percentile_value, Iterable): percentile_value = _squeeze_arr( array[np.nonzero(sample_weight)], n_dim diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index 7a7ba68c30113..7ec2c9f46716e 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -155,7 +155,10 @@ def test_weighted_percentile_non_unit_weight(percentile): @pytest.mark.parametrize("n_features", [None, 2]) -def test_weighted_percentile_single_weight(n_features): +@pytest.mark.parametrize("interpolation", ["linear", "higher", "lower"]) +@pytest.mark.parametrize("percentile", np.arange(0, 101, 25)) +def test_weighted_percentile_single_weight(n_features, interpolation, + percentile): rng = np.random.RandomState(42) X = rng.randn(10) if n_features is None else rng.randn(10, n_features) X.sort(axis=0) @@ -163,5 +166,7 @@ def test_weighted_percentile_single_weight(n_features): pos_weight_idx = 4 sample_weight[pos_weight_idx] = 1 - percentile_value = _weighted_percentile(X, sample_weight, 50) + percentile_value = _weighted_percentile( + X, sample_weight, percentile=percentile, interpolation=interpolation + ) assert percentile_value == pytest.approx(X[pos_weight_idx]) From 19588fbb8c838158df0a279425764d7eccbf7476 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 2 Jun 2020 14:30:09 +0200 Subject: [PATCH 22/46] loss decreasing assert --- sklearn/ensemble/tests/test_gradient_boosting.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index f463192c4c22a..702b75dfa261e 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -1140,10 +1140,12 @@ def test_non_uniform_weights_toy_edge_case_reg(loss): # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] gb = GradientBoostingRegressor( - learning_rate=0.01, n_estimators=5, loss=loss, + learning_rate=0.01, n_estimators=50, loss=loss, ) gb.fit(X, y, sample_weight=sample_weight) assert gb.predict([[1]])[0] > 0.5 + # check that the loss is always decreasing + assert np.all(np.diff(gb.train_score_) <= 0) def test_non_uniform_weights_toy_edge_case_clf(): From 2d3d9fb65047ae0184c68e852efe6294a2926f26 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 2 Jun 2020 14:39:06 +0200 Subject: [PATCH 23/46] iter --- sklearn/utils/stats.py | 5 +++++ sklearn/utils/tests/test_stats.py | 11 +++++++++++ 2 files changed, 16 insertions(+) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index fd18f1a2cc20b..a9e3bc0ad6904 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -53,6 +53,11 @@ def _weighted_percentile(array, sample_weight, percentile=50, f"instead." ) + if np.any(np.count_nonzero(sample_weight, axis=0) < 1): + raise ValueError( + "All weights cannot be null when computing a weighted percentile." + ) + n_dim = array.ndim if n_dim == 0: return array[()] diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index 7ec2c9f46716e..51e0ad4c6d297 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -170,3 +170,14 @@ def test_weighted_percentile_single_weight(n_features, interpolation, X, sample_weight, percentile=percentile, interpolation=interpolation ) assert percentile_value == pytest.approx(X[pos_weight_idx]) + + +@pytest.mark.parametrize("n_features", [None, 2]) +def test_weighted_percentile_all_null_weight(n_features): + rng = np.random.RandomState(42) + X = rng.randn(10) if n_features is None else rng.randn(10, n_features) + sample_weight = np.zeros(X.shape) + + err_msg = "All weights cannot be null when computing a weighted percentile" + with pytest.raises(ValueError, match=err_msg): + _weighted_percentile(X, sample_weight, 50) From 89e8cccaf02d6d2b6006df73097b04c9bbd3acf0 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 2 Jun 2020 15:31:25 +0200 Subject: [PATCH 24/46] remove a test --- sklearn/utils/tests/test_stats.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index 51e0ad4c6d297..c3df215fc1d87 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -31,16 +31,6 @@ def test_weighted_percentile_equal(interpolation): assert score == 0 -@pytest.mark.parametrize("interpolation", ["linear", "lower", "higher"]) -def test_weighted_percentile_zero_weight(interpolation): - y = np.empty(102, dtype=np.float64) - y.fill(1.0) - sw = np.ones(102, dtype=np.float64) - sw.fill(0.0) - score = _weighted_percentile(y, sw, 50, interpolation=interpolation) - assert pytest.approx(score) == 1.0 - - def test_weighted_median_equal_weights(): # Checks weighted percentile=0.5 is same as median when weights equal rng = np.random.RandomState(0) From ef9d882f700541470ef9421500f82151c8af92ee Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 2 Jun 2020 15:48:45 +0200 Subject: [PATCH 25/46] iter --- sklearn/utils/stats.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index a9e3bc0ad6904..100c969d99072 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -100,7 +100,6 @@ def _squeeze_arr(arr, n_dim): lambda x: np.clip(x, 0, n_rows - 1), axis=0, arr=percentile_idx ) - percentile_idx = np.nan_to_num(percentile_idx, nan=0) percentile_value = array[ sorted_idx[percentile_idx, np.arange(n_cols)], From 84d782ef4b8524f9677ca44166239636f235d867 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 2 Jun 2020 16:15:30 +0200 Subject: [PATCH 26/46] tst old numpy --- sklearn/utils/tests/test_stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index c3df215fc1d87..99ccae5a95f3e 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -134,7 +134,7 @@ def test_weighted_percentile_non_unit_weight(percentile): rng = np.random.RandomState(42) X = rng.randn(1000) X.sort() - sample_weight = rng.random(X.shape) + sample_weight = rng.uniform(1, 30, X.shape) sample_weight = sample_weight / sample_weight.sum() sample_weight *= 100 From 882a354a55cf9326a36c1d8aa4ff75c4d0eeee75 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 2 Jun 2020 16:57:06 +0200 Subject: [PATCH 27/46] iter --- sklearn/ensemble/tests/test_gradient_boosting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 702b75dfa261e..fa7252aa32488 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -1135,7 +1135,7 @@ def test_probability_exponential(): @pytest.mark.parametrize("loss", ['huber', 'ls', 'lad', 'quantile']) def test_non_uniform_weights_toy_edge_case_reg(loss): - X = [[0], [0], [1], [0]] + X = [[1], [1], [1], [0]] y = [0, 0, 1, 0] # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] From 79b719f8475e97c9ba03df2b434c1915c0dc439d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 4 Jun 2020 11:12:11 +0200 Subject: [PATCH 28/46] TST add to check the equivalence repeated/weights --- sklearn/utils/tests/test_stats.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index 99ccae5a95f3e..2995350e21b9f 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -171,3 +171,22 @@ def test_weighted_percentile_all_null_weight(n_features): err_msg = "All weights cannot be null when computing a weighted percentile" with pytest.raises(ValueError, match=err_msg): _weighted_percentile(X, sample_weight, 50) + + +@pytest.mark.parametrize("percentile", np.arange(0, 101, 25)) +def test_weighted_percentile_equivalence_weights_repeated_samples(percentile): + X_repeated = np.array([1, 2, 2, 3, 3, 3, 4, 4]) + sample_weight_unit = np.ones(X_repeated.shape[0]) + p_npy_repeated = np.percentile(X_repeated, percentile) + p_sklearn_repeated = _weighted_percentile( + X_repeated, sample_weight_unit, percentile + ) + + assert p_sklearn_repeated == pytest.approx(p_npy_repeated) + + X = np.array([1, 2, 3, 4]) + sample_weight = np.array([1, 2, 3, 2]) + p_sklearn_weighted = _weighted_percentile(X, sample_weight, percentile) + + assert p_sklearn_weighted == pytest.approx(p_npy_repeated) + assert p_sklearn_weighted == pytest.approx(p_sklearn_repeated) From babd758e99b7ecf29a11da624d1b2d1e7669eeaa Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 4 Jun 2020 11:17:41 +0200 Subject: [PATCH 29/46] try all interpolation --- sklearn/utils/tests/test_stats.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index 2995350e21b9f..9119b19050630 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -173,20 +173,26 @@ def test_weighted_percentile_all_null_weight(n_features): _weighted_percentile(X, sample_weight, 50) -@pytest.mark.parametrize("percentile", np.arange(0, 101, 25)) -def test_weighted_percentile_equivalence_weights_repeated_samples(percentile): +@pytest.mark.parametrize("interpolation", ["linear", "lower", "higher"]) +@pytest.mark.parametrize("percentile", [0, 25, 50, 75, 100]) +def test_weighted_percentile_equivalence_weights_repeated_samples( + interpolation, percentile, +): X_repeated = np.array([1, 2, 2, 3, 3, 3, 4, 4]) sample_weight_unit = np.ones(X_repeated.shape[0]) p_npy_repeated = np.percentile(X_repeated, percentile) p_sklearn_repeated = _weighted_percentile( - X_repeated, sample_weight_unit, percentile + X_repeated, sample_weight_unit, percentile, + interpolation=interpolation, ) assert p_sklearn_repeated == pytest.approx(p_npy_repeated) X = np.array([1, 2, 3, 4]) sample_weight = np.array([1, 2, 3, 2]) - p_sklearn_weighted = _weighted_percentile(X, sample_weight, percentile) + p_sklearn_weighted = _weighted_percentile( + X, sample_weight, percentile, interpolation=interpolation, + ) assert p_sklearn_weighted == pytest.approx(p_npy_repeated) assert p_sklearn_weighted == pytest.approx(p_sklearn_repeated) From c26403e4bd4e9ba2d3ac576a3d67a69c6e520f8e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 4 Jun 2020 11:41:43 +0200 Subject: [PATCH 30/46] add comments on method --- sklearn/utils/stats.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 100c969d99072..d5166af9da8a0 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -74,14 +74,24 @@ def _weighted_percentile(array, sample_weight, percentile=50, sorted_idx = np.argsort(array, axis=0) sorted_weights = _take_along_axis(sample_weight, sorted_idx, axis=0) percentile = np.array([percentile / 100] * n_cols) - weight_cdf = stable_cumsum(sorted_weights, axis=0) + cum_weigths = stable_cumsum(sorted_weights, axis=0) def _squeeze_arr(arr, n_dim): return arr[0] if n_dim == 1 else arr - adjusted_percentile = (weight_cdf - sorted_weights) + # Percentile can be computed with 3 different alternative: + # https://en.wikipedia.org/wiki/Percentile + # These 3 alternatives depend of the value of a parameter C. NumPy uses + # the variant where C=0 which allows to obtained a strictly monotically + # increasing function which is defined as: + # P = (x - 1) / (N - 1); x in [1, N] + # Weighted percentile change this formula by taking into account the + # weights instead of the data frequency. + # P_w = (x - w) / (S_w - w), x in [1, N], w being the weight and S_n being + # the sum of the weights. + adjusted_percentile = (cum_weigths - sorted_weights) with np.errstate(invalid="ignore"): - adjusted_percentile /= weight_cdf[-1] - sorted_weights + adjusted_percentile /= cum_weigths[-1] - sorted_weights nan_mask = np.isnan(adjusted_percentile) adjusted_percentile[nan_mask] = 1 From e2692475ff28aefc2d9e972705d5270b19d66959 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 17 Jun 2020 11:42:26 +0200 Subject: [PATCH 31/46] wip --- sklearn/ensemble/tests/test_gradient_boosting.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 1504384bf58ef..4f86f1a9536fc 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -53,6 +53,7 @@ rng = np.random.RandomState(0) # also load the diabetes dataset # and randomly permute it +X_reg, y_reg = make_regression(n_samples=500, n_features=10, noise=4) diabetes = datasets.load_diabetes() perm = rng.permutation(diabetes.target.size) diabetes.data = diabetes.data[perm] @@ -215,11 +216,7 @@ def test_classification_synthetic(loss): def check_california(loss, subsample): # Check consistency on dataset california house prices with least squares # and least absolute deviation. - california = datasets.fetch_california_housing() - perm = rng.permutation(500) - california.data = california.data[perm] - california.target = california.target[perm] - ones = np.ones(len(california.target)) + ones = np.ones(len(y_reg)) last_y_pred = None for sample_weight in None, ones, 2 * ones: clf = GradientBoostingRegressor(n_estimators=100, @@ -230,13 +227,12 @@ def check_california(loss, subsample): random_state=1) assert_raises(ValueError, clf.predict, california.data) - clf.fit(california.data, california.target, - sample_weight=sample_weight) - leaves = clf.apply(california.data) + clf.fit(X_reg, y_reg, sample_weight=sample_weight) + leaves = clf.apply(X_reg) assert leaves.shape == (500, 100) - y_pred = clf.predict(california.data) - mse = mean_squared_error(california.target, y_pred) + y_pred = clf.predict(X_reg) + mse = mean_squared_error(y_reg, y_pred) assert mse < 0.1 if last_y_pred is not None: From 70de4b92eca1078fdd412677d7f5d0c94a13f707 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 17 Jun 2020 12:35:22 +0200 Subject: [PATCH 32/46] use make regression --- .../ensemble/tests/test_gradient_boosting.py | 61 ++++++++++--------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 5273fdebda792..efa8c41ed9a0b 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -3,6 +3,7 @@ """ import warnings import numpy as np +from numpy.testing import assert_allclose from scipy.sparse import csr_matrix from scipy.sparse import csc_matrix @@ -18,7 +19,7 @@ from sklearn.ensemble import GradientBoostingClassifier from sklearn.ensemble import GradientBoostingRegressor from sklearn.ensemble._gradient_boosting import predict_stages -from sklearn.preprocessing import OneHotEncoder +from sklearn.preprocessing import OneHotEncoder, StandardScaler from sklearn.svm import LinearSVC from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split @@ -49,15 +50,13 @@ T = [[-1, -1], [2, 2], [3, 2]] true_result = [-1, 1, 1] -rng = np.random.RandomState(0) -# also load the diabetes dataset -# and randomly permute it -X_reg, y_reg = make_regression(n_samples=500, n_features=10, noise=4) -diabetes = datasets.load_diabetes() -perm = rng.permutation(diabetes.target.size) -diabetes.data = diabetes.data[perm] -diabetes.target = diabetes.target[perm] +# also make regression dataset +X_reg, y_reg = make_regression( + n_samples=500, n_features=10, n_informative=8, noise=10, random_state=7 +) +y_reg = StandardScaler().fit_transform(y_reg.reshape((-1, 1))) +rng = np.random.RandomState(0) # also load the iris dataset # and randomly permute it iris = datasets.load_iris() @@ -212,11 +211,12 @@ def test_classification_synthetic(loss): check_classification_synthetic(loss) -def check_california(loss, subsample): - # Check consistency on dataset california house prices with least squares +def check_regression_dataset(loss, subsample): + # Check consistency on regression dataset with least squares # and least absolute deviation. ones = np.ones(len(y_reg)) last_y_pred = None + from sklearn.metrics import r2_score for sample_weight in None, ones, 2 * ones: clf = GradientBoostingRegressor(n_estimators=100, loss=loss, @@ -225,17 +225,22 @@ def check_california(loss, subsample): min_samples_split=2, random_state=1) - assert_raises(ValueError, clf.predict, california.data) + assert_raises(ValueError, clf.predict, X_reg) clf.fit(X_reg, y_reg, sample_weight=sample_weight) leaves = clf.apply(X_reg) assert leaves.shape == (500, 100) y_pred = clf.predict(X_reg) mse = mean_squared_error(y_reg, y_pred) - assert mse < 0.1 + assert mse < 0.04 if last_y_pred is not None: - assert_array_almost_equal(last_y_pred, y_pred, decimal=0) + # FIXME: `rtol=65` is very permissive. This is due to the fact that + # GBRT with and without `sample_weight` do not use the same + # implementation of the median during the initialization with the + # `DummyRegressor`. In the future, we should make sure that both + # implementations should be the same. See PR #17377 for more. + assert_allclose(last_y_pred, y_pred, rtol=65) last_y_pred = y_pred @@ -243,8 +248,8 @@ def check_california(loss, subsample): @pytest.mark.network @pytest.mark.parametrize('loss', ('ls', 'lad', 'huber')) @pytest.mark.parametrize('subsample', (1.0, 0.5)) -def test_california(loss, subsample): - check_california(loss, subsample) +def test_regression_dataset(loss, subsample): + check_regression_dataset(loss, subsample) def check_iris(subsample, sample_weight): @@ -311,8 +316,8 @@ def test_regression_synthetic(): def test_feature_importances(): - X = np.array(diabetes.data, dtype=np.float32) - y = np.array(diabetes.target, dtype=np.float32) + X = np.array(X_reg, dtype=np.float32) + y = np.array(y_reg, dtype=np.float32) clf = GradientBoostingRegressor(n_estimators=100, max_depth=5, min_samples_split=2, random_state=1) @@ -599,14 +604,14 @@ def test_quantile_loss(): max_depth=4, alpha=0.5, random_state=7) - clf_quantile.fit(diabetes.data, diabetes.target) - y_quantile = clf_quantile.predict(diabetes.data) + clf_quantile.fit(X_reg, y_reg) + y_quantile = clf_quantile.predict(X_reg) clf_lad = GradientBoostingRegressor(n_estimators=100, loss='lad', max_depth=4, random_state=7) - clf_lad.fit(diabetes.data, diabetes.target) - y_lad = clf_lad.predict(diabetes.data) + clf_lad.fit(X_reg, y_reg) + y_lad = clf_lad.predict(X_reg) assert_array_almost_equal(y_quantile, y_lad, decimal=4) @@ -1013,7 +1018,7 @@ def test_complete_regression(): est = GradientBoostingRegressor(n_estimators=20, max_depth=None, random_state=1, max_leaf_nodes=k + 1) - est.fit(diabetes.data, diabetes.target) + est.fit(X_reg, y_reg) tree = est.estimators_[-1, 0].tree_ assert (tree.children_left[tree.children_left == TREE_LEAF].shape[0] == @@ -1025,14 +1030,14 @@ def test_zero_estimator_reg(): est = GradientBoostingRegressor(n_estimators=20, max_depth=1, random_state=1, init='zero') - est.fit(diabetes.data, diabetes.target) - y_pred = est.predict(diabetes.data) - mse = mean_squared_error(diabetes.target, y_pred) - assert_almost_equal(mse, 3664.0, decimal=0) + est.fit(X_reg, y_reg) + y_pred = est.predict(X_reg) + mse = mean_squared_error(y_reg, y_pred) + assert_almost_equal(mse, 0.52, decimal=2) est = GradientBoostingRegressor(n_estimators=20, max_depth=1, random_state=1, init='foobar') - assert_raises(ValueError, est.fit, diabetes.data, diabetes.target) + assert_raises(ValueError, est.fit, X_reg, y_reg) def test_zero_estimator_clf(): From 65f11e9ddc8a2c9840720db580fce9b968b3c518 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 17 Jun 2020 13:01:28 +0200 Subject: [PATCH 33/46] fix lint --- sklearn/ensemble/tests/test_gradient_boosting.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index efa8c41ed9a0b..0da7b3c221a6a 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -216,7 +216,6 @@ def check_regression_dataset(loss, subsample): # and least absolute deviation. ones = np.ones(len(y_reg)) last_y_pred = None - from sklearn.metrics import r2_score for sample_weight in None, ones, 2 * ones: clf = GradientBoostingRegressor(n_estimators=100, loss=loss, From 7a618484e7071a671a31d4979403bb0e4b83eee5 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 17 Jun 2020 13:19:11 +0200 Subject: [PATCH 34/46] up rtol --- sklearn/ensemble/tests/test_gradient_boosting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 0da7b3c221a6a..c3da784b409bc 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -234,12 +234,12 @@ def check_regression_dataset(loss, subsample): assert mse < 0.04 if last_y_pred is not None: - # FIXME: `rtol=65` is very permissive. This is due to the fact that + # FIXME: `rtol=75` is very permissive. This is due to the fact that # GBRT with and without `sample_weight` do not use the same # implementation of the median during the initialization with the # `DummyRegressor`. In the future, we should make sure that both # implementations should be the same. See PR #17377 for more. - assert_allclose(last_y_pred, y_pred, rtol=65) + assert_allclose(last_y_pred, y_pred, rtol=75) last_y_pred = y_pred From bda24faa02f23430d06165e7674135807332cbdd Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 17 Jun 2020 13:26:35 +0200 Subject: [PATCH 35/46] [empty] CI From 475db417aec783e3c4592a84bff4c7d32370c7f4 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 17 Jun 2020 14:04:50 +0200 Subject: [PATCH 36/46] try rtol float --- sklearn/ensemble/tests/test_gradient_boosting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index c3da784b409bc..dc532007a1494 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -239,7 +239,7 @@ def check_regression_dataset(loss, subsample): # implementation of the median during the initialization with the # `DummyRegressor`. In the future, we should make sure that both # implementations should be the same. See PR #17377 for more. - assert_allclose(last_y_pred, y_pred, rtol=75) + assert_allclose(last_y_pred, y_pred, rtol=75.) last_y_pred = y_pred From b631eb497da84e5f77dd0383a8941c6b3f8d3b2c Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 17 Jun 2020 16:23:13 +0200 Subject: [PATCH 37/46] reduc rtol --- sklearn/ensemble/tests/test_gradient_boosting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index dc532007a1494..0da7b3c221a6a 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -234,12 +234,12 @@ def check_regression_dataset(loss, subsample): assert mse < 0.04 if last_y_pred is not None: - # FIXME: `rtol=75` is very permissive. This is due to the fact that + # FIXME: `rtol=65` is very permissive. This is due to the fact that # GBRT with and without `sample_weight` do not use the same # implementation of the median during the initialization with the # `DummyRegressor`. In the future, we should make sure that both # implementations should be the same. See PR #17377 for more. - assert_allclose(last_y_pred, y_pred, rtol=75.) + assert_allclose(last_y_pred, y_pred, rtol=65) last_y_pred = y_pred From 5e3b31c9545f2b9188f9aebe20ed64edab3a2616 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 17 Jun 2020 21:50:19 +0200 Subject: [PATCH 38/46] rtol=100 --- sklearn/ensemble/tests/test_gradient_boosting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 0da7b3c221a6a..b0ff1192db3bf 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -239,7 +239,7 @@ def check_regression_dataset(loss, subsample): # implementation of the median during the initialization with the # `DummyRegressor`. In the future, we should make sure that both # implementations should be the same. See PR #17377 for more. - assert_allclose(last_y_pred, y_pred, rtol=65) + assert_allclose(last_y_pred, y_pred, rtol=100) last_y_pred = y_pred From 97281b21f4478d9e57fc8b1a098a195ef6fb3422 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 18 Jun 2020 11:43:05 +0200 Subject: [PATCH 39/46] suggestions --- sklearn/ensemble/tests/test_gradient_boosting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index b0ff1192db3bf..fb6a2e72da434 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -19,7 +19,7 @@ from sklearn.ensemble import GradientBoostingClassifier from sklearn.ensemble import GradientBoostingRegressor from sklearn.ensemble._gradient_boosting import predict_stages -from sklearn.preprocessing import OneHotEncoder, StandardScaler +from sklearn.preprocessing import OneHotEncoder, scale from sklearn.svm import LinearSVC from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split @@ -54,7 +54,7 @@ X_reg, y_reg = make_regression( n_samples=500, n_features=10, n_informative=8, noise=10, random_state=7 ) -y_reg = StandardScaler().fit_transform(y_reg.reshape((-1, 1))) +y_reg = scale(y_reg) rng = np.random.RandomState(0) # also load the iris dataset @@ -216,7 +216,7 @@ def check_regression_dataset(loss, subsample): # and least absolute deviation. ones = np.ones(len(y_reg)) last_y_pred = None - for sample_weight in None, ones, 2 * ones: + for sample_weight in [None, ones, 2 * ones]: clf = GradientBoostingRegressor(n_estimators=100, loss=loss, max_depth=4, From 899d56dfc9aba52f8ff45088d1c9c3b656398353 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 22 Jun 2020 14:54:54 +0200 Subject: [PATCH 40/46] use weighted_percentile everywhere --- sklearn/dummy.py | 20 ++++++++---- sklearn/ensemble/_gb.py | 2 ++ sklearn/ensemble/_gb_losses.py | 25 ++++++++++----- .../ensemble/_hist_gradient_boosting/loss.py | 21 ++++++++----- .../ensemble/tests/test_gradient_boosting.py | 1 + sklearn/metrics/_regression.py | 10 ++++-- sklearn/tests/test_dummy.py | 8 +++-- sklearn/utils/stats.py | 19 ++++++++++-- sklearn/utils/tests/test_stats.py | 31 ++++++++++++------- 9 files changed, 96 insertions(+), 41 deletions(-) diff --git a/sklearn/dummy.py b/sklearn/dummy.py index cee7294ab5afd..cf93630f41411 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -492,9 +492,13 @@ def fit(self, X, y, sample_weight=None): if sample_weight is None: self.constant_ = np.median(y, axis=0) else: - self.constant_ = [_weighted_percentile(y[:, k], sample_weight, - percentile=50.) - for k in range(self.n_outputs_)] + self.constant_ = [ + _weighted_percentile( + y[:, k], sample_weight, percentile=50., + interpolation="nearest", + ) + for k in range(self.n_outputs_) + ] elif self.strategy == "quantile": if self.quantile is None or not np.isscalar(self.quantile): @@ -505,9 +509,13 @@ def fit(self, X, y, sample_weight=None): if sample_weight is None: self.constant_ = np.percentile(y, axis=0, q=percentile) else: - self.constant_ = [_weighted_percentile(y[:, k], sample_weight, - percentile=percentile) - for k in range(self.n_outputs_)] + self.constant_ = [ + _weighted_percentile( + y[:, k], sample_weight, percentile=percentile, + interpolation="nearest", + ) + for k in range(self.n_outputs_) + ] elif self.strategy == "constant": if self.constant is None: diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 439500c1917d8..b1cdf6935ea37 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -208,6 +208,8 @@ def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask, sample_weight = sample_weight * sample_mask.astype(np.float64) X = X_csr if X_csr is not None else X + print(sample_weight) + print(residual) tree.fit(X, residual, sample_weight=sample_weight, check_input=False, X_idx_sorted=X_idx_sorted) diff --git a/sklearn/ensemble/_gb_losses.py b/sklearn/ensemble/_gb_losses.py index 7bd5faca1d7d9..d37c92496c95e 100644 --- a/sklearn/ensemble/_gb_losses.py +++ b/sklearn/ensemble/_gb_losses.py @@ -317,8 +317,11 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, sample_weight = sample_weight.take(terminal_region, axis=0) diff = (y.take(terminal_region, axis=0) - raw_predictions.take(terminal_region, axis=0)) - tree.value[leaf, 0, 0] = _weighted_percentile(diff, sample_weight, - percentile=50) + tree.value[leaf, 0, 0] = _weighted_percentile( + diff, sample_weight, percentile=50, interpolation="nearest", + ) + # print(diff) + # print(sample_weight) class HuberLossFunction(RegressionLossFunction): @@ -368,10 +371,14 @@ def __call__(self, y, raw_predictions, sample_weight=None): gamma = self.gamma if gamma is None: if sample_weight is None: - gamma = np.percentile(np.abs(diff), self.alpha * 100) + gamma = np.percentile( + np.abs(diff), self.alpha * 100, interpolation="nearest", + ) else: - gamma = _weighted_percentile(np.abs(diff), sample_weight, - self.alpha * 100) + gamma = _weighted_percentile( + np.abs(diff), sample_weight=sample_weight, + percentile=self.alpha * 100, interpolation="nearest", + ) gamma_mask = np.abs(diff) <= gamma if sample_weight is None: @@ -424,7 +431,9 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, gamma = self.gamma diff = (y.take(terminal_region, axis=0) - raw_predictions.take(terminal_region, axis=0)) - median = _weighted_percentile(diff, sample_weight, percentile=50) + median = _weighted_percentile( + diff, sample_weight, percentile=50, interpolation="nearest", + ) diff_minus_median = diff - median tree.value[leaf, 0] = median + np.mean( np.sign(diff_minus_median) * @@ -506,7 +515,9 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, - raw_predictions.take(terminal_region, axis=0)) sample_weight = sample_weight.take(terminal_region, axis=0) - val = _weighted_percentile(diff, sample_weight, self.percentile) + val = _weighted_percentile( + diff, sample_weight, self.percentile, interpolation="nearest", + ) tree.value[leaf, 0] = val diff --git a/sklearn/ensemble/_hist_gradient_boosting/loss.py b/sklearn/ensemble/_hist_gradient_boosting/loss.py index f256408bf01fb..99060b0e83eb8 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/loss.py +++ b/sklearn/ensemble/_hist_gradient_boosting/loss.py @@ -224,9 +224,11 @@ def pointwise_loss(self, y_true, raw_predictions): def get_baseline_prediction(self, y_train, sample_weight, prediction_dim): if sample_weight is None: - return np.median(y_train) + return np.percentile(y_train, 50, interpolation="nearest") else: - return _weighted_percentile(y_train, sample_weight, 50) + return _weighted_percentile( + y_train, sample_weight, 50, interpolation="nearest", + ) @staticmethod def inverse_link_function(raw_predictions): @@ -258,13 +260,16 @@ def update_leaves_values(self, grower, y_true, raw_predictions, for leaf in grower.finalized_leaves: indices = leaf.sample_indices if sample_weight is None: - median_res = np.median(y_true[indices] - - raw_predictions[indices]) + median_res = np.percentile( + y_true[indices] - raw_predictions[indices], 50, + interpolation="nearest", + ) else: - median_res = _weighted_percentile(y_true[indices] - - raw_predictions[indices], - sample_weight=sample_weight, - percentile=50) + median_res = _weighted_percentile( + y_true[indices] - raw_predictions[indices], + sample_weight=sample_weight, percentile=50, + interpolation="nearest", + ) leaf.value = grower.shrinkage * median_res # Note that the regularization is ignored here diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index fa7252aa32488..d31ea741f342b 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -1143,6 +1143,7 @@ def test_non_uniform_weights_toy_edge_case_reg(loss): learning_rate=0.01, n_estimators=50, loss=loss, ) gb.fit(X, y, sample_weight=sample_weight) + print(gb.train_score_) assert gb.predict([[1]])[0] > 0.5 # check that the loss is always decreasing assert np.all(np.diff(gb.train_score_) <= 0) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index e805bdc099d1f..373113567f0da 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -401,11 +401,15 @@ def median_absolute_error(y_true, y_pred, *, multioutput='uniform_average', y_type, y_true, y_pred, multioutput = _check_reg_targets( y_true, y_pred, multioutput) if sample_weight is None: - output_errors = np.median(np.abs(y_pred - y_true), axis=0) + output_errors = np.percentile( + np.abs(y_pred - y_true), 50., interpolation="nearest", axis=0, + ) else: sample_weight = _check_sample_weight(sample_weight, y_pred) - output_errors = _weighted_percentile(np.abs(y_pred - y_true), - sample_weight=sample_weight) + output_errors = _weighted_percentile( + np.abs(y_pred - y_true), percentile=50, + sample_weight=sample_weight, interpolation="nearest", + ) if isinstance(multioutput, str): if multioutput == 'raw_values': return output_errors diff --git a/sklearn/tests/test_dummy.py b/sklearn/tests/test_dummy.py index 280ade175bc4a..d2b30a0f74f81 100644 --- a/sklearn/tests/test_dummy.py +++ b/sklearn/tests/test_dummy.py @@ -665,11 +665,15 @@ def test_dummy_regressor_sample_weight(n_samples=10): assert est.constant_ == np.average(y, weights=sample_weight) est = DummyRegressor(strategy="median").fit(X, y, sample_weight) - assert est.constant_ == _weighted_percentile(y, sample_weight, 50.) + assert est.constant_ == _weighted_percentile( + y, sample_weight, 50., interpolation="nearest", + ) est = DummyRegressor(strategy="quantile", quantile=.95).fit(X, y, sample_weight) - assert est.constant_ == _weighted_percentile(y, sample_weight, 95.) + assert est.constant_ == _weighted_percentile( + y, sample_weight, 95., interpolation="nearest", + ) def test_dummy_regressor_on_3D_array(): diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index d5166af9da8a0..379fe276f53cb 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -35,7 +35,8 @@ def _weighted_percentile(array, sample_weight, percentile=50, * `"linear"`: `i + (j - i) * fraction`, where `fraction` is the fractional part of the index surrounded by `i` and `j`; * `"lower"`: i`; - * `"higher"`: `j`. + * `"higher"`: `j`; + * `"nearest"`: `i` or `j`, whichever is nearest. .. versionadded: 0.24 @@ -45,7 +46,7 @@ def _weighted_percentile(array, sample_weight, percentile=50, ndarray of shape (m,) Weighted percentile. """ - possible_interpolation = ("linear", "lower", "higher") + possible_interpolation = ("linear", "lower", "higher", "nearest") if interpolation not in possible_interpolation: raise ValueError( f"'interpolation' should be one of " @@ -95,7 +96,7 @@ def _squeeze_arr(arr, n_dim): nan_mask = np.isnan(adjusted_percentile) adjusted_percentile[nan_mask] = 1 - if interpolation in ("lower", "higher"): + if interpolation in ("lower", "higher", "nearest"): percentile_idx = np.array([ np.searchsorted(adjusted_percentile[:, col], percentile[col], side="left") @@ -105,6 +106,18 @@ def _squeeze_arr(arr, n_dim): if interpolation == "lower" and np.all(percentile < 1): # P = 100 is a corner case for "lower" percentile_idx -= 1 + elif interpolation == "nearest" and np.all(percentile < 1): + for col in range(n_cols): + error_higher = abs( + adjusted_percentile[percentile_idx[col], col] - + percentile[col] + ) + error_lower = abs( + adjusted_percentile[percentile_idx[col] - 1, col] - + percentile[col] + ) + if error_higher >= error_lower: + percentile_idx[col] -= 1 percentile_idx = np.apply_along_axis( lambda x: np.clip(x, 0, n_rows - 1), axis=0, diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index 9119b19050630..400d0162e6c4f 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -21,10 +21,11 @@ def test_weighted_percentile(interpolation, expected_median): assert score == pytest.approx(expected_median) -@pytest.mark.parametrize("interpolation", ["linear", "lower", "higher"]) -def test_weighted_percentile_equal(interpolation): - y = np.empty(102, dtype=np.float64) - y.fill(0.0) +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest"] +) +def test_weighted_percentile_constant_data(interpolation): + y = np.zeros(102, dtype=np.float64) sw = np.ones(102, dtype=np.float64) sw[-1] = 0.0 score = _weighted_percentile(y, sw, 50, interpolation=interpolation) @@ -57,7 +58,9 @@ def test_weighted_median_integer_weights(): assert median == pytest.approx(w_median) -@pytest.mark.parametrize("interpolation", ["linear", "lower", "higher"]) +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest"] +) def test_weighted_percentile_2d(interpolation): # Check for when array 2D and sample_weight 1D rng = np.random.RandomState(0) @@ -103,7 +106,9 @@ def test_weighted_percentile_np_median(): assert sklearn_median == pytest.approx(np_median) -@pytest.mark.parametrize("interpolation", ["linear", "lower", "higher"]) +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest"] +) @pytest.mark.parametrize("percentile", np.arange(0, 101, 2.5)) def test_weighted_percentile_np_percentile(interpolation, percentile): rng = np.random.RandomState(0) @@ -145,7 +150,9 @@ def test_weighted_percentile_non_unit_weight(percentile): @pytest.mark.parametrize("n_features", [None, 2]) -@pytest.mark.parametrize("interpolation", ["linear", "higher", "lower"]) +@pytest.mark.parametrize( + "interpolation", ["linear", "higher", "lower", "nearest"] +) @pytest.mark.parametrize("percentile", np.arange(0, 101, 25)) def test_weighted_percentile_single_weight(n_features, interpolation, percentile): @@ -173,14 +180,14 @@ def test_weighted_percentile_all_null_weight(n_features): _weighted_percentile(X, sample_weight, 50) -@pytest.mark.parametrize("interpolation", ["linear", "lower", "higher"]) @pytest.mark.parametrize("percentile", [0, 25, 50, 75, 100]) -def test_weighted_percentile_equivalence_weights_repeated_samples( - interpolation, percentile, -): +def test_weighted_percentile_equivalence_weights_repeated_samples(percentile): + interpolation = "nearest" X_repeated = np.array([1, 2, 2, 3, 3, 3, 4, 4]) sample_weight_unit = np.ones(X_repeated.shape[0]) - p_npy_repeated = np.percentile(X_repeated, percentile) + p_npy_repeated = np.percentile( + X_repeated, percentile, interpolation=interpolation + ) p_sklearn_repeated = _weighted_percentile( X_repeated, sample_weight_unit, percentile, interpolation=interpolation, From cd4344bd357f763d52fa820a5657458121c62e3f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 22 Jun 2020 23:33:02 +0200 Subject: [PATCH 41/46] iter --- sklearn/ensemble/_gb.py | 2 -- sklearn/ensemble/_gb_losses.py | 2 -- sklearn/ensemble/tests/test_gradient_boosting.py | 2 +- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index b1cdf6935ea37..439500c1917d8 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -208,8 +208,6 @@ def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask, sample_weight = sample_weight * sample_mask.astype(np.float64) X = X_csr if X_csr is not None else X - print(sample_weight) - print(residual) tree.fit(X, residual, sample_weight=sample_weight, check_input=False, X_idx_sorted=X_idx_sorted) diff --git a/sklearn/ensemble/_gb_losses.py b/sklearn/ensemble/_gb_losses.py index d37c92496c95e..cdab5c6042ef6 100644 --- a/sklearn/ensemble/_gb_losses.py +++ b/sklearn/ensemble/_gb_losses.py @@ -320,8 +320,6 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, tree.value[leaf, 0, 0] = _weighted_percentile( diff, sample_weight, percentile=50, interpolation="nearest", ) - # print(diff) - # print(sample_weight) class HuberLossFunction(RegressionLossFunction): diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index d31ea741f342b..83d374c143333 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -1140,7 +1140,7 @@ def test_non_uniform_weights_toy_edge_case_reg(loss): # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] gb = GradientBoostingRegressor( - learning_rate=0.01, n_estimators=50, loss=loss, + learning_rate=0.1, n_estimators=200, loss=loss, ) gb.fit(X, y, sample_weight=sample_weight) print(gb.train_score_) From 2de439be4d07b2de0b10c0fcdfce641e05bc2839 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 22 Jun 2020 23:35:02 +0200 Subject: [PATCH 42/46] iter --- sklearn/ensemble/tests/test_gradient_boosting.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index c728f955561f2..2cb668311328e 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -1143,7 +1143,6 @@ def test_non_uniform_weights_toy_edge_case_reg(loss): learning_rate=0.1, n_estimators=200, loss=loss, ) gb.fit(X, y, sample_weight=sample_weight) - print(gb.train_score_) assert gb.predict([[1]])[0] > 0.5 # check that the loss is always decreasing assert np.all(np.diff(gb.train_score_) <= 0) From c02f1ea18cfada6d7614bf74e046a1c9f203f159 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 23 Jun 2020 00:03:12 +0200 Subject: [PATCH 43/46] iter --- sklearn/dummy.py | 4 +++- sklearn/ensemble/tests/test_gradient_boosting.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/dummy.py b/sklearn/dummy.py index cf93630f41411..3547a7c5404c8 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -507,7 +507,9 @@ def fit(self, X, y, sample_weight=None): percentile = self.quantile * 100.0 if sample_weight is None: - self.constant_ = np.percentile(y, axis=0, q=percentile) + self.constant_ = np.percentile( + y, q=percentile, interpolation="nearest", axis=0 + ) else: self.constant_ = [ _weighted_percentile( diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 537d2e4c0c9a6..819afcfeab1c5 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -239,8 +239,8 @@ def check_regression_dataset(loss, subsample): # implementation of the median during the initialization with the # `DummyRegressor`. In the future, we should make sure that both # implementations should be the same. See PR #17377 for more. - assert_allclose(last_y_pred, y_pred, rtol=100) - + # assert_allclose(last_y_pred, y_pred) + pass last_y_pred = y_pred From d47b1628e630176a6d0b89e8b04933a015781062 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 23 Jun 2020 00:05:13 +0200 Subject: [PATCH 44/46] iter --- sklearn/ensemble/tests/test_gradient_boosting.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 819afcfeab1c5..57642791f70b0 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -234,13 +234,7 @@ def check_regression_dataset(loss, subsample): assert mse < 0.04 if last_y_pred is not None: - # FIXME: `rtol=65` is very permissive. This is due to the fact that - # GBRT with and without `sample_weight` do not use the same - # implementation of the median during the initialization with the - # `DummyRegressor`. In the future, we should make sure that both - # implementations should be the same. See PR #17377 for more. - # assert_allclose(last_y_pred, y_pred) - pass + assert_allclose(last_y_pred, y_pred) last_y_pred = y_pred From 089da6b9d0025c68354c050b66358d84530336a4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 23 Jun 2020 10:16:13 +0200 Subject: [PATCH 45/46] iter --- sklearn/dummy.py | 4 +- .../tests/test_loss.py | 4 +- .../test_gradient_boosting_loss_functions.py | 75 ++++++++++--------- sklearn/tests/test_dummy.py | 27 +++++-- 4 files changed, 67 insertions(+), 43 deletions(-) diff --git a/sklearn/dummy.py b/sklearn/dummy.py index 3547a7c5404c8..5c4bddbcf2d1b 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -490,7 +490,9 @@ def fit(self, X, y, sample_weight=None): elif self.strategy == "median": if sample_weight is None: - self.constant_ = np.median(y, axis=0) + self.constant_ = np.percentile( + y, 50, interpolation="nearest", axis=0 + ) else: self.constant_ = [ _weighted_percentile( diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py index 029b00a85affe..295af7198ce32 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py @@ -183,7 +183,9 @@ def test_baseline_least_absolute_deviation(): # Make sure baseline prediction is the median of all targets assert np.allclose(loss.inverse_link_function(baseline_prediction), baseline_prediction) - assert baseline_prediction == pytest.approx(np.median(y_train)) + assert baseline_prediction == pytest.approx( + np.percentile(y_train, 50, interpolation="nearest") + ) def test_baseline_poisson(): diff --git a/sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py b/sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py index b5bc17eeeb14c..ec4be61cef6c2 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py +++ b/sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py @@ -65,41 +65,41 @@ def test_sample_weight_smoke(): assert_almost_equal(loss_wo_sw, loss_w_sw) -def test_sample_weight_init_estimators(): +@pytest.mark.parametrize("Loss", LOSS_FUNCTIONS.values()) +def test_sample_weight_init_estimators(Loss): # Smoke test for init estimators with sample weights. rng = check_random_state(13) - X = rng.rand(100, 2) - sample_weight = np.ones(100) - reg_y = rng.rand(100) - - clf_y = rng.randint(0, 2, size=100) - - for Loss in LOSS_FUNCTIONS.values(): - if Loss is None: - continue - if issubclass(Loss, RegressionLossFunction): - k = 1 - y = reg_y - else: - k = 2 - y = clf_y - if Loss.is_multi_class: - # skip multiclass - continue - - loss = Loss(k) - init_est = loss.init_estimator() - init_est.fit(X, y) - out = loss.get_init_raw_predictions(X, init_est) - assert out.shape == (y.shape[0], 1) - - sw_init_est = loss.init_estimator() - sw_init_est.fit(X, y, sample_weight=sample_weight) - sw_out = loss.get_init_raw_predictions(X, sw_init_est) - assert sw_out.shape == (y.shape[0], 1) - - # check if predictions match - assert_allclose(out, sw_out, rtol=1e-2) + X = rng.rand(101, 2) + sample_weight = np.ones(101) + reg_y = rng.rand(101) + + clf_y = rng.randint(0, 2, size=101) + + if Loss is None: + return + if issubclass(Loss, RegressionLossFunction): + k = 1 + y = reg_y + else: + k = 2 + y = clf_y + if Loss.is_multi_class: + # skip multiclass + return + + loss = Loss(k) + init_est = loss.init_estimator() + init_est.fit(X, y) + out = loss.get_init_raw_predictions(X, init_est) + assert out.shape == (y.shape[0], 1) + + sw_init_est = loss.init_estimator() + sw_init_est.fit(X, y, sample_weight=sample_weight) + sw_out = loss.get_init_raw_predictions(X, sw_init_est) + assert sw_out.shape == (y.shape[0], 1) + + # check if predictions match + assert_allclose(out, sw_out) def test_quantile_loss_function(): @@ -202,7 +202,9 @@ def test_init_raw_predictions_values(): init_estimator = loss.init_estimator().fit(X, y) raw_predictions = loss.get_init_raw_predictions(y, init_estimator) # Make sure baseline prediction is the median of all targets - assert_almost_equal(raw_predictions, np.median(y)) + assert_almost_equal( + raw_predictions, np.percentile(y, 50, interpolation="nearest") + ) # Quantile loss for alpha in (.1, .5, .9): @@ -210,7 +212,10 @@ def test_init_raw_predictions_values(): init_estimator = loss.init_estimator().fit(X, y) raw_predictions = loss.get_init_raw_predictions(y, init_estimator) # Make sure baseline prediction is the alpha-quantile of all targets - assert_almost_equal(raw_predictions, np.percentile(y, alpha * 100)) + assert_almost_equal( + raw_predictions, + np.percentile(y, alpha * 100, interpolation="nearest") + ) y = rng.randint(0, 2, size=n_samples) diff --git a/sklearn/tests/test_dummy.py b/sklearn/tests/test_dummy.py index d2b30a0f74f81..935b707e01624 100644 --- a/sklearn/tests/test_dummy.py +++ b/sklearn/tests/test_dummy.py @@ -311,7 +311,10 @@ def test_median_strategy_regressor(): reg = DummyRegressor(strategy="median") reg.fit(X, y) - assert_array_equal(reg.predict(X), [np.median(y)] * len(X)) + assert_array_equal( + reg.predict(X), + [np.percentile(y, 50, interpolation="nearest")] * len(X) + ) def test_median_strategy_multioutput_regressor(): @@ -321,7 +324,9 @@ def test_median_strategy_multioutput_regressor(): X_learn = random_state.randn(10, 10) y_learn = random_state.randn(10, 5) - median = np.median(y_learn, axis=0).reshape((1, -1)) + median = np.percentile( + y_learn, 50, interpolation="nearest", axis=0 + ).reshape((1, -1)) X_test = random_state.randn(20, 10) y_test = random_state.randn(20, 5) @@ -346,7 +351,10 @@ def test_quantile_strategy_regressor(): reg = DummyRegressor(strategy="quantile", quantile=0.5) reg.fit(X, y) - assert_array_equal(reg.predict(X), [np.median(y)] * len(X)) + assert_array_equal( + reg.predict(X), + [np.percentile(y, 50, interpolation="nearest")] * len(X) + ) reg = DummyRegressor(strategy="quantile", quantile=0) reg.fit(X, y) @@ -358,7 +366,10 @@ def test_quantile_strategy_regressor(): reg = DummyRegressor(strategy="quantile", quantile=0.3) reg.fit(X, y) - assert_array_equal(reg.predict(X), [np.percentile(y, q=30)] * len(X)) + assert_array_equal( + reg.predict(X), + [np.percentile(y, q=30, interpolation="nearest")] * len(X) + ) def test_quantile_strategy_multioutput_regressor(): @@ -368,8 +379,12 @@ def test_quantile_strategy_multioutput_regressor(): X_learn = random_state.randn(10, 10) y_learn = random_state.randn(10, 5) - median = np.median(y_learn, axis=0).reshape((1, -1)) - quantile_values = np.percentile(y_learn, axis=0, q=80).reshape((1, -1)) + median = np.percentile( + y_learn, 50, interpolation="nearest", axis=0 + ).reshape((1, -1)) + quantile_values = np.percentile( + y_learn, axis=0, q=80, interpolation="nearest" + ).reshape((1, -1)) X_test = random_state.randn(20, 10) y_test = random_state.randn(20, 5) From 548bda17cd7428e7aef7f3afc335171dd7facf6e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 23 Jun 2020 10:20:15 +0200 Subject: [PATCH 46/46] change name variable --- sklearn/ensemble/tests/test_gradient_boosting.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 57642791f70b0..f98ba288334c3 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -217,19 +217,18 @@ def check_regression_dataset(loss, subsample): ones = np.ones(len(y_reg)) last_y_pred = None for sample_weight in [None, ones, 2 * ones]: - clf = GradientBoostingRegressor(n_estimators=100, + reg = GradientBoostingRegressor(n_estimators=100, loss=loss, max_depth=4, subsample=subsample, min_samples_split=2, random_state=1) - assert_raises(ValueError, clf.predict, X_reg) - clf.fit(X_reg, y_reg, sample_weight=sample_weight) - leaves = clf.apply(X_reg) + reg.fit(X_reg, y_reg, sample_weight=sample_weight) + leaves = reg.apply(X_reg) assert leaves.shape == (500, 100) - y_pred = clf.predict(X_reg) + y_pred = reg.predict(X_reg) mse = mean_squared_error(y_reg, y_pred) assert mse < 0.04