Skip to content

[MRG] FIX ignore null weight when computing estimator error in AdaBoostRegressor #14294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Oct 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a645c4f
FIX normalize with max of samples with non-null weights in AdaBoostRe…
glemaitre Jul 5, 2019
cb52f65
iter
glemaitre Jul 8, 2019
6557806
Merge remote-tracking branch 'origin/master' into is/fix_adaboost_reg…
glemaitre Jul 8, 2019
805a6b5
update PR number
glemaitre Jul 8, 2019
b0f2966
PEP8
glemaitre Jul 8, 2019
0f317d9
iter
glemaitre Jul 8, 2019
6ee41fd
iter
glemaitre Jul 10, 2019
344e46f
Merge remote-tracking branch 'origin/master' into is/fix_adaboost_reg…
glemaitre Sep 10, 2019
ea46a7c
ignore line coverage
glemaitre Sep 10, 2019
6902eff
FIX use _check_sample_weight to validate sample_weight
glemaitre Sep 10, 2019
a5dec84
address jeremie comments
glemaitre Sep 12, 2019
ddafd73
Merge remote-tracking branch 'origin/master' into is/fix_adaboost_reg…
glemaitre Sep 12, 2019
810c637
fix inplace/mask copy operation
glemaitre Sep 12, 2019
fd50b08
add default value for wrapper
glemaitre Sep 12, 2019
da5a2a0
PEP8
glemaitre Sep 12, 2019
1da7cc7
Apply suggestions from code review
glemaitre Sep 12, 2019
ff2d23d
apply jeremie comments
glemaitre Sep 12, 2019
335b62a
increase coverage
glemaitre Sep 13, 2019
68084b5
PEP8
glemaitre Sep 13, 2019
cd65a83
Merge branch 'master' into is/fix_adaboost_regressor
glemaitre Sep 24, 2019
2f8054e
address adrin comments
glemaitre Oct 3, 2019
54b3d89
Merge remote-tracking branch 'origin/master' into is/fix_adaboost_reg…
glemaitre Oct 3, 2019
f446018
Merge remote-tracking branch 'glemaitre/is/fix_adaboost_regressor' in…
glemaitre Oct 3, 2019
cbcd681
change import for mocking
glemaitre Oct 3, 2019
362ac2a
fix
glemaitre Oct 3, 2019
a9812fe
Merge remote-tracking branch 'origin/master' into is/fix_adaboost_reg…
glemaitre Oct 22, 2019
31a6276
PEP8
glemaitre Oct 22, 2019
d494b3f
address comment adrin
glemaitre Oct 23, 2019
2e239d1
Merge remote-tracking branch 'origin/master' into is/fix_adaboost_reg…
glemaitre Oct 23, 2019
d353765
Merge remote-tracking branch 'glemaitre/is/fix_adaboost_regressor' in…
glemaitre Oct 23, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ Changelog
now raise consistent error messages.
:pr:`15084` by :user:`Guillaume Lemaitre <glemaitre>`.

- |Fix| :class:`ensemble.AdaBoostRegressor` where the loss should be normalized
by the max of the samples with non-null weights only.
:pr:`14294` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.feature_extraction`
.................................

Expand Down
51 changes: 21 additions & 30 deletions sklearn/ensemble/_weight_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..utils.extmath import stable_cumsum
from ..metrics import accuracy_score, r2_score
from ..utils.validation import check_is_fitted
from ..utils.validation import _check_sample_weight
from ..utils.validation import has_fit_parameter
from ..utils.validation import _num_samples

Expand Down Expand Up @@ -117,20 +118,10 @@ def fit(self, X, y, sample_weight=None):

X, y = self._validate_data(X, y)

if sample_weight is None:
# Initialize weights to 1 / n_samples
sample_weight = np.empty(_num_samples(X), dtype=np.float64)
sample_weight[:] = 1. / _num_samples(X)
else:
sample_weight = check_array(sample_weight, ensure_2d=False)
# Normalize existing weights
sample_weight = sample_weight / sample_weight.sum(dtype=np.float64)

# Check that the sample weights sum is positive
if sample_weight.sum() <= 0:
raise ValueError(
"Attempting to fit with a non-positive "
"weighted number of samples.")
sample_weight = _check_sample_weight(sample_weight, X, np.float64)
sample_weight /= sample_weight.sum()
if np.any(sample_weight < 0):
raise ValueError("sample_weight cannot contain negative weights")

# Check parameters
self._validate_estimator()
Expand Down Expand Up @@ -1029,13 +1020,10 @@ def _boost(self, iboost, X, y, sample_weight, random_state):
estimator = self._make_estimator(random_state=random_state)

# Weighted sampling of the training set with replacement
# For NumPy >= 1.7.0 use np.random.choice
cdf = stable_cumsum(sample_weight)
cdf /= cdf[-1]
uniform_samples = random_state.random_sample(_num_samples(X))
bootstrap_idx = cdf.searchsorted(uniform_samples, side='right')
# searchsorted returns a scalar
bootstrap_idx = np.array(bootstrap_idx, copy=False)
bootstrap_idx = random_state.choice(
np.arange(_num_samples(X)), size=_num_samples(X), replace=True,
p=sample_weight
)

# Fit on the bootstrapped sample and obtain a prediction
# for all samples in the training set
Expand All @@ -1045,18 +1033,21 @@ def _boost(self, iboost, X, y, sample_weight, random_state):
y_predict = estimator.predict(X)

error_vect = np.abs(y_predict - y)
error_max = error_vect.max()
sample_mask = sample_weight > 0
masked_sample_weight = sample_weight[sample_mask]
masked_error_vector = error_vect[sample_mask]

if error_max != 0.:
error_vect /= error_max
error_max = masked_error_vector.max()
if error_max != 0:
masked_error_vector /= error_max

if self.loss == 'square':
error_vect **= 2
masked_error_vector **= 2
elif self.loss == 'exponential':
error_vect = 1. - np.exp(- error_vect)
masked_error_vector = 1. - np.exp(-masked_error_vector)

# Calculate the average loss
estimator_error = (sample_weight * error_vect).sum()
estimator_error = (masked_sample_weight * masked_error_vector).sum()

if estimator_error <= 0:
# Stop if fit is perfect
Expand All @@ -1074,9 +1065,9 @@ def _boost(self, iboost, X, y, sample_weight, random_state):
estimator_weight = self.learning_rate * np.log(1. / beta)

if not iboost == self.n_estimators - 1:
sample_weight *= np.power(
beta,
(1. - error_vect) * self.learning_rate)
sample_weight[sample_mask] *= np.power(
beta, (1. - masked_error_vector) * self.learning_rate
)

return sample_weight, estimator_weight, estimator_error

Expand Down
17 changes: 2 additions & 15 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state, tosequence
from sklearn.utils._mocking import NoSampleWeightWrapper
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_array_equal
Expand Down Expand Up @@ -1292,20 +1293,6 @@ def test_early_stopping_stratified():
gbc.fit(X, y)


class _NoSampleWeightWrapper(BaseEstimator):
def __init__(self, est):
self.est = est

def fit(self, X, y):
self.est.fit(X, y)

def predict(self, X):
return self.est.predict(X)

def predict_proba(self, X):
return self.est.predict_proba(X)


def _make_multiclass():
return make_classification(n_classes=3, n_clusters_per_class=1)

Expand All @@ -1330,7 +1317,7 @@ def test_gradient_boosting_with_init(gb, dataset_maker, init_estimator):
gb(init=init_est).fit(X, y, sample_weight=sample_weight)

# init does not support sample weights
init_est = _NoSampleWeightWrapper(init_estimator())
init_est = NoSampleWeightWrapper(init_estimator())
gb(init=init_est).fit(X, y) # ok no sample weights
with pytest.raises(ValueError,
match="estimator.*does not support sample weights"):
Expand Down
97 changes: 77 additions & 20 deletions sklearn/ensemble/tests/test_weight_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,29 @@
import numpy as np
import pytest

from scipy.sparse import csc_matrix
from scipy.sparse import csr_matrix
from scipy.sparse import coo_matrix
from scipy.sparse import dok_matrix
from scipy.sparse import lil_matrix

from sklearn.utils.testing import assert_array_equal, assert_array_less
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_raises, assert_raises_regexp

from sklearn.base import BaseEstimator
from sklearn.base import clone
from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import AdaBoostClassifier
from sklearn.ensemble import AdaBoostRegressor
from sklearn.ensemble._weight_boosting import _samme_proba
from scipy.sparse import csc_matrix
from scipy.sparse import csr_matrix
from scipy.sparse import coo_matrix
from scipy.sparse import dok_matrix
from scipy.sparse import lil_matrix
from sklearn.svm import SVC, SVR
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils import shuffle
from sklearn.utils._mocking import NoSampleWeightWrapper
from sklearn import datasets


Expand Down Expand Up @@ -137,9 +142,10 @@ def test_iris():
np.abs(clf_samme.predict_proba(iris.data) - prob_samme))


def test_boston():
@pytest.mark.parametrize('loss', ['linear', 'square', 'exponential'])
def test_boston(loss):
# Check consistency on dataset boston house prices.
reg = AdaBoostRegressor(random_state=0)
reg = AdaBoostRegressor(loss=loss, random_state=0)
reg.fit(boston.data, boston.target)
score = reg.score(boston.data, boston.target)
assert score > 0.85
Expand Down Expand Up @@ -304,16 +310,6 @@ def test_base_estimator():
clf.fit, X_fail, y_fail)


def test_sample_weight_missing():
from sklearn.cluster import KMeans

clf = AdaBoostClassifier(KMeans(), algorithm="SAMME")
assert_raises(ValueError, clf.fit, X, y_regr)

clf = AdaBoostRegressor(KMeans())
assert_raises(ValueError, clf.fit, X, y_regr)


def test_sparse_classification():
# Check classification with sparse input.

Expand Down Expand Up @@ -486,9 +482,6 @@ def test_multidimensional_X():
Check that the AdaBoost estimators can work with n-dimensional
data matrix
"""

from sklearn.dummy import DummyClassifier, DummyRegressor

rng = np.random.RandomState(0)

X = rng.randn(50, 3, 3)
Expand All @@ -505,6 +498,56 @@ def test_multidimensional_X():
boost.predict(X)


@pytest.mark.parametrize("algorithm", ['SAMME', 'SAMME.R'])
def test_adaboostclassifier_without_sample_weight(algorithm):
X, y = iris.data, iris.target
base_estimator = NoSampleWeightWrapper(DummyClassifier())
clf = AdaBoostClassifier(
base_estimator=base_estimator, algorithm=algorithm
)
err_msg = ("{} doesn't support sample_weight"
.format(base_estimator.__class__.__name__))
with pytest.raises(ValueError, match=err_msg):
clf.fit(X, y)


def test_adaboostregressor_sample_weight():
# check that giving weight will have an influence on the error computed
# for a weak learner
rng = np.random.RandomState(42)
X = np.linspace(0, 100, num=1000)
y = (.8 * X + 0.2) + (rng.rand(X.shape[0]) * 0.0001)
X = X.reshape(-1, 1)

# add an arbitrary outlier
X[-1] *= 10
y[-1] = 10000

# random_state=0 ensure that the underlying boostrap will use the outlier
regr_no_outlier = AdaBoostRegressor(
base_estimator=LinearRegression(), n_estimators=1, random_state=0
)
regr_with_weight = clone(regr_no_outlier)
regr_with_outlier = clone(regr_no_outlier)

# fit 3 models:
# - a model containing the outlier
# - a model without the outlier
# - a model containing the outlier but with a null sample-weight
regr_with_outlier.fit(X, y)
regr_no_outlier.fit(X[:-1], y[:-1])
sample_weight = np.ones_like(y)
sample_weight[-1] = 0
regr_with_weight.fit(X, y, sample_weight=sample_weight)

score_with_outlier = regr_with_outlier.score(X[:-1], y[:-1])
score_no_outlier = regr_no_outlier.score(X[:-1], y[:-1])
score_with_weight = regr_with_weight.score(X[:-1], y[:-1])

assert score_with_outlier < score_no_outlier
assert score_with_outlier < score_with_weight
assert score_no_outlier == pytest.approx(score_with_weight)

@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"])
def test_adaboost_consistent_predict(algorithm):
# check that predict_proba and predict give consistent results
Expand All @@ -520,3 +563,17 @@ def test_adaboost_consistent_predict(algorithm):
np.argmax(model.predict_proba(X_test), axis=1),
model.predict(X_test)
)


@pytest.mark.parametrize(
'model, X, y',
[(AdaBoostClassifier(), iris.data, iris.target),
(AdaBoostRegressor(), boston.data, boston.target)]
)
def test_adaboost_negative_weight_error(model, X, y):
sample_weight = np.ones_like(y)
sample_weight[-1] = -10

err_msg = "sample_weight cannot contain negative weight"
with pytest.raises(ValueError, match=err_msg):
model.fit(X, y, sample_weight=sample_weight)
24 changes: 24 additions & 0 deletions sklearn/utils/_mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,27 @@ def score(self, X=None, Y=None):

def _more_tags(self):
return {'_skip_test': True, 'X_types': ['1dlabel']}


class NoSampleWeightWrapper(BaseEstimator):
"""Wrap estimator which will not expose `sample_weight`.

Parameters
----------
est : estimator, default=None
The estimator to wrap.
"""
def __init__(self, est=None):
self.est = est

def fit(self, X, y):
return self.est.fit(X, y)

def predict(self, X):
return self.est.predict(X)

def predict_proba(self, X):
return self.est.predict_proba(X)

def _more_tags(self):
return {'_skip_test': True} # pragma: no cover
4 changes: 2 additions & 2 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,8 +1061,8 @@ def _check_sample_weight(sample_weight, X, dtype=None):
if dtype is None:
dtype = [np.float64, np.float32]
sample_weight = check_array(
sample_weight, accept_sparse=False,
ensure_2d=False, dtype=dtype, order="C"
sample_weight, accept_sparse=False, ensure_2d=False, dtype=dtype,
order="C"
)
if sample_weight.ndim != 1:
raise ValueError("Sample weights must be 1D array or scalar")
Expand Down