|
22 | 22 |
|
23 | 23 | import sklearn
|
24 | 24 | from sklearn import clone, datasets
|
25 |
| -from sklearn.datasets import make_classification, make_hastie_10_2 |
| 25 | +from sklearn.base import is_classifier |
| 26 | +from sklearn.datasets import make_classification, make_hastie_10_2, make_regression |
26 | 27 | from sklearn.decomposition import TruncatedSVD
|
27 | 28 | from sklearn.dummy import DummyRegressor
|
28 | 29 | from sklearn.ensemble import (
|
|
46 | 47 | from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split
|
47 | 48 | from sklearn.svm import LinearSVC
|
48 | 49 | from sklearn.tree._classes import SPARSE_SPLITTERS
|
| 50 | +from sklearn.utils import shuffle |
49 | 51 | from sklearn.utils._testing import (
|
50 | 52 | _convert_container,
|
51 | 53 | assert_allclose,
|
|
55 | 57 | ignore_warnings,
|
56 | 58 | skip_if_no_parallel,
|
57 | 59 | )
|
| 60 | +from sklearn.utils.estimator_checks import ( |
| 61 | + _enforce_estimator_tags_X, |
| 62 | + _enforce_estimator_tags_y, |
| 63 | +) |
58 | 64 | from sklearn.utils.fixes import COO_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS
|
59 | 65 | from sklearn.utils.multiclass import type_of_target
|
60 | 66 | from sklearn.utils.parallel import Parallel
|
@@ -1973,6 +1979,159 @@ def test_importance_reg_match_onehot_classi(global_random_seed):
|
1973 | 1979 | )
|
1974 | 1980 |
|
1975 | 1981 |
|
| 1982 | +@pytest.mark.parametrize("est_name", FOREST_CLASSIFIERS_REGRESSORS) |
| 1983 | +def test_feature_importance_with_sample_weights(est_name, global_random_seed): |
| 1984 | + # From https://github.com/snath-xoc/sample-weight-audit-nondet/blob/main/src/sample_weight_audit/data.py#L53 |
| 1985 | + |
| 1986 | + # Strategy: sample 2 datasets, each with n_features // 2: |
| 1987 | + # - the first one has int(0.8 * n_samples) but mostly zero or one weights. |
| 1988 | + # - the second one has the remaining samples but with higher weights. |
| 1989 | + # |
| 1990 | + # The features of the two datasets are horizontally stacked with random |
| 1991 | + # feature values sampled independently from the other dataset. Then the two |
| 1992 | + # datasets are vertically stacked and the result is shuffled. |
| 1993 | + # |
| 1994 | + # The sum of weights of the second dataset is 10 times the sum of weights of |
| 1995 | + # the first dataset so that weight aware estimators should mostly ignore the |
| 1996 | + # features of the first dataset to learn their prediction function. |
| 1997 | + n_samples = 250 |
| 1998 | + n_features = 4 |
| 1999 | + n_classes = 2 |
| 2000 | + max_sample_weight = 5 |
| 2001 | + |
| 2002 | + rng = check_random_state(global_random_seed) |
| 2003 | + n_samples_sw = int(0.5 * n_samples) # small weights |
| 2004 | + n_samples_lw = n_samples - n_samples_sw # large weights |
| 2005 | + n_features_sw = n_features // 2 |
| 2006 | + n_features_lw = n_features - n_features_sw |
| 2007 | + |
| 2008 | + # Construct the sample weights: mostly zeros and some ones for the first |
| 2009 | + # dataset, and some random integers larger than one for the second dataset. |
| 2010 | + sample_weight_sw = np.where(rng.random(n_samples_sw) < 0.2, 1, 0) |
| 2011 | + sample_weight_lw = rng.randint(2, max_sample_weight, size=n_samples_lw) |
| 2012 | + total_weight_sum = np.sum(sample_weight_sw) + np.sum(sample_weight_lw) |
| 2013 | + assert np.sum(sample_weight_sw) < 0.3 * total_weight_sum |
| 2014 | + |
| 2015 | + est = FOREST_CLASSIFIERS_REGRESSORS[est_name]( |
| 2016 | + n_estimators=50, |
| 2017 | + bootstrap=True, |
| 2018 | + oob_score=True, |
| 2019 | + random_state=rng, |
| 2020 | + ) |
| 2021 | + if not is_classifier(est): |
| 2022 | + X_sw, y_sw = make_regression( |
| 2023 | + n_samples=n_samples_sw, |
| 2024 | + n_features=n_features_sw, |
| 2025 | + random_state=rng, |
| 2026 | + ) |
| 2027 | + X_lw, y_lw = make_regression( |
| 2028 | + n_samples=n_samples_lw, |
| 2029 | + n_features=n_features_lw, |
| 2030 | + random_state=rng, # rng is different because mutated |
| 2031 | + ) |
| 2032 | + else: |
| 2033 | + X_sw, y_sw = make_classification( |
| 2034 | + n_samples=n_samples_sw, |
| 2035 | + n_features=n_features_sw, |
| 2036 | + n_informative=n_features_sw, |
| 2037 | + n_redundant=0, |
| 2038 | + n_repeated=0, |
| 2039 | + n_classes=n_classes, |
| 2040 | + random_state=rng, |
| 2041 | + ) |
| 2042 | + X_lw, y_lw = make_classification( |
| 2043 | + n_samples=n_samples_lw, |
| 2044 | + n_features=n_features_lw, |
| 2045 | + n_informative=n_features_lw, |
| 2046 | + n_redundant=0, |
| 2047 | + n_repeated=0, |
| 2048 | + n_classes=n_classes, |
| 2049 | + random_state=rng, # rng is different because mutated |
| 2050 | + ) |
| 2051 | + |
| 2052 | + # Horizontally pad the features with features values marginally sampled |
| 2053 | + # from the other dataset. |
| 2054 | + pad_sw_idx = rng.choice(n_samples_lw, size=n_samples_sw, replace=True) |
| 2055 | + X_sw_padded = np.hstack([X_sw, np.take(X_lw, pad_sw_idx, axis=0)]) |
| 2056 | + |
| 2057 | + pad_lw_idx = rng.choice(n_samples_sw, size=n_samples_lw, replace=True) |
| 2058 | + X_lw_padded = np.hstack([np.take(X_sw, pad_lw_idx, axis=0), X_lw]) |
| 2059 | + |
| 2060 | + # Vertically stack the two datasets and shuffle them. |
| 2061 | + X = np.concatenate([X_sw_padded, X_lw_padded], axis=0) |
| 2062 | + y = np.concatenate([y_sw, y_lw]) |
| 2063 | + |
| 2064 | + X = _enforce_estimator_tags_X(est, X) |
| 2065 | + y = _enforce_estimator_tags_y(est, y) |
| 2066 | + sample_weight = np.concatenate([sample_weight_sw, sample_weight_lw]) |
| 2067 | + X, y, sample_weight = shuffle(X, y, sample_weight, random_state=rng) |
| 2068 | + |
| 2069 | + est.fit(X, y, sample_weight) |
| 2070 | + |
| 2071 | + ufi_feature_importance = est.ufi_feature_importances_ |
| 2072 | + mdi_oob_feature_importance = est.mdi_oob_feature_importances_ |
| 2073 | + assert ( |
| 2074 | + ufi_feature_importance[:n_features_sw].sum() |
| 2075 | + < ufi_feature_importance[n_features_sw:].sum() |
| 2076 | + ) |
| 2077 | + assert ( |
| 2078 | + mdi_oob_feature_importance[:n_features_sw].sum() |
| 2079 | + < mdi_oob_feature_importance[n_features_sw:].sum() |
| 2080 | + ) |
| 2081 | + |
| 2082 | + |
| 2083 | +@pytest.mark.parametrize("est_name", FOREST_CLASSIFIERS_REGRESSORS) |
| 2084 | +def test_feature_importance_sample_weight_equals_repeated(est_name, global_random_seed): |
| 2085 | + # check that setting sample_weight to zero / integer is equivalent |
| 2086 | + # to removing / repeating corresponding samples. |
| 2087 | + params = dict( |
| 2088 | + n_estimators=100, |
| 2089 | + bootstrap=True, |
| 2090 | + oob_score=True, |
| 2091 | + max_features=1.0, |
| 2092 | + random_state=global_random_seed, |
| 2093 | + ) |
| 2094 | + |
| 2095 | + est_weighted = FOREST_CLASSIFIERS_REGRESSORS[est_name](**params) |
| 2096 | + est_repeated = FOREST_CLASSIFIERS_REGRESSORS[est_name](**params) |
| 2097 | + |
| 2098 | + n_samples = 100 |
| 2099 | + n_features = 2 |
| 2100 | + X, y = make_classification( |
| 2101 | + n_samples=n_samples, |
| 2102 | + n_features=n_features, |
| 2103 | + n_informative=n_features, |
| 2104 | + n_redundant=0, |
| 2105 | + ) |
| 2106 | + # Use random integers (including zero) as weights. |
| 2107 | + sw = rng.randint(0, 2, size=n_samples) |
| 2108 | + |
| 2109 | + X_weighted = X |
| 2110 | + y_weighted = y |
| 2111 | + # repeat samples according to weights |
| 2112 | + X_repeated = X_weighted.repeat(repeats=sw, axis=0) |
| 2113 | + y_repeated = y_weighted.repeat(repeats=sw) |
| 2114 | + |
| 2115 | + X_weighted, y_weighted, sw = shuffle(X_weighted, y_weighted, sw, random_state=0) |
| 2116 | + |
| 2117 | + est_repeated.fit(X_repeated, y=y_repeated, sample_weight=None) |
| 2118 | + est_weighted.fit(X_weighted, y=y_weighted, sample_weight=sw) |
| 2119 | + |
| 2120 | + assert_allclose( |
| 2121 | + est_repeated.feature_importances_, est_weighted.feature_importances_, atol=1e-1 |
| 2122 | + ) |
| 2123 | + assert_allclose( |
| 2124 | + est_repeated.ufi_feature_importances_, |
| 2125 | + est_weighted.ufi_feature_importances_, |
| 2126 | + atol=1e-1, |
| 2127 | + ) |
| 2128 | + assert_allclose( |
| 2129 | + est_repeated.mdi_oob_feature_importances_, |
| 2130 | + est_weighted.mdi_oob_feature_importances_, |
| 2131 | + atol=1e-1, |
| 2132 | + ) |
| 2133 | + |
| 2134 | + |
1976 | 2135 | @pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS)
|
1977 | 2136 | def test_max_samples_bootstrap(name):
|
1978 | 2137 | # Check invalid `max_samples` values
|
|
0 commit comments