Skip to content

[WIP] BUG make _weighted_percentile(data, ones, 50) consistent with numpy.median(data) #17377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 52 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
2ecc96d
check diabetes
lucyleeow Apr 16, 2020
deacbf5
use diabetes and cali
lucyleeow Apr 16, 2020
f257ff1
pytest network
lucyleeow Apr 16, 2020
e0b00ac
Merge remote-tracking branch 'origin/master' into pr/lucyleeow/16937
glemaitre May 28, 2020
31a116e
BUG make _weighted_percentile behave as NumPy
glemaitre May 28, 2020
6c8a405
iter
glemaitre May 28, 2020
23af759
revert setup.cfg
glemaitre May 28, 2020
3be7c09
iter
glemaitre May 29, 2020
06aeab1
iter
glemaitre May 29, 2020
f389292
iter
glemaitre May 29, 2020
9e1222f
iter
glemaitre May 29, 2020
9314aee
iter
glemaitre May 29, 2020
0e857a9
iter
glemaitre May 29, 2020
5988234
improve documentation
glemaitre May 29, 2020
8100873
iter
glemaitre May 29, 2020
cc4a172
iter
glemaitre May 29, 2020
4e100a9
parametrize debug
glemaitre May 29, 2020
25e5d24
iter
glemaitre Jun 1, 2020
ff2a6e0
case we have a single weight non null
glemaitre Jun 2, 2020
201f0c7
update test
glemaitre Jun 2, 2020
d8a4a73
compat old numpy
glemaitre Jun 2, 2020
450f4c8
iter
glemaitre Jun 2, 2020
19588fb
loss decreasing assert
glemaitre Jun 2, 2020
2d3d9fb
iter
glemaitre Jun 2, 2020
89e8ccc
remove a test
glemaitre Jun 2, 2020
ef9d882
iter
glemaitre Jun 2, 2020
84d782e
tst old numpy
glemaitre Jun 2, 2020
882a354
iter
glemaitre Jun 2, 2020
79b719f
TST add to check the equivalence repeated/weights
glemaitre Jun 4, 2020
babd758
try all interpolation
glemaitre Jun 4, 2020
c26403e
add comments on method
glemaitre Jun 4, 2020
e269247
wip
lucyleeow Jun 17, 2020
e6ce12b
Merge branch 'test_grad_boost' of github.com:lucyleeow/scikit-learn i…
lucyleeow Jun 17, 2020
5bf3b8f
Merge branch 'master' into test_grad_boost
lucyleeow Jun 17, 2020
70de4b9
use make regression
lucyleeow Jun 17, 2020
65f11e9
fix lint
lucyleeow Jun 17, 2020
7a61848
up rtol
lucyleeow Jun 17, 2020
bda24fa
[empty] CI
lucyleeow Jun 17, 2020
475db41
try rtol float
lucyleeow Jun 17, 2020
b631eb4
reduc rtol
lucyleeow Jun 17, 2020
5e3b31c
rtol=100
lucyleeow Jun 17, 2020
97281b2
suggestions
lucyleeow Jun 18, 2020
899d56d
use weighted_percentile everywhere
glemaitre Jun 22, 2020
cd4344b
iter
glemaitre Jun 22, 2020
2327f98
Merge remote-tracking branch 'origin/master' into is/17370
glemaitre Jun 22, 2020
2de439b
iter
glemaitre Jun 22, 2020
45f6123
Merge remote-tracking branch 'lucyleeow/test_grad_boost' into is/17370
glemaitre Jun 22, 2020
c02f1ea
iter
glemaitre Jun 22, 2020
d47b162
iter
glemaitre Jun 22, 2020
089da6b
iter
glemaitre Jun 23, 2020
548bda1
change name variable
glemaitre Jun 23, 2020
7ca2d47
Merge remote-tracking branch 'origin/master' into is/17370
glemaitre Jun 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions sklearn/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,17 @@ def fit(self, X, y, sample_weight=None):

elif self.strategy == "median":
if sample_weight is None:
self.constant_ = np.median(y, axis=0)
self.constant_ = np.percentile(
y, 50, interpolation="nearest", axis=0
)
else:
self.constant_ = [_weighted_percentile(y[:, k], sample_weight,
percentile=50.)
for k in range(self.n_outputs_)]
self.constant_ = [
_weighted_percentile(
y[:, k], sample_weight, percentile=50.,
interpolation="nearest",
)
for k in range(self.n_outputs_)
]

elif self.strategy == "quantile":
if self.quantile is None or not np.isscalar(self.quantile):
Expand All @@ -503,11 +509,17 @@ def fit(self, X, y, sample_weight=None):

percentile = self.quantile * 100.0
if sample_weight is None:
self.constant_ = np.percentile(y, axis=0, q=percentile)
self.constant_ = np.percentile(
y, q=percentile, interpolation="nearest", axis=0
)
else:
self.constant_ = [_weighted_percentile(y[:, k], sample_weight,
percentile=percentile)
for k in range(self.n_outputs_)]
self.constant_ = [
_weighted_percentile(
y[:, k], sample_weight, percentile=percentile,
interpolation="nearest",
)
for k in range(self.n_outputs_)
]

elif self.strategy == "constant":
if self.constant is None:
Expand Down
23 changes: 16 additions & 7 deletions sklearn/ensemble/_gb_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,9 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y,
sample_weight = sample_weight.take(terminal_region, axis=0)
diff = (y.take(terminal_region, axis=0) -
raw_predictions.take(terminal_region, axis=0))
tree.value[leaf, 0, 0] = _weighted_percentile(diff, sample_weight,
percentile=50)
tree.value[leaf, 0, 0] = _weighted_percentile(
diff, sample_weight, percentile=50, interpolation="nearest",
)


class HuberLossFunction(RegressionLossFunction):
Expand Down Expand Up @@ -368,10 +369,14 @@ def __call__(self, y, raw_predictions, sample_weight=None):
gamma = self.gamma
if gamma is None:
if sample_weight is None:
gamma = np.percentile(np.abs(diff), self.alpha * 100)
gamma = np.percentile(
np.abs(diff), self.alpha * 100, interpolation="nearest",
)
else:
gamma = _weighted_percentile(np.abs(diff), sample_weight,
self.alpha * 100)
gamma = _weighted_percentile(
np.abs(diff), sample_weight=sample_weight,
percentile=self.alpha * 100, interpolation="nearest",
)

gamma_mask = np.abs(diff) <= gamma
if sample_weight is None:
Expand Down Expand Up @@ -424,7 +429,9 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y,
gamma = self.gamma
diff = (y.take(terminal_region, axis=0)
- raw_predictions.take(terminal_region, axis=0))
median = _weighted_percentile(diff, sample_weight, percentile=50)
median = _weighted_percentile(
diff, sample_weight, percentile=50, interpolation="nearest",
)
diff_minus_median = diff - median
tree.value[leaf, 0] = median + np.mean(
np.sign(diff_minus_median) *
Expand Down Expand Up @@ -506,7 +513,9 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y,
- raw_predictions.take(terminal_region, axis=0))
sample_weight = sample_weight.take(terminal_region, axis=0)

val = _weighted_percentile(diff, sample_weight, self.percentile)
val = _weighted_percentile(
diff, sample_weight, self.percentile, interpolation="nearest",
)
tree.value[leaf, 0] = val


Expand Down
21 changes: 13 additions & 8 deletions sklearn/ensemble/_hist_gradient_boosting/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,11 @@ def pointwise_loss(self, y_true, raw_predictions):

def get_baseline_prediction(self, y_train, sample_weight, prediction_dim):
if sample_weight is None:
return np.median(y_train)
return np.percentile(y_train, 50, interpolation="nearest")
else:
return _weighted_percentile(y_train, sample_weight, 50)
return _weighted_percentile(
y_train, sample_weight, 50, interpolation="nearest",
)

@staticmethod
def inverse_link_function(raw_predictions):
Expand Down Expand Up @@ -258,13 +260,16 @@ def update_leaves_values(self, grower, y_true, raw_predictions,
for leaf in grower.finalized_leaves:
indices = leaf.sample_indices
if sample_weight is None:
median_res = np.median(y_true[indices]
- raw_predictions[indices])
median_res = np.percentile(
y_true[indices] - raw_predictions[indices], 50,
interpolation="nearest",
)
else:
median_res = _weighted_percentile(y_true[indices]
- raw_predictions[indices],
sample_weight=sample_weight,
percentile=50)
median_res = _weighted_percentile(
y_true[indices] - raw_predictions[indices],
sample_weight=sample_weight, percentile=50,
interpolation="nearest",
)
leaf.value = grower.shrinkage * median_res
# Note that the regularization is ignored here

Expand Down
4 changes: 3 additions & 1 deletion sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ def test_baseline_least_absolute_deviation():
# Make sure baseline prediction is the median of all targets
assert np.allclose(loss.inverse_link_function(baseline_prediction),
baseline_prediction)
assert baseline_prediction == pytest.approx(np.median(y_train))
assert baseline_prediction == pytest.approx(
np.percentile(y_train, 50, interpolation="nearest")
)


def test_baseline_poisson():
Expand Down
32 changes: 13 additions & 19 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import warnings
import numpy as np
from numpy.testing import assert_allclose

from scipy.sparse import csr_matrix
from scipy.sparse import csc_matrix
Expand Down Expand Up @@ -229,17 +230,10 @@ def check_regression_dataset(loss, subsample):

y_pred = reg.predict(X_reg)
mse = mean_squared_error(y_reg, y_pred)
assert mse < 0.04
assert mse < 0.05

if last_y_pred is not None:
# FIXME: We temporarily bypass this test. This is due to the fact
# that GBRT with and without `sample_weight` do not use the same
# implementation of the median during the initialization with the
# `DummyRegressor`. In the future, we should make sure that both
# implementations should be the same. See PR #17377 for more.
# assert_allclose(last_y_pred, y_pred)
pass

assert_allclose(last_y_pred, y_pred)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucyleeow I added your test here. Using the strategy nearest everywhere seems to be the winner here to be able to keep the sample_weight semantic rights.

last_y_pred = y_pred


Expand Down Expand Up @@ -1137,19 +1131,19 @@ def test_probability_exponential():
assert_array_equal(y_pred, true_result)


def test_non_uniform_weights_toy_edge_case_reg():
X = [[1, 0],
[1, 0],
[1, 0],
[0, 1]]
@pytest.mark.parametrize("loss", ['huber', 'ls', 'lad', 'quantile'])
def test_non_uniform_weights_toy_edge_case_reg(loss):
X = [[1], [1], [1], [0]]
y = [0, 0, 1, 0]
# ignore the first 2 training samples by setting their weight to 0
sample_weight = [0, 0, 1, 1]
for loss in ('huber', 'ls', 'lad', 'quantile'):
gb = GradientBoostingRegressor(learning_rate=1.0, n_estimators=2,
loss=loss)
gb.fit(X, y, sample_weight=sample_weight)
assert gb.predict([[1, 0]])[0] > 0.5
gb = GradientBoostingRegressor(
learning_rate=0.1, n_estimators=200, loss=loss,
)
gb.fit(X, y, sample_weight=sample_weight)
assert gb.predict([[1]])[0] > 0.5
# check that the loss is always decreasing
assert np.all(np.diff(gb.train_score_) <= 0)


def test_non_uniform_weights_toy_edge_case_clf():
Expand Down
75 changes: 40 additions & 35 deletions sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,41 +65,41 @@ def test_sample_weight_smoke():
assert_almost_equal(loss_wo_sw, loss_w_sw)


def test_sample_weight_init_estimators():
@pytest.mark.parametrize("Loss", LOSS_FUNCTIONS.values())
def test_sample_weight_init_estimators(Loss):
# Smoke test for init estimators with sample weights.
rng = check_random_state(13)
X = rng.rand(100, 2)
sample_weight = np.ones(100)
reg_y = rng.rand(100)

clf_y = rng.randint(0, 2, size=100)

for Loss in LOSS_FUNCTIONS.values():
if Loss is None:
continue
if issubclass(Loss, RegressionLossFunction):
k = 1
y = reg_y
else:
k = 2
y = clf_y
if Loss.is_multi_class:
# skip multiclass
continue

loss = Loss(k)
init_est = loss.init_estimator()
init_est.fit(X, y)
out = loss.get_init_raw_predictions(X, init_est)
assert out.shape == (y.shape[0], 1)

sw_init_est = loss.init_estimator()
sw_init_est.fit(X, y, sample_weight=sample_weight)
sw_out = loss.get_init_raw_predictions(X, sw_init_est)
assert sw_out.shape == (y.shape[0], 1)

# check if predictions match
assert_allclose(out, sw_out, rtol=1e-2)
X = rng.rand(101, 2)
sample_weight = np.ones(101)
reg_y = rng.rand(101)

clf_y = rng.randint(0, 2, size=101)

if Loss is None:
return
if issubclass(Loss, RegressionLossFunction):
k = 1
y = reg_y
else:
k = 2
y = clf_y
if Loss.is_multi_class:
# skip multiclass
return

loss = Loss(k)
init_est = loss.init_estimator()
init_est.fit(X, y)
out = loss.get_init_raw_predictions(X, init_est)
assert out.shape == (y.shape[0], 1)

sw_init_est = loss.init_estimator()
sw_init_est.fit(X, y, sample_weight=sample_weight)
sw_out = loss.get_init_raw_predictions(X, sw_init_est)
assert sw_out.shape == (y.shape[0], 1)

# check if predictions match
assert_allclose(out, sw_out)


def test_quantile_loss_function():
Expand Down Expand Up @@ -202,15 +202,20 @@ def test_init_raw_predictions_values():
init_estimator = loss.init_estimator().fit(X, y)
raw_predictions = loss.get_init_raw_predictions(y, init_estimator)
# Make sure baseline prediction is the median of all targets
assert_almost_equal(raw_predictions, np.median(y))
assert_almost_equal(
raw_predictions, np.percentile(y, 50, interpolation="nearest")
)

# Quantile loss
for alpha in (.1, .5, .9):
loss = QuantileLossFunction(n_classes=1, alpha=alpha)
init_estimator = loss.init_estimator().fit(X, y)
raw_predictions = loss.get_init_raw_predictions(y, init_estimator)
# Make sure baseline prediction is the alpha-quantile of all targets
assert_almost_equal(raw_predictions, np.percentile(y, alpha * 100))
assert_almost_equal(
raw_predictions,
np.percentile(y, alpha * 100, interpolation="nearest")
)

y = rng.randint(0, 2, size=n_samples)

Expand Down
10 changes: 7 additions & 3 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,15 @@ def median_absolute_error(y_true, y_pred, *, multioutput='uniform_average',
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput)
if sample_weight is None:
output_errors = np.median(np.abs(y_pred - y_true), axis=0)
output_errors = np.percentile(
np.abs(y_pred - y_true), 50., interpolation="nearest", axis=0,
)
else:
sample_weight = _check_sample_weight(sample_weight, y_pred)
output_errors = _weighted_percentile(np.abs(y_pred - y_true),
sample_weight=sample_weight)
output_errors = _weighted_percentile(
np.abs(y_pred - y_true), percentile=50,
sample_weight=sample_weight, interpolation="nearest",
)
if isinstance(multioutput, str):
if multioutput == 'raw_values':
return output_errors
Expand Down
59 changes: 32 additions & 27 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,44 +520,49 @@ def test_classification_scorer_sample_weight():
f"with sample weights: {str(e)}")


@ignore_warnings
def test_regression_scorer_sample_weight():
# Test that regression scorers support sample_weight or raise sensible
# errors

@pytest.fixture
def regressor_and_data():
# Odd number of test samples req for neg_median_absolute_error
X, y = make_regression(n_samples=101, n_features=20, random_state=0)
y = _require_positive_y(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

sample_weight = np.ones_like(y_test)
sample_weight_test = np.ones_like(y_test)
# Odd number req for neg_median_absolute_error
sample_weight[:11] = 0
sample_weight_test[:11] = 0

reg = DecisionTreeRegressor(random_state=0)
reg.fit(X_train, y_train)

for name, scorer in SCORERS.items():
if name not in REGRESSION_SCORERS:
# skip classification scorers
continue
try:
weighted = scorer(reg, X_test, y_test,
sample_weight=sample_weight)
ignored = scorer(reg, X_test[11:], y_test[11:])
unweighted = scorer(reg, X_test, y_test)
assert weighted != unweighted, (
f"scorer {name} behaves identically when called with "
f"sample weights: {weighted} vs {unweighted}")
assert_almost_equal(weighted, ignored,
err_msg=f"scorer {name} behaves differently "
f"when ignoring samples and setting "
f"sample_weight to 0: {weighted} vs {ignored}")
return reg, X_test, y_test, sample_weight_test

except TypeError as e:
assert "sample_weight" in str(e), (
f"scorer {name} raises unhelpful exception when called "
f"with sample weights: {str(e)}")

@ignore_warnings
@pytest.mark.parametrize("name, scorer", SCORERS.items())
def test_regression_scorer_sample_weight(regressor_and_data, name, scorer):
# Test that regression scorers support sample_weight or raise sensible
# errors
reg, X_test, y_test, sample_weight = regressor_and_data

if name not in REGRESSION_SCORERS:
# skip classification scorers
return
try:
weighted = scorer(reg, X_test, y_test, sample_weight=sample_weight)
ignored = scorer(reg, X_test[11:], y_test[11:])
unweighted = scorer(reg, X_test, y_test)
assert weighted != unweighted, (
f"scorer {name} behaves identically when called with "
f"sample weights: {weighted} vs {unweighted}")
assert_almost_equal(weighted, ignored,
err_msg=f"scorer {name} behaves differently "
f"when ignoring samples and setting "
f"sample_weight to 0: {weighted} vs {ignored}")

except TypeError as e:
assert "sample_weight" in str(e), (
f"scorer {name} raises unhelpful exception when called "
f"with sample weights: {str(e)}")


@pytest.mark.parametrize('name', SCORERS)
Expand Down
Loading