Skip to content

Commit 241de66

Browse files
committed
add sample weight tests
1 parent 6fcf61c commit 241de66

File tree

1 file changed

+160
-1
lines changed

1 file changed

+160
-1
lines changed

sklearn/ensemble/tests/test_forest.py

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323
import sklearn
2424
from sklearn import clone, datasets
25-
from sklearn.datasets import make_classification, make_hastie_10_2
25+
from sklearn.base import is_classifier
26+
from sklearn.datasets import make_classification, make_hastie_10_2, make_regression
2627
from sklearn.decomposition import TruncatedSVD
2728
from sklearn.dummy import DummyRegressor
2829
from sklearn.ensemble import (
@@ -46,6 +47,7 @@
4647
from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split
4748
from sklearn.svm import LinearSVC
4849
from sklearn.tree._classes import SPARSE_SPLITTERS
50+
from sklearn.utils import shuffle
4951
from sklearn.utils._testing import (
5052
_convert_container,
5153
assert_allclose,
@@ -55,6 +57,10 @@
5557
ignore_warnings,
5658
skip_if_no_parallel,
5759
)
60+
from sklearn.utils.estimator_checks import (
61+
_enforce_estimator_tags_X,
62+
_enforce_estimator_tags_y,
63+
)
5864
from sklearn.utils.fixes import COO_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS
5965
from sklearn.utils.multiclass import type_of_target
6066
from sklearn.utils.parallel import Parallel
@@ -1973,6 +1979,159 @@ def test_importance_reg_match_onehot_classi(global_random_seed):
19731979
)
19741980

19751981

1982+
@pytest.mark.parametrize("est_name", FOREST_CLASSIFIERS_REGRESSORS)
1983+
def test_feature_importance_with_sample_weights(est_name, global_random_seed):
1984+
# From https://github.com/snath-xoc/sample-weight-audit-nondet/blob/main/src/sample_weight_audit/data.py#L53
1985+
1986+
# Strategy: sample 2 datasets, each with n_features // 2:
1987+
# - the first one has int(0.8 * n_samples) but mostly zero or one weights.
1988+
# - the second one has the remaining samples but with higher weights.
1989+
#
1990+
# The features of the two datasets are horizontally stacked with random
1991+
# feature values sampled independently from the other dataset. Then the two
1992+
# datasets are vertically stacked and the result is shuffled.
1993+
#
1994+
# The sum of weights of the second dataset is 10 times the sum of weights of
1995+
# the first dataset so that weight aware estimators should mostly ignore the
1996+
# features of the first dataset to learn their prediction function.
1997+
n_samples = 250
1998+
n_features = 4
1999+
n_classes = 2
2000+
max_sample_weight = 5
2001+
2002+
rng = check_random_state(global_random_seed)
2003+
n_samples_sw = int(0.5 * n_samples) # small weights
2004+
n_samples_lw = n_samples - n_samples_sw # large weights
2005+
n_features_sw = n_features // 2
2006+
n_features_lw = n_features - n_features_sw
2007+
2008+
# Construct the sample weights: mostly zeros and some ones for the first
2009+
# dataset, and some random integers larger than one for the second dataset.
2010+
sample_weight_sw = np.where(rng.random(n_samples_sw) < 0.2, 1, 0)
2011+
sample_weight_lw = rng.randint(2, max_sample_weight, size=n_samples_lw)
2012+
total_weight_sum = np.sum(sample_weight_sw) + np.sum(sample_weight_lw)
2013+
assert np.sum(sample_weight_sw) < 0.3 * total_weight_sum
2014+
2015+
est = FOREST_CLASSIFIERS_REGRESSORS[est_name](
2016+
n_estimators=50,
2017+
bootstrap=True,
2018+
oob_score=True,
2019+
random_state=rng,
2020+
)
2021+
if not is_classifier(est):
2022+
X_sw, y_sw = make_regression(
2023+
n_samples=n_samples_sw,
2024+
n_features=n_features_sw,
2025+
random_state=rng,
2026+
)
2027+
X_lw, y_lw = make_regression(
2028+
n_samples=n_samples_lw,
2029+
n_features=n_features_lw,
2030+
random_state=rng, # rng is different because mutated
2031+
)
2032+
else:
2033+
X_sw, y_sw = make_classification(
2034+
n_samples=n_samples_sw,
2035+
n_features=n_features_sw,
2036+
n_informative=n_features_sw,
2037+
n_redundant=0,
2038+
n_repeated=0,
2039+
n_classes=n_classes,
2040+
random_state=rng,
2041+
)
2042+
X_lw, y_lw = make_classification(
2043+
n_samples=n_samples_lw,
2044+
n_features=n_features_lw,
2045+
n_informative=n_features_lw,
2046+
n_redundant=0,
2047+
n_repeated=0,
2048+
n_classes=n_classes,
2049+
random_state=rng, # rng is different because mutated
2050+
)
2051+
2052+
# Horizontally pad the features with features values marginally sampled
2053+
# from the other dataset.
2054+
pad_sw_idx = rng.choice(n_samples_lw, size=n_samples_sw, replace=True)
2055+
X_sw_padded = np.hstack([X_sw, np.take(X_lw, pad_sw_idx, axis=0)])
2056+
2057+
pad_lw_idx = rng.choice(n_samples_sw, size=n_samples_lw, replace=True)
2058+
X_lw_padded = np.hstack([np.take(X_sw, pad_lw_idx, axis=0), X_lw])
2059+
2060+
# Vertically stack the two datasets and shuffle them.
2061+
X = np.concatenate([X_sw_padded, X_lw_padded], axis=0)
2062+
y = np.concatenate([y_sw, y_lw])
2063+
2064+
X = _enforce_estimator_tags_X(est, X)
2065+
y = _enforce_estimator_tags_y(est, y)
2066+
sample_weight = np.concatenate([sample_weight_sw, sample_weight_lw])
2067+
X, y, sample_weight = shuffle(X, y, sample_weight, random_state=rng)
2068+
2069+
est.fit(X, y, sample_weight)
2070+
2071+
ufi_feature_importance = est.ufi_feature_importances_
2072+
mdi_oob_feature_importance = est.mdi_oob_feature_importances_
2073+
assert (
2074+
ufi_feature_importance[:n_features_sw].sum()
2075+
< ufi_feature_importance[n_features_sw:].sum()
2076+
)
2077+
assert (
2078+
mdi_oob_feature_importance[:n_features_sw].sum()
2079+
< mdi_oob_feature_importance[n_features_sw:].sum()
2080+
)
2081+
2082+
2083+
@pytest.mark.parametrize("est_name", FOREST_CLASSIFIERS_REGRESSORS)
2084+
def test_feature_importance_sample_weight_equals_repeated(est_name, global_random_seed):
2085+
# check that setting sample_weight to zero / integer is equivalent
2086+
# to removing / repeating corresponding samples.
2087+
params = dict(
2088+
n_estimators=100,
2089+
bootstrap=True,
2090+
oob_score=True,
2091+
max_features=1.0,
2092+
random_state=global_random_seed,
2093+
)
2094+
2095+
est_weighted = FOREST_CLASSIFIERS_REGRESSORS[est_name](**params)
2096+
est_repeated = FOREST_CLASSIFIERS_REGRESSORS[est_name](**params)
2097+
2098+
n_samples = 100
2099+
n_features = 2
2100+
X, y = make_classification(
2101+
n_samples=n_samples,
2102+
n_features=n_features,
2103+
n_informative=n_features,
2104+
n_redundant=0,
2105+
)
2106+
# Use random integers (including zero) as weights.
2107+
sw = rng.randint(0, 2, size=n_samples)
2108+
2109+
X_weighted = X
2110+
y_weighted = y
2111+
# repeat samples according to weights
2112+
X_repeated = X_weighted.repeat(repeats=sw, axis=0)
2113+
y_repeated = y_weighted.repeat(repeats=sw)
2114+
2115+
X_weighted, y_weighted, sw = shuffle(X_weighted, y_weighted, sw, random_state=0)
2116+
2117+
est_repeated.fit(X_repeated, y=y_repeated, sample_weight=None)
2118+
est_weighted.fit(X_weighted, y=y_weighted, sample_weight=sw)
2119+
2120+
assert_allclose(
2121+
est_repeated.feature_importances_, est_weighted.feature_importances_, atol=1e-1
2122+
)
2123+
assert_allclose(
2124+
est_repeated.ufi_feature_importances_,
2125+
est_weighted.ufi_feature_importances_,
2126+
atol=1e-1,
2127+
)
2128+
assert_allclose(
2129+
est_repeated.mdi_oob_feature_importances_,
2130+
est_weighted.mdi_oob_feature_importances_,
2131+
atol=1e-1,
2132+
)
2133+
2134+
19762135
@pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS)
19772136
def test_max_samples_bootstrap(name):
19782137
# Check invalid `max_samples` values

0 commit comments

Comments
 (0)