From 1b0a9607c97e438741c330cb5ec5760fc927ec2c Mon Sep 17 00:00:00 2001 From: Pat O'Reilly Date: Fri, 26 Aug 2016 16:53:07 +0100 Subject: [PATCH 001/174] Monotonic features in tree-based models DecisionTrees RandomForests GradientBoosting --- sklearn/ensemble/forest.py | 68 ++++++++++++++++++++--- sklearn/ensemble/gradient_boosting.py | 28 ++++++++-- sklearn/tree/_splitter.pxd | 5 ++ sklearn/tree/_splitter.pyx | 61 +++++++++++++++++++- sklearn/tree/tree.py | 80 +++++++++++++++++++++++---- 5 files changed, 214 insertions(+), 28 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 4d4f04bc12408..ac7fb9a65e0d0 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -854,6 +854,12 @@ class RandomForestClassifier(ForestClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + increasing : list of ints, optional (default=None) + Indices of features to have a monotonically increasing effect. + + decreasing : list of ints, optional (default=None) + Indices of features to have a monotonically decreasing effect. + Attributes ---------- estimators_ : list of DecisionTreeClassifier @@ -910,7 +916,9 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - class_weight=None): + class_weight=None, + increasing=None, + decreasing=None): super(RandomForestClassifier, self).__init__( base_estimator=DecisionTreeClassifier(), n_estimators=n_estimators, @@ -934,6 +942,8 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split + self.increasing = increasing + self.decreasing = decreasing class RandomForestRegressor(ForestRegressor): @@ -1039,6 +1049,12 @@ class RandomForestRegressor(ForestRegressor): and add more estimators to the ensemble, otherwise, just fit a whole new forest. + increasing : list of ints, optional (default=None) + Indices of features to have a monotonically increasing effect. + + decreasing : list of ints, optional (default=None) + Indices of features to have a monotonically decreasing effect. + Attributes ---------- estimators_ : list of DecisionTreeRegressor @@ -1083,14 +1099,16 @@ def __init__(self, n_jobs=1, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + increasing=None, + decreasing=None): super(RandomForestRegressor, self).__init__( base_estimator=DecisionTreeRegressor(), n_estimators=n_estimators, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "increasing", "decreasing"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1106,6 +1124,8 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split + self.increasing = increasing + self.decreasing = decreasing class ExtraTreesClassifier(ForestClassifier): @@ -1222,6 +1242,12 @@ class ExtraTreesClassifier(ForestClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + increasing : list of ints, optional (default=None) + Indices of features to have a monotonically increasing effect. + + decreasing : list of ints, optional (default=None) + Indices of features to have a monotonically decreasing effect. + Attributes ---------- estimators_ : list of DecisionTreeClassifier @@ -1281,14 +1307,16 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - class_weight=None): + class_weight=None, + increasing=None, + decreasing=None): super(ExtraTreesClassifier, self).__init__( base_estimator=ExtraTreeClassifier(), n_estimators=n_estimators, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "increasing", "decreasing"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1305,6 +1333,8 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split + self.increasing = increasing + self.decreasing = decreasing class ExtraTreesRegressor(ForestRegressor): @@ -1407,6 +1437,12 @@ class ExtraTreesRegressor(ForestRegressor): and add more estimators to the ensemble, otherwise, just fit a whole new forest. + increasing : list of ints, optional (default=None) + Indices of features to have a monotonically increasing effect. + + decreasing : list of ints, optional (default=None) + Indices of features to have a monotonically decreasing effect. + Attributes ---------- estimators_ : list of DecisionTreeRegressor @@ -1453,14 +1489,16 @@ def __init__(self, n_jobs=1, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + increasing=None, + decreasin=None): super(ExtraTreesRegressor, self).__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "increasing", "decreasing"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1476,6 +1514,8 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split + self.increasing = increasing + self.decreasing = decreasing class RandomTreesEmbedding(BaseForest): @@ -1556,6 +1596,12 @@ class RandomTreesEmbedding(BaseForest): and add more estimators to the ensemble, otherwise, just fit a whole new forest. + increasing : list of ints, optional (default=None) + Indices of features to have a monotonically increasing effect. + + decreasing : list of ints, optional (default=None) + Indices of features to have a monotonically decreasing effect. + Attributes ---------- estimators_ : list of DecisionTreeClassifier @@ -1583,14 +1629,16 @@ def __init__(self, n_jobs=1, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + increasing=None, + decreasing=None): super(RandomTreesEmbedding, self).__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "increasing", "decreasing"), bootstrap=False, oob_score=False, n_jobs=n_jobs, @@ -1607,6 +1655,8 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split self.sparse_output = sparse_output + self.increasing = increasing + self.decreasing = decreasing def _set_oob_score(self, X, y): raise NotImplementedError("OOB score not supported by tree embedding") diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 4ea8ef8e4e779..7e26e20232c27 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -724,7 +724,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, min_samples_split, min_samples_leaf, min_weight_fraction_leaf, max_depth, min_impurity_split, init, subsample, max_features, random_state, alpha=0.9, verbose=0, max_leaf_nodes=None, - warm_start=False, presort='auto'): + warm_start=False, presort='auto', increasing=None, decreasing=None): self.n_estimators = n_estimators self.learning_rate = learning_rate @@ -744,6 +744,8 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, self.max_leaf_nodes = max_leaf_nodes self.warm_start = warm_start self.presort = presort + self.increasing = increasing + self.decreasing = decreasing self.estimators_ = np.empty((0, 0), dtype=np.object) @@ -773,7 +775,9 @@ def _fit_stage(self, i, X, y, y_pred, sample_weight, sample_mask, max_features=self.max_features, max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, - presort=self.presort) + presort=self.presort, + increasing=self.increasing, + decreasing=self.decreasing) if self.subsample < 1.0: # no inplace multiplication! @@ -1395,6 +1399,12 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): .. versionadded:: 0.17 *presort* parameter. + increasing : list of ints, optional (default=None) + Indices of features to have a monotonically increasing effect. + + decreasing : list of ints, optional (default=None) + Indices of features to have a monotonically decreasing effect. + Attributes ---------- feature_importances_ : array, shape = [n_features] @@ -1447,7 +1457,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, max_depth=3, min_impurity_split=1e-7, init=None, random_state=None, max_features=None, verbose=0, max_leaf_nodes=None, warm_start=False, - presort='auto'): + presort='auto', increasing=None, decreasing=None): super(GradientBoostingClassifier, self).__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -1460,7 +1470,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, max_leaf_nodes=max_leaf_nodes, min_impurity_split=min_impurity_split, warm_start=warm_start, - presort=presort) + presort=presort, increasing=increasing, decreasing=decreasing) def _validate_y(self, y): check_classification_targets(y) @@ -1761,6 +1771,12 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): .. versionadded:: 0.17 optional parameter *presort*. + increasing : list of ints, optional (default=None) + Indices of features to have a monotonically increasing effect. + + decreasing : list of ints, optional (default=None) + Indices of features to have a monotonically decreasing effect. + Attributes ---------- feature_importances_ : array, shape = [n_features] @@ -1809,7 +1825,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, min_samples_leaf=1, min_weight_fraction_leaf=0., max_depth=3, min_impurity_split=1e-7, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, - warm_start=False, presort='auto'): + warm_start=False, presort='auto', increasing=None, decreasing=None): super(GradientBoostingRegressor, self).__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -1820,7 +1836,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, max_features=max_features, min_impurity_split=min_impurity_split, random_state=random_state, alpha=alpha, verbose=verbose, max_leaf_nodes=max_leaf_nodes, warm_start=warm_start, - presort=presort) + presort=presort, increasing=increasing, decreasing=decreasing) def predict(self, X): """Predict regression target for X. diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 4b97d01614b9f..52951fa6b184d 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -60,6 +60,11 @@ cdef class Splitter: cdef bint presort # Whether to use presorting, only # allowed on dense data + cdef INT32_t* monotonic # Monotonicity constraints + # -1: monotonically decreasing + # 0: no constraint + # +1: monotonically increasing + cdef DOUBLE_t* y cdef SIZE_t y_stride cdef DOUBLE_t* sample_weight diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 0617508aab236..5ffaf638b46b3 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -60,7 +60,7 @@ cdef class Splitter: def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, np.ndarray[INT32_t] monotonic): """ Parameters ---------- @@ -101,6 +101,7 @@ cdef class Splitter: self.min_weight_leaf = min_weight_leaf self.random_state = random_state self.presort = presort + self.monotonic = monotonic.data def __dealloc__(self): """Destructor.""" @@ -242,7 +243,7 @@ cdef class BaseDenseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, np.ndarray[INT32_t] monotonic): self.X = NULL self.X_sample_stride = 0 @@ -251,6 +252,7 @@ cdef class BaseDenseSplitter(Splitter): self.X_idx_sorted_stride = 0 self.sample_mask = NULL self.presort = presort + self.monotonic = monotonic.data def __dealloc__(self): """Destructor.""" @@ -396,6 +398,8 @@ cdef class BestSplitter(BaseDenseSplitter): current.feature = features[f_j] feature_offset = self.X_feature_stride * current.feature + monotonic_constraint = self.monotonic[current.feature] + # Sort samples along that feature; either by utilizing # presorting, or by copying the values into an array and # sorting the array in a manner which utilizes the cache more @@ -457,6 +461,17 @@ cdef class BestSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # Monotonic check + if monotonic_constraint != 0: + # TODO: is 0.0 divisor possible? + left = self.criterion.sum_left[0]/self.criterion.weighted_n_left + right = self.criterion.sum_right[0]/self.criterion.weighted_n_right + if monotonic_constraint == -1: + if left < right: + continue + elif left > right: + continue + current_proxy_improvement = self.criterion.proxy_impurity_improvement() if current_proxy_improvement > best_proxy_improvement: @@ -723,6 +738,8 @@ cdef class RandomSplitter(BaseDenseSplitter): current.feature = features[f_j] feature_stride = X_feature_stride * current.feature + monotonic_constraint = self.monotonic[current.feature] + # Find min, max min_feature_value = X[X_sample_stride * samples[start] + feature_stride] max_feature_value = min_feature_value @@ -789,6 +806,17 @@ cdef class RandomSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # Monotonic check + if monotonic_constraint != 0: + # TODO: is 0.0 divisor possible? + left = self.criterion.sum_left[0]/self.criterion.weighted_n_left + right = self.criterion.sum_right[0]/self.criterion.weighted_n_right + if monotonic_constraint == -1: + if left < right: + continue + elif left > right: + continue + current_proxy_improvement = self.criterion.proxy_impurity_improvement() if current_proxy_improvement > best_proxy_improvement: @@ -848,7 +876,7 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, np.ndarray[INT32_t] monotonic): # Parent __cinit__ is automatically called self.X_data = NULL @@ -859,6 +887,7 @@ cdef class BaseSparseSplitter(Splitter): self.index_to_samples = NULL self.sorted_samples = NULL + self.monotonic = monotonic.data def __dealloc__(self): """Deallocate memory.""" @@ -1269,6 +1298,8 @@ cdef class BestSparseSplitter(BaseSparseSplitter): &end_negative, &start_positive, &is_samples_sorted) + monotonic_constraint = self.monotonic[current.feature] + # Sort the positive and negative parts of `Xf` sort(Xf + start, samples + start, end_negative - start) sort(Xf + start_positive, samples + start_positive, @@ -1342,6 +1373,17 @@ cdef class BestSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # Monotonic check + if monotonic_constraint != 0: + # TODO: is 0.0 divisor possible? + left = self.criterion.sum_left[0]/self.criterion.weighted_n_left + right = self.criterion.sum_right[0]/self.criterion.weighted_n_right + if monotonic_constraint == -1: + if left < right: + continue + elif left > right: + continue + current_proxy_improvement = self.criterion.proxy_impurity_improvement() if current_proxy_improvement > best_proxy_improvement: @@ -1500,6 +1542,8 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): &end_negative, &start_positive, &is_samples_sorted) + monotonic_constraint = self.monotonic[current.feature] + # Add one or two zeros in Xf, if there is any if end_negative < start_positive: start_positive -= 1 @@ -1570,6 +1614,17 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # Monotonic check + if monotonic_constraint != 0: + # TODO: is 0.0 divisor possible? + left = self.criterion.sum_left[0]/self.criterion.weighted_n_left + right = self.criterion.sum_right[0]/self.criterion.weighted_n_right + if monotonic_constraint == -1: + if left < right: + continue + elif left > right: + continue + current_proxy_improvement = self.criterion.proxy_impurity_improvement() if current_proxy_improvement > best_proxy_improvement: diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 1ec678931752a..a320c01577f2f 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -92,7 +92,9 @@ def __init__(self, random_state, min_impurity_split, class_weight=None, - presort=False): + presort=False, + increasing=None, + decreasing=None): self.criterion = criterion self.splitter = splitter self.max_depth = max_depth @@ -105,6 +107,8 @@ def __init__(self, self.min_impurity_split = min_impurity_split self.class_weight = class_weight self.presort = presort + self.increasing = increasing + self.decreasing = decreasing self.n_features_ = None self.n_outputs_ = None @@ -350,6 +354,33 @@ def fit(self, X, y, sample_weight=None, check_input=True, SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS + def _encode_monotonic(increasing, decreasing): + if increasing is None: increasing = [] + if decreasing is None: decreasing = [] + def is_int_in_range(feature): + return isinstance(feature, int) and 0 <= feature < self.n_features_ + def is_valid(features): + return (isinstance(features, list) and + all(is_int_in_range(feature) for feature in features)) + if not is_valid(increasing): + raise ValueError("increasing should be a list of ints in the range [0,n_features].") + if not is_valid(decreasing): + raise ValueError("decreasing should be a list of ints in the range [0,n_features].") + if increasing and decreasing: + intersection = set(increasing) & set(decreasing) + if intersection: + raise ValueError("The following features cannot be both increasing and decreasing: " + str(list(intersection))) + monotonic = np.zeros(self.n_features_, dtype=np.int32) + if increasing: + for feature in increasing: + monotonic[feature] = 1 + if decreasing: + for feature in decreasing: + monotonic[feature] = -1 + return monotonic + + monotonic = _encode_monotonic(self.increasing, self.decreasing) + splitter = self.splitter if not isinstance(self.splitter, Splitter): splitter = SPLITTERS[self.splitter](criterion, @@ -357,7 +388,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, min_samples_leaf, min_weight_leaf, random_state, - self.presort) + self.presort, + monotonic) self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_) @@ -626,6 +658,12 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + increasing : list of ints, optional (default=None) + Indices of features to have a monotonically increasing effect. + + decreasing : list of ints, optional (default=None) + Indices of features to have a monotonically decreasing effect. + Attributes ---------- classes_ : array of shape = [n_classes] or a list of such arrays @@ -698,7 +736,9 @@ def __init__(self, max_leaf_nodes=None, min_impurity_split=1e-7, class_weight=None, - presort=False): + presort=False, + increasing=None, + decreasing=None): super(DecisionTreeClassifier, self).__init__( criterion=criterion, splitter=splitter, @@ -711,7 +751,9 @@ def __init__(self, class_weight=class_weight, random_state=random_state, min_impurity_split=min_impurity_split, - presort=presort) + presort=presort, + increasing=increasing, + decreasing=decreasing) def predict_proba(self, X, check_input=True): """Predict class probabilities of the input samples X. @@ -875,6 +917,12 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + increasing : list of ints, optional (default=None) + Indices of features to have a monotonically increasing effect. + + decreasing : list of ints, optional (default=None) + Indices of features to have a monotonically decreasing effect. + Attributes ---------- feature_importances_ : array of shape = [n_features] @@ -938,7 +986,9 @@ def __init__(self, random_state=None, max_leaf_nodes=None, min_impurity_split=1e-7, - presort=False): + presort=False, + increasing=None, + decreasing=None): super(DecisionTreeRegressor, self).__init__( criterion=criterion, splitter=splitter, @@ -950,7 +1000,9 @@ def __init__(self, max_leaf_nodes=max_leaf_nodes, random_state=random_state, min_impurity_split=min_impurity_split, - presort=presort) + presort=presort, + increasing=increasing, + decreasing=decreasing) class ExtraTreeClassifier(DecisionTreeClassifier): @@ -988,7 +1040,9 @@ def __init__(self, random_state=None, max_leaf_nodes=None, min_impurity_split=1e-7, - class_weight=None): + class_weight=None, + increasing=None, + decreasing=None): super(ExtraTreeClassifier, self).__init__( criterion=criterion, splitter=splitter, @@ -1000,7 +1054,9 @@ def __init__(self, max_leaf_nodes=max_leaf_nodes, class_weight=class_weight, min_impurity_split=min_impurity_split, - random_state=random_state) + random_state=random_state, + increasing=increasing, + decreasing=decreasing) class ExtraTreeRegressor(DecisionTreeRegressor): @@ -1037,7 +1093,9 @@ def __init__(self, max_features="auto", random_state=None, min_impurity_split=1e-7, - max_leaf_nodes=None): + max_leaf_nodes=None, + increasing=None, + decreasing=None): super(ExtraTreeRegressor, self).__init__( criterion=criterion, splitter=splitter, @@ -1048,4 +1106,6 @@ def __init__(self, max_features=max_features, max_leaf_nodes=max_leaf_nodes, min_impurity_split=min_impurity_split, - random_state=random_state) + random_state=random_state, + increasing=increasing, + decreasing=decreasing) From 953651f248ca30261de53ed048af1708a8bfa69d Mon Sep 17 00:00:00 2001 From: Pat O'Reilly Date: Sun, 28 Aug 2016 16:27:53 +0100 Subject: [PATCH 002/174] refactoring monotonic check into a function --- sklearn/tree/_splitter.pxd | 4 ++- sklearn/tree/_splitter.pyx | 65 ++++++++++++++------------------------ 2 files changed, 27 insertions(+), 42 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 52951fa6b184d..ecdf5ee6958f5 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -100,4 +100,6 @@ cdef class Splitter: cdef void node_value(self, double* dest) nogil - cdef double node_impurity(self) nogil \ No newline at end of file + cdef double node_impurity(self) nogil + + cdef inline bint split_passes_monotonic_check(self, INT32_t monotonic_constraint) nogil \ No newline at end of file diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 5ffaf638b46b3..fbc7dd8f5e43a 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -229,6 +229,22 @@ cdef class Splitter: return self.criterion.node_impurity() + cdef inline bint split_passes_monotonic_check(self, INT32_t monotonic_constraint) nogil: + """Check monotonic constraint is satisfied at the current split""" + if monotonic_constraint == 0: # No constraint + return 1 + else: + left = self.criterion.sum_left[0]/self.criterion.weighted_n_left + right = self.criterion.sum_right[0]/self.criterion.weighted_n_right + if monotonic_constraint == -1: # Decreasing constraint + if left < right: # Fails + return 0 + else: + return 1 + elif left > right: # Increasing constraint fails + return 0 + else: + return 1 cdef class BaseDenseSplitter(Splitter): cdef DTYPE_t* X @@ -461,16 +477,8 @@ cdef class BestSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - # Monotonic check - if monotonic_constraint != 0: - # TODO: is 0.0 divisor possible? - left = self.criterion.sum_left[0]/self.criterion.weighted_n_left - right = self.criterion.sum_right[0]/self.criterion.weighted_n_right - if monotonic_constraint == -1: - if left < right: - continue - elif left > right: - continue + if not self.split_passes_monotonic_check(monotonic_constraint): + continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -525,7 +533,6 @@ cdef class BestSplitter(BaseDenseSplitter): split[0] = best n_constant_features[0] = n_total_constants - # Sort n-element arrays pointed to by Xf and samples, simultaneously, # by the values in Xf. Algorithm: Introsort (Musser, SP&E, 1997). cdef inline void sort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil: @@ -806,16 +813,8 @@ cdef class RandomSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - # Monotonic check - if monotonic_constraint != 0: - # TODO: is 0.0 divisor possible? - left = self.criterion.sum_left[0]/self.criterion.weighted_n_left - right = self.criterion.sum_right[0]/self.criterion.weighted_n_right - if monotonic_constraint == -1: - if left < right: - continue - elif left > right: - continue + if not self.split_passes_monotonic_check(monotonic_constraint): + continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1373,16 +1372,8 @@ cdef class BestSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - # Monotonic check - if monotonic_constraint != 0: - # TODO: is 0.0 divisor possible? - left = self.criterion.sum_left[0]/self.criterion.weighted_n_left - right = self.criterion.sum_right[0]/self.criterion.weighted_n_right - if monotonic_constraint == -1: - if left < right: - continue - elif left > right: - continue + if not self.split_passes_monotonic_check(monotonic_constraint): + continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1614,16 +1605,8 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - # Monotonic check - if monotonic_constraint != 0: - # TODO: is 0.0 divisor possible? - left = self.criterion.sum_left[0]/self.criterion.weighted_n_left - right = self.criterion.sum_right[0]/self.criterion.weighted_n_right - if monotonic_constraint == -1: - if left < right: - continue - elif left > right: - continue + if not self.split_passes_monotonic_check(monotonic_constraint): + continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() From 63c3f9c85d9fe7d06f6e7c694bdde06fc73bbf9b Mon Sep 17 00:00:00 2001 From: Pat O'Reilly Date: Fri, 2 Sep 2016 16:46:54 +0100 Subject: [PATCH 003/174] tidy up monotonic check function --- sklearn/tree/_splitter.pyx | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index fbc7dd8f5e43a..e446310f21de9 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -231,20 +231,22 @@ cdef class Splitter: cdef inline bint split_passes_monotonic_check(self, INT32_t monotonic_constraint) nogil: """Check monotonic constraint is satisfied at the current split""" + if monotonic_constraint == 0: # No constraint return 1 else: left = self.criterion.sum_left[0]/self.criterion.weighted_n_left right = self.criterion.sum_right[0]/self.criterion.weighted_n_right - if monotonic_constraint == -1: # Decreasing constraint + if monotonic_constraint == -1: # Monotonically decreasing constraint if left < right: # Fails return 0 - else: + else: # Passes + return 1 + else: # Monotonically increasing constraint + if left > right: # Fails + return 0 + else: # Passes return 1 - elif left > right: # Increasing constraint fails - return 0 - else: - return 1 cdef class BaseDenseSplitter(Splitter): cdef DTYPE_t* X From 6407e99593e1d56aaa40afa8ee3addfcb7e29406 Mon Sep 17 00:00:00 2001 From: Pat O'Reilly Date: Wed, 7 Sep 2016 10:41:07 +0100 Subject: [PATCH 004/174] fix typo 'decreasin' --- sklearn/ensemble/forest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index ac7fb9a65e0d0..d90d364948675 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -1491,7 +1491,7 @@ def __init__(self, verbose=0, warm_start=False, increasing=None, - decreasin=None): + decreasing=None): super(ExtraTreesRegressor, self).__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, From 119d5a384f8349a6108f69592df9fd82ef3959b4 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Sun, 14 Apr 2019 21:43:49 +0200 Subject: [PATCH 005/174] Clean up dirty merge --- sklearn/ensemble/gradient_boosting.py | 28 ++++++++++++--------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index b977a54b12663..7ffcfbda963cb 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -1936,13 +1936,6 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): .. versionadded:: 0.17 *presort* parameter. -<<<<<<< HEAD - increasing : list of ints, optional (default=None) - Indices of features to have a monotonically increasing effect. - - decreasing : list of ints, optional (default=None) - Indices of features to have a monotonically decreasing effect. -======= validation_fraction : float, optional, default 0.1 The proportion of training data to set aside as validation set for early stopping. Must be between 0 and 1. @@ -1967,7 +1960,12 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): number), the training stops. .. versionadded:: 0.20 ->>>>>>> master + + increasing : list of ints, optional (default=None) + Indices of features to have a monotonically increasing effect. + + decreasing : list of ints, optional (default=None) + Indices of features to have a monotonically decreasing effect. Attributes ---------- @@ -2414,13 +2412,6 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): .. versionadded:: 0.17 optional parameter *presort*. -<<<<<<< HEAD - increasing : list of ints, optional (default=None) - Indices of features to have a monotonically increasing effect. - - decreasing : list of ints, optional (default=None) - Indices of features to have a monotonically decreasing effect. -======= validation_fraction : float, optional, default 0.1 The proportion of training data to set aside as validation set for early stopping. Must be between 0 and 1. @@ -2446,7 +2437,12 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): .. versionadded:: 0.20 ->>>>>>> master + increasing : list of ints, optional (default=None) + Indices of features to have a monotonically increasing effect. + + decreasing : list of ints, optional (default=None) + Indices of features to have a monotonically decreasing effect. + Attributes ---------- From c331e82ca3bedb6b28be8a2446865890356dcd18 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 16 Apr 2019 00:40:15 +0200 Subject: [PATCH 006/174] PEP8 --- sklearn/ensemble/gradient_boosting.py | 6 ++--- sklearn/tree/tree.py | 33 ++++++++++++++++++--------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 7ffcfbda963cb..6efaba41a434d 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -1236,7 +1236,7 @@ def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask, max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, presort=self.presort, - increasing=self.increasing, + increasing=self.increasing, decreasing=self.decreasing) if self.subsample < 1.0: @@ -2038,7 +2038,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, max_leaf_nodes=None, warm_start=False, presort='auto', validation_fraction=0.1, n_iter_no_change=None, tol=1e-4, increasing=None, - decreasing=None): + decreasing=None): super().__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -2053,7 +2053,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, min_impurity_split=min_impurity_split, warm_start=warm_start, presort=presort, validation_fraction=validation_fraction, - n_iter_no_change=n_iter_no_change, tol=tol,increasing=increasing, + n_iter_no_change=n_iter_no_change, tol=tol, increasing=increasing, decreasing=decreasing) def _validate_y(self, y, sample_weight): diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index b30de403199d4..8a20f6a40e1fb 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -354,29 +354,40 @@ def fit(self, X, y, sample_weight=None, check_input=True, SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS def _encode_monotonic(increasing, decreasing): - if increasing is None: increasing = [] - if decreasing is None: decreasing = [] - def is_int_in_range(feature): - return isinstance(feature, int) and 0 <= feature < self.n_features_ + increasing = [] if increasing is None else increasing + decreasing = [] if decreasing is None else decreasing + + def is_int_in_range(feature_): + return isinstance(feature_, int) and \ + 0 <= feature_ < self.n_features_ + def is_valid(features): return (isinstance(features, list) and all(is_int_in_range(feature) for feature in features)) + if not is_valid(increasing): - raise ValueError("increasing should be a list of ints in the range [0,n_features].") + raise ValueError("increasing should be a list of ints in " + "the range [0,n_features].") + if not is_valid(decreasing): - raise ValueError("decreasing should be a list of ints in the range [0,n_features].") + raise ValueError("decreasing should be a list of ints in " + "the range [0,n_features].") + if increasing and decreasing: intersection = set(increasing) & set(decreasing) if intersection: - raise ValueError("The following features cannot be both increasing and decreasing: " + str(list(intersection))) - monotonic = np.zeros(self.n_features_, dtype=np.int32) + raise ValueError("The following features cannot be both " + "increasing and decreasing: " + + str(list(intersection))) + + monotonic_ = np.zeros(self.n_features_, dtype=np.int32) if increasing: for feature in increasing: - monotonic[feature] = 1 + monotonic_[feature] = 1 if decreasing: for feature in decreasing: - monotonic[feature] = -1 - return monotonic + monotonic_[feature] = -1 + return monotonic_ monotonic = _encode_monotonic(self.increasing, self.decreasing) From ab6d1854938dcf948bc8eb4f200efe826719e092 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 16 Apr 2019 00:43:09 +0200 Subject: [PATCH 007/174] PEP8 --- sklearn/tree/tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 8a20f6a40e1fb..dcb64ec4b7228 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -388,7 +388,7 @@ def is_valid(features): for feature in decreasing: monotonic_[feature] = -1 return monotonic_ - + monotonic = _encode_monotonic(self.increasing, self.decreasing) splitter = self.splitter From ce9377fc9d03a5883f526f627df4849ae0ae2996 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 16 Apr 2019 00:44:09 +0200 Subject: [PATCH 008/174] PEP8 --- sklearn/tree/tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index dcb64ec4b7228..4455f958bc347 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -390,7 +390,7 @@ def is_valid(features): return monotonic_ monotonic = _encode_monotonic(self.increasing, self.decreasing) - + splitter = self.splitter if not isinstance(self.splitter, Splitter): splitter = SPLITTERS[self.splitter](criterion, From fe215797f2fca2ba6e9c902ea17f554575177904 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 25 Apr 2019 23:25:56 +0200 Subject: [PATCH 009/174] Update docstrings for increasing/decreasing parameters --- sklearn/ensemble/forest.py | 30 ++++++++++++++++++--------- sklearn/ensemble/gradient_boosting.py | 13 +++++++----- sklearn/tree/tree.py | 12 +++++++---- 3 files changed, 36 insertions(+), 19 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 64b225b232df4..27b935da1a63e 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -914,10 +914,12 @@ class RandomForestClassifier(ForestClassifier): through the fit method) if sample_weight is specified. increasing : list of ints, optional (default=None) - Indices of features to have a monotonically increasing effect. + Indices of features constrained to have a monotonically increasing + effect on the predicted variable. decreasing : list of ints, optional (default=None) - Indices of features to have a monotonically decreasing effect. + Indices of features constrained to have a monotonically decreasing + effect on the predicted variable. Attributes ---------- @@ -1189,10 +1191,12 @@ class RandomForestRegressor(ForestRegressor): new forest. See :term:`the Glossary `. increasing : list of ints, optional (default=None) - Indices of features to have a monotonically increasing effect. + Indices of features constrained to have a monotonically increasing + effect on the predicted variable. decreasing : list of ints, optional (default=None) - Indices of features to have a monotonically decreasing effect. + Indices of features constrained to have a monotonically decreasing + effect on the predicted variable. Attributes ---------- @@ -1475,10 +1479,12 @@ class ExtraTreesClassifier(ForestClassifier): through the fit method) if sample_weight is specified. increasing : list of ints, optional (default=None) - Indices of features to have a monotonically increasing effect. + Indices of features constrained to have a monotonically increasing + effect on the predicted variable. decreasing : list of ints, optional (default=None) - Indices of features to have a monotonically decreasing effect. + Indices of features constrained to have a monotonically decreasing + effect on the predicted variable. Attributes ---------- @@ -1721,10 +1727,12 @@ class ExtraTreesRegressor(ForestRegressor): new forest. See :term:`the Glossary `. increasing : list of ints, optional (default=None) - Indices of features to have a monotonically increasing effect. + Indices of features constrained to have a monotonically increasing + effect on the predicted variable. decreasing : list of ints, optional (default=None) - Indices of features to have a monotonically decreasing effect. + Indices of features constrained to have a monotonically decreasing + effect on the predicted variable. Attributes ---------- @@ -1930,10 +1938,12 @@ class RandomTreesEmbedding(BaseForest): new forest. See :term:`the Glossary `. increasing : list of ints, optional (default=None) - Indices of features to have a monotonically increasing effect. + Indices of features constrained to have a monotonically increasing + effect on the predicted variable. decreasing : list of ints, optional (default=None) - Indices of features to have a monotonically decreasing effect. + Indices of features constrained to have a monotonically decreasing + effect on the predicted variable. Attributes ---------- diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 6efaba41a434d..de25f645e319e 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -1962,10 +1962,12 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): .. versionadded:: 0.20 increasing : list of ints, optional (default=None) - Indices of features to have a monotonically increasing effect. + Indices of features constrained to have a monotonically increasing + effect on the predicted variable. decreasing : list of ints, optional (default=None) - Indices of features to have a monotonically decreasing effect. + Indices of features constrained to have a monotonically decreasing + effect on the predicted variable. Attributes ---------- @@ -2438,11 +2440,12 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): .. versionadded:: 0.20 increasing : list of ints, optional (default=None) - Indices of features to have a monotonically increasing effect. + Indices of features constrained to have a monotonically increasing + effect on the predicted variable. decreasing : list of ints, optional (default=None) - Indices of features to have a monotonically decreasing effect. - + Indices of features constrained to have a monotonically decreasing + effect on the predicted variable. Attributes ---------- diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 4455f958bc347..0d4c6f33b2d16 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -710,10 +710,12 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): speed up the training. increasing : list of ints, optional (default=None) - Indices of features to have a monotonically increasing effect. + Indices of features constrained to have a monotonically increasing + effect on the predicted variable. decreasing : list of ints, optional (default=None) - Indices of features to have a monotonically decreasing effect. + Indices of features constrained to have a monotonically decreasing + effect on the predicted variable. Attributes ---------- @@ -1072,10 +1074,12 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin): speed up the training. increasing : list of ints, optional (default=None) - Indices of features to have a monotonically increasing effect. + Indices of features constrained to have a monotonically increasing + effect on the predicted variable. decreasing : list of ints, optional (default=None) - Indices of features to have a monotonically decreasing effect. + Indices of features constrained to have a monotonically decreasing + effect on the predicted variable. Attributes ---------- From 544cced71bad060d46aa5662c01acf6f3a3fead4 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Fri, 26 Apr 2019 00:44:59 +0200 Subject: [PATCH 010/174] Cleanup --- sklearn/ensemble/forest.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 5ff902c18d687..0ba4fd8e8068c 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -972,11 +972,12 @@ class labels (multi-output problem). ... random_state=0) >>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini', - max_depth=2, max_features='auto', max_leaf_nodes=None, - min_impurity_decrease=0.0, min_impurity_split=None, - min_samples_leaf=1, min_samples_split=2, - min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None, - oob_score=False, random_state=0, verbose=0, warm_start=False) + decreasing=None, increasing=None, max_depth=2, + max_features='auto', max_leaf_nodes=None, + min_impurity_decrease=0.0, min_impurity_split=None, + min_samples_leaf=1, min_samples_split=2, + min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None, + oob_score=False, random_state=0, verbose=0, warm_start=False) >>> print(clf.feature_importances_) [0.14205973 0.76664038 0.0282433 0.06305659] >>> print(clf.predict([[0, 0, 0, 0]])) @@ -1236,12 +1237,13 @@ class RandomForestRegressor(ForestRegressor): >>> regr = RandomForestRegressor(max_depth=2, random_state=0, ... n_estimators=100) >>> regr.fit(X, y) # doctest: +NORMALIZE_WHITESPACE - RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=2, - max_features='auto', max_leaf_nodes=None, - min_impurity_decrease=0.0, min_impurity_split=None, - min_samples_leaf=1, min_samples_split=2, - min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None, - oob_score=False, random_state=0, verbose=0, warm_start=False) + RandomForestRegressor(bootstrap=True, criterion='mse', decreasing=None, + increasing=None, max_depth=2, max_features='auto', + max_leaf_nodes=None, min_impurity_decrease=0.0, + min_impurity_split=None, min_samples_leaf=1, + min_samples_split=2, min_weight_fraction_leaf=0.0, + n_estimators=100, n_jobs=None, oob_score=False, random_state=0, + verbose=0, warm_start=False) >>> print(regr.feature_importances_) [0.18146984 0.81473937 0.00145312 0.00233767] >>> print(regr.predict([[0, 0, 0, 0]])) From 20c79297a9ae175ed7fffd3c517c506d9391be6d Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Fri, 26 Apr 2019 01:32:54 +0200 Subject: [PATCH 011/174] Cleanup --- sklearn/ensemble/forest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 0ba4fd8e8068c..0f0b068cc1763 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -976,8 +976,9 @@ class labels (multi-output problem). max_features='auto', max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, - min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None, - oob_score=False, random_state=0, verbose=0, warm_start=False) + min_weight_fraction_leaf=0.0, n_estimators=100, + n_jobs=None, oob_score=False, random_state=0, + verbose=0, warm_start=False) >>> print(clf.feature_importances_) [0.14205973 0.76664038 0.0282433 0.06305659] >>> print(clf.predict([[0, 0, 0, 0]])) From 49747c6d1782c6481b038ea0d30aa6d90ca808c4 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Mon, 17 Jun 2019 18:44:19 +0200 Subject: [PATCH 012/174] Add basic monotonicity tests --- sklearn/tree/tests/test_tree.py | 42 +++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 8d64150cab8f5..0a8e45fbbd7f3 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1858,3 +1858,45 @@ def test_decision_tree_memmap(): with TempMemmap((X, y)) as (X_read_only, y_read_only): DecisionTreeClassifier().fit(X_read_only, y_read_only) + +def test_montonic_constraints(): + X, y = datasets.make_hastie_10_2(n_samples=1000, random_state=123) + X0 = np.copy(X) + for name, TreeRegressor in REG_TREES.items(): + est = TreeRegressor(max_depth=None, decreasing=[0], increasing=[1]).fit(X, y) + + y0 = est.predict(X0) + + # decreasing constraint + X1 = np.copy(X) + X1[:,0] += 10 + y1 = est.predict(X1) + # y1 should always be lower than y0 + assert_less_equal(np.max(y1 - y0), 0) + + # increasing constraint + X1 = np.copy(X) + X1[:, 1] += 10 + y1 = est.predict(X1) + # y1 should always be greater than y0 + assert_greater_equal(np.min(y1 - y0), 0) + + for name, TreeClassifier in CLF_TREES.items(): + est = TreeClassifier(max_depth=None, decreasing=[0], increasing=[1]).fit(X, y) + + X0 = np.copy(X) + y0 = est.predict_proba(X0) + + # decreasing constraint + X1 = np.copy(X) + X1[:,0] += 10 + y1 = est.predict_proba(X1) + # y1 should always be lower than y0 + assert_less_equal(np.max(y1 - y0), 0) + + # increasing constraint + X1 = np.copy(X) + X1[:, 1] += 10 + y1 = est.predict_proba(X1) + # y1 should always be greater than y0 + assert_greater_equal(np.min(y1 - y0), 0) From 37b536bd116600020acb30377a179a44560c08ed Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Mon, 17 Jun 2019 18:48:26 +0200 Subject: [PATCH 013/174] PEP8 --- sklearn/tree/tests/test_tree.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 0a8e45fbbd7f3..b57df83da52bd 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1869,7 +1869,7 @@ def test_montonic_constraints(): # decreasing constraint X1 = np.copy(X) - X1[:,0] += 10 + X1[:, 0] += 10 y1 = est.predict(X1) # y1 should always be lower than y0 assert_less_equal(np.max(y1 - y0), 0) @@ -1889,7 +1889,7 @@ def test_montonic_constraints(): # decreasing constraint X1 = np.copy(X) - X1[:,0] += 10 + X1[:, 0] += 10 y1 = est.predict_proba(X1) # y1 should always be lower than y0 assert_less_equal(np.max(y1 - y0), 0) From ef2fb1cad9866d82353114b409bf9e6904f7e860 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Mon, 17 Jun 2019 18:50:54 +0200 Subject: [PATCH 014/174] PEP8 --- sklearn/tree/tests/test_tree.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index b57df83da52bd..3b23e9e951015 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1859,11 +1859,13 @@ def test_decision_tree_memmap(): with TempMemmap((X, y)) as (X_read_only, y_read_only): DecisionTreeClassifier().fit(X_read_only, y_read_only) + def test_montonic_constraints(): X, y = datasets.make_hastie_10_2(n_samples=1000, random_state=123) X0 = np.copy(X) for name, TreeRegressor in REG_TREES.items(): - est = TreeRegressor(max_depth=None, decreasing=[0], increasing=[1]).fit(X, y) + est = TreeRegressor(max_depth=None, + decreasing=[0], increasing=[1]).fit(X, y) y0 = est.predict(X0) @@ -1882,7 +1884,8 @@ def test_montonic_constraints(): assert_greater_equal(np.min(y1 - y0), 0) for name, TreeClassifier in CLF_TREES.items(): - est = TreeClassifier(max_depth=None, decreasing=[0], increasing=[1]).fit(X, y) + est = TreeClassifier(max_depth=None, + decreasing=[0], increasing=[1]).fit(X, y) X0 = np.copy(X) y0 = est.predict_proba(X0) From 072ef991b9d5ce783884721c35ad75d38036d53b Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 20 Jun 2019 18:46:24 +0200 Subject: [PATCH 015/174] Simplify split_passes_monotonic_check function --- sklearn/tree/_splitter.pyx | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index d994bef5c4f30..e228eadab8afa 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -241,18 +241,10 @@ cdef class Splitter: if monotonic_constraint == 0: # No constraint return 1 else: - left = self.criterion.sum_left[0]/self.criterion.weighted_n_left - right = self.criterion.sum_right[0]/self.criterion.weighted_n_right - if monotonic_constraint == -1: # Monotonically decreasing constraint - if left < right: # Fails - return 0 - else: # Passes - return 1 - else: # Monotonically increasing constraint - if left > right: # Fails - return 0 - else: # Passes - return 1 + delta = (self.criterion.sum_left[0] * self.criterion.weighted_n_right \ + - self.criterion.sum_right[0] * self.criterion.weighted_n_left) \ + * monotonic_constraint + return delta >= 0 cdef class BaseDenseSplitter(Splitter): cdef const DTYPE_t[:, :] X From 605e1c823d668a33a00704a71c673d22313da00a Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 20 Jun 2019 19:01:17 +0200 Subject: [PATCH 016/174] Fix split_passes_monotonic_check function --- sklearn/tree/_splitter.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index e228eadab8afa..f4f89d52018bf 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -244,7 +244,7 @@ cdef class Splitter: delta = (self.criterion.sum_left[0] * self.criterion.weighted_n_right \ - self.criterion.sum_right[0] * self.criterion.weighted_n_left) \ * monotonic_constraint - return delta >= 0 + return delta <= 0 cdef class BaseDenseSplitter(Splitter): cdef const DTYPE_t[:, :] X From 018d6c24bc60cf5a4c412ce3480193e858916eb7 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 20 Jun 2019 19:25:58 +0200 Subject: [PATCH 017/174] Fix tests (still something broken though) --- sklearn/tree/tests/test_tree.py | 52 ++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 3b23e9e951015..bb1742a611577 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1861,45 +1861,49 @@ def test_decision_tree_memmap(): def test_montonic_constraints(): - X, y = datasets.make_hastie_10_2(n_samples=1000, random_state=123) - X0 = np.copy(X) + X, y = datasets.make_hastie_10_2(n_samples=100, random_state=0) + train = np.arange(90) + test = np.arange(90, 100) + X_train = X[train] + y_train = y[train] + X_test_0 = np.copy(X[test]) + X_test_1 = np.copy(X_test_0) + X_test_1[:, 1] += 10 + X_test_2 = np.copy(X_test_0) + X_test_2[:, 2] += 10 for name, TreeRegressor in REG_TREES.items(): est = TreeRegressor(max_depth=None, - decreasing=[0], increasing=[1]).fit(X, y) - - y0 = est.predict(X0) + decreasing=[1], increasing=[2]) + if hasattr(est, "random_state"): + est.set_params(**{"random_state": 0}) + est.fit(X_train, y_train) + y0 = est.predict(X_test_0) # decreasing constraint - X1 = np.copy(X) - X1[:, 0] += 10 - y1 = est.predict(X1) + y1 = est.predict(X_test_1) # y1 should always be lower than y0 assert_less_equal(np.max(y1 - y0), 0) # increasing constraint - X1 = np.copy(X) - X1[:, 1] += 10 - y1 = est.predict(X1) - # y1 should always be greater than y0 - assert_greater_equal(np.min(y1 - y0), 0) + y2 = est.predict(X_test_2) + # y2 should always be greater than y0 + assert_greater_equal(np.min(y2 - y0), 0) for name, TreeClassifier in CLF_TREES.items(): est = TreeClassifier(max_depth=None, - decreasing=[0], increasing=[1]).fit(X, y) + decreasing=[1], increasing=[2]) + if hasattr(est, "random_state"): + est.set_params(**{"random_state": 0}) + est.fit(X_train, y_train) - X0 = np.copy(X) - y0 = est.predict_proba(X0) + y0 = est.predict_proba(X_test_0)[:, 0] # decreasing constraint - X1 = np.copy(X) - X1[:, 0] += 10 - y1 = est.predict_proba(X1) + y1 = est.predict_proba(X_test_1)[:, 0] # y1 should always be lower than y0 assert_less_equal(np.max(y1 - y0), 0) # increasing constraint - X1 = np.copy(X) - X1[:, 1] += 10 - y1 = est.predict_proba(X1) - # y1 should always be greater than y0 - assert_greater_equal(np.min(y1 - y0), 0) + y2 = est.predict_proba(X_test_2)[:, 0] + # y2 should always be greater than y0 + assert_greater_equal(np.min(y2 - y0), 0) From b217ec8521cd550068ba4e6d0ace969752868561 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Jan 2020 12:49:00 +0100 Subject: [PATCH 018/174] Fix PEP8 and missing import in tests --- sklearn/ensemble/_forest.py | 4 ++-- sklearn/ensemble/_gb.py | 8 ++++---- sklearn/tree/tests/test_tree.py | 2 ++ 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 420ba8235947c..9e8b619d1529b 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -2018,8 +2018,8 @@ def __init__(self, warm_start=False, ccp_alpha=0.0, max_samples=None, - increasing = None, - decreasing = None): + increasing=None, + decreasing=None): super().__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 27b1011e4abea..d88543121ebce 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -209,7 +209,7 @@ def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask, max_features=self.max_features, max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, - ccp_alpha = self.ccp_alpha, + ccp_alpha=self.ccp_alpha, increasing=self.increasing, decreasing=self.decreasing) @@ -1524,9 +1524,9 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, max_depth=3, min_impurity_decrease=0., min_impurity_split=None, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, - warm_start=False, presort='deprecated', validation_fraction=0.1, - n_iter_no_change=None, tol=1e-4, ccp_alpha=0.0, increasing=None, - decreasing=None): + warm_start=False, presort='deprecated', + validation_fraction=0.1, n_iter_no_change=None, tol=1e-4, + ccp_alpha=0.0, increasing=None, decreasing=None): super().__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 77d63def596ca..0e00a59a91d62 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -21,6 +21,8 @@ from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_almost_equal +from sklearn.utils._testing import assert_greater_equal +from sklearn.utils._testing import assert_less_equal from sklearn.utils._testing import assert_warns from sklearn.utils._testing import assert_warns_message from sklearn.utils._testing import ignore_warnings From be8b38a845349be965fa540f6d62750b23e9f520 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Jan 2020 13:43:44 +0100 Subject: [PATCH 019/174] Fix faulty Splitter call in _classes.py --- sklearn/tree/_classes.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 14b1bb04d5b27..975ffe3f00b2b 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -380,8 +380,7 @@ def is_valid(features): min_samples_leaf, min_weight_leaf, random_state, - monotonic, - random_state) + monotonic) if is_classifier(self): self.tree_ = Tree(self.n_features_, @@ -1751,4 +1750,3 @@ def __init__(self, ccp_alpha=ccp_alpha, increasing=increasing, decreasing=decreasing) - From 5aacb405a687b98fca9a8b1e9c882711da133707 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Jan 2020 14:18:27 +0100 Subject: [PATCH 020/174] Remove merge bruises --- sklearn/ensemble/_forest.py | 36 ++++++++++++++++-------------------- sklearn/ensemble/_gb.py | 18 ++++++++---------- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 9e8b619d1529b..7328619c0e732 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1307,15 +1307,6 @@ class RandomForestRegressor(ForestRegressor): and add more estimators to the ensemble, otherwise, just fit a whole new forest. See :term:`the Glossary `. -<<<<<<< HEAD:sklearn/ensemble/forest.py - increasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically increasing - effect on the predicted variable. - - decreasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically decreasing - effect on the predicted variable. -======= ccp_alpha : non-negative float, default=0.0 Complexity parameter used for Minimal Cost-Complexity Pruning. The subtree with the largest cost complexity that is smaller than @@ -1334,7 +1325,14 @@ class RandomForestRegressor(ForestRegressor): `max_samples` should be in the interval `(0, 1)`. .. versionadded:: 0.22 ->>>>>>> master:sklearn/ensemble/_forest.py + + increasing : list of ints, optional (default=None) + Indices of features constrained to have a monotonically increasing + effect on the predicted variable. + + decreasing : list of ints, optional (default=None) + Indices of features constrained to have a monotonically decreasing + effect on the predicted variable. Attributes ---------- @@ -1629,15 +1627,6 @@ class ExtraTreesClassifier(ForestClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. -<<<<<<< HEAD:sklearn/ensemble/forest.py - increasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically increasing - effect on the predicted variable. - - decreasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically decreasing - effect on the predicted variable. -======= ccp_alpha : non-negative float, default=0.0 Complexity parameter used for Minimal Cost-Complexity Pruning. The subtree with the largest cost complexity that is smaller than @@ -1656,7 +1645,14 @@ class ExtraTreesClassifier(ForestClassifier): `max_samples` should be in the interval `(0, 1)`. .. versionadded:: 0.22 ->>>>>>> master:sklearn/ensemble/_forest.py + + increasing : list of ints, optional (default=None) + Indices of features constrained to have a monotonically increasing + effect on the predicted variable. + + decreasing : list of ints, optional (default=None) + Indices of features constrained to have a monotonically decreasing + effect on the predicted variable. Attributes ---------- diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index d88543121ebce..c8b31c516e1e0 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -1446,15 +1446,6 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): .. versionadded:: 0.20 -<<<<<<< HEAD:sklearn/ensemble/gradient_boosting.py - increasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically increasing - effect on the predicted variable. - - decreasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically decreasing - effect on the predicted variable. -======= ccp_alpha : non-negative float, optional (default=0.0) Complexity parameter used for Minimal Cost-Complexity Pruning. The subtree with the largest cost complexity that is smaller than @@ -1462,7 +1453,14 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): :ref:`minimal_cost_complexity_pruning` for details. .. versionadded:: 0.22 ->>>>>>> master:sklearn/ensemble/_gb.py + + increasing : list of ints, optional (default=None) + Indices of features constrained to have a monotonically increasing + effect on the predicted variable. + + decreasing : list of ints, optional (default=None) + Indices of features constrained to have a monotonically decreasing + effect on the predicted variable. Attributes ---------- From ef0588964b2a49478643e392013734617ba52773 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Jan 2020 15:09:01 +0100 Subject: [PATCH 021/174] Remove deprecated asserts --- sklearn/tree/tests/test_tree.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 0e00a59a91d62..acef6e7e03883 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -21,8 +21,6 @@ from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_almost_equal -from sklearn.utils._testing import assert_greater_equal -from sklearn.utils._testing import assert_less_equal from sklearn.utils._testing import assert_warns from sklearn.utils._testing import assert_warns_message from sklearn.utils._testing import ignore_warnings @@ -1975,12 +1973,12 @@ def test_montonic_constraints(): # decreasing constraint y1 = est.predict(X_test_1) # y1 should always be lower than y0 - assert_less_equal(np.max(y1 - y0), 0) + assert(np.max(y1 - y0) <= 0) # increasing constraint y2 = est.predict(X_test_2) # y2 should always be greater than y0 - assert_greater_equal(np.min(y2 - y0), 0) + assert(np.min(y2 - y0) >= 0) for name, TreeClassifier in CLF_TREES.items(): est = TreeClassifier(max_depth=None, @@ -1994,12 +1992,12 @@ def test_montonic_constraints(): # decreasing constraint y1 = est.predict_proba(X_test_1)[:, 0] # y1 should always be lower than y0 - assert_less_equal(np.max(y1 - y0), 0) + assert(np.max(y1 - y0) <= 0) # increasing constraint y2 = est.predict_proba(X_test_2)[:, 0] # y2 should always be greater than y0 - assert_greater_equal(np.min(y2 - y0), 0) + assert(np.min(y2 - y0) >= 0) def test_classes_deprecated(): From 8ac9945f6ea5abf406e93ec194eaeb4645110af9 Mon Sep 17 00:00:00 2001 From: dsleo Date: Wed, 29 Jan 2020 18:50:05 +0100 Subject: [PATCH 022/174] changing decreasing/increasing to monotonic_cst --- sklearn/ensemble/_forest.py | 98 ++++++++++++--------------------- sklearn/ensemble/_gb.py | 38 +++++-------- sklearn/tree/_classes.py | 92 +++++++------------------------ sklearn/tree/tests/test_tree.py | 10 ++-- 4 files changed, 75 insertions(+), 163 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 7328619c0e732..930404c0b0d93 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1022,13 +1022,10 @@ class RandomForestClassifier(ForestClassifier): .. versionadded:: 0.22 - increasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically increasing - effect on the predicted variable. - - decreasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically decreasing - effect on the predicted variable. + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonic constraint to enforce on each feature. -1, 1 + and 0 respectively correspond to a positive constraint, negative + constraint and no constraint. Attributes ---------- @@ -1127,8 +1124,7 @@ def __init__(self, class_weight=None, ccp_alpha=0.0, max_samples=None, - increasing=None, - decreasing=None): + monotonic_cst=None): super().__init__( base_estimator=DecisionTreeClassifier(), @@ -1137,7 +1133,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state", "ccp_alpha"), + "random_state", "ccp_alpha", "monotonic_cst"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1156,8 +1152,7 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split - self.increasing = increasing - self.decreasing = decreasing + self.monotonic_cst = monotonic_cst self.ccp_alpha = ccp_alpha @@ -1326,13 +1321,10 @@ class RandomForestRegressor(ForestRegressor): .. versionadded:: 0.22 - increasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically increasing - effect on the predicted variable. - - decreasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically decreasing - effect on the predicted variable. + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonic constraint to enforce on each feature. -1, 1 + and 0 respectively correspond to a positive constraint, negative + constraint and no constraint. Attributes ---------- @@ -1425,8 +1417,7 @@ def __init__(self, warm_start=False, ccp_alpha=0.0, max_samples=None, - increasing=None, - decreasing=None): + monotonic_cst=None): super().__init__( base_estimator=DecisionTreeRegressor(), n_estimators=n_estimators, @@ -1434,8 +1425,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state", "ccp_alpha", "increasing", - "decreasing"), + "random_state", "ccp_alpha", "monotonic_cst"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1454,8 +1444,7 @@ def __init__(self, self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split self.ccp_alpha = ccp_alpha - self.increasing = increasing - self.decreasing = decreasing + self.monotonic_cst = monotonic_cst class ExtraTreesClassifier(ForestClassifier): @@ -1646,13 +1635,10 @@ class ExtraTreesClassifier(ForestClassifier): .. versionadded:: 0.22 - increasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically increasing - effect on the predicted variable. - - decreasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically decreasing - effect on the predicted variable. + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonic constraint to enforce on each feature. -1, 1 + and 0 respectively correspond to a positive constraint, negative + constraint and no constraint. Attributes ---------- @@ -1742,8 +1728,7 @@ def __init__(self, class_weight=None, ccp_alpha=0.0, max_samples=None, - increasing=None, - decreasing=None): + monotonic_cst=None): super().__init__( base_estimator=ExtraTreeClassifier(), n_estimators=n_estimators, @@ -1751,8 +1736,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state", "ccp_alpha", "increasing", - "decreasing"), + "random_state", "ccp_alpha", "monotonic_cst"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1772,8 +1756,7 @@ def __init__(self, self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split self.ccp_alpha = ccp_alpha - self.increasing = increasing - self.decreasing = decreasing + self.monotonic_cst = monotonic_cst class ExtraTreesRegressor(ForestRegressor): @@ -1942,13 +1925,10 @@ class ExtraTreesRegressor(ForestRegressor): .. versionadded:: 0.22 - increasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically increasing - effect on the predicted variable. - - decreasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically decreasing - effect on the predicted variable. + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonic constraint to enforce on each feature. -1, 1 + and 0 respectively correspond to a positive constraint, negative + constraint and no constraint. Attributes ---------- @@ -2014,8 +1994,7 @@ def __init__(self, warm_start=False, ccp_alpha=0.0, max_samples=None, - increasing=None, - decreasing=None): + monotonic_cst=None): super().__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, @@ -2023,8 +2002,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state", "ccp_alpha", "increasing", - "decreasing"), + "random_state", "ccp_alpha", "monotonic_cst"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -2043,8 +2021,7 @@ def __init__(self, self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split self.ccp_alpha = ccp_alpha - self.increasing = increasing - self.decreasing = decreasing + self.monotonic_cst = monotonic_cst class RandomTreesEmbedding(BaseForest): @@ -2165,13 +2142,10 @@ class RandomTreesEmbedding(BaseForest): and add more estimators to the ensemble, otherwise, just fit a whole new forest. See :term:`the Glossary `. - increasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically increasing - effect on the predicted variable. - - decreasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically decreasing - effect on the predicted variable. + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonic constraint to enforce on each feature. -1, 1 + and 0 respectively correspond to a positive constraint, negative + constraint and no constraint. Attributes ---------- @@ -2205,8 +2179,7 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - increasing=None, - decreasing=None): + monotonic_cst=None): super().__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, @@ -2214,7 +2187,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state", "increasing", "decreasing"), + "random_state", "monotonic_cst"), bootstrap=False, oob_score=False, n_jobs=n_jobs, @@ -2231,8 +2204,7 @@ def __init__(self, self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split self.sparse_output = sparse_output - self.increasing = increasing - self.decreasing = decreasing + self.monotonic_cst = monotonic_cst def _set_oob_score(self, X, y): raise NotImplementedError("OOB score not supported by tree embedding") diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index c8b31c516e1e0..44d6174c07be9 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -147,7 +147,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, random_state, alpha=0.9, verbose=0, max_leaf_nodes=None, warm_start=False, presort='deprecated', validation_fraction=0.1, n_iter_no_change=None, - tol=1e-4, increasing=None, decreasing=None): + tol=1e-4, monotonic_cst=None): self.n_estimators = n_estimators self.learning_rate = learning_rate @@ -172,8 +172,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, self.validation_fraction = validation_fraction self.n_iter_no_change = n_iter_no_change self.tol = tol - self.increasing = increasing - self.decreasing = decreasing + self.monotonic_cst = monotonic_cst def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask, random_state, X_idx_sorted, X_csc=None, X_csr=None): @@ -210,8 +209,7 @@ def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask, max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, ccp_alpha=self.ccp_alpha, - increasing=self.increasing, - decreasing=self.decreasing) + monotonic_cst=self.monotonic_cst) if self.subsample < 1.0: @@ -967,13 +965,10 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): .. versionadded:: 0.22 - increasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically increasing - effect on the predicted variable. - - decreasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically decreasing - effect on the predicted variable. + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonic constraint to enforce on each feature. -1, 1 + and 0 respectively correspond to a positive constraint, negative + constraint and no constraint. Attributes ---------- @@ -1051,7 +1046,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, max_leaf_nodes=None, warm_start=False, presort='deprecated', validation_fraction=0.1, n_iter_no_change=None, tol=1e-4, ccp_alpha=0.0, - increasing=None, decreasing=None): + monotonic_cst=None): super().__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -1067,7 +1062,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, warm_start=warm_start, presort=presort, validation_fraction=validation_fraction, n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha, - increasing=increasing, decreasing=decreasing) + monotonic_cst=monotonic_cst) def _validate_y(self, y, sample_weight): check_classification_targets(y) @@ -1454,13 +1449,10 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): .. versionadded:: 0.22 - increasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically increasing - effect on the predicted variable. - - decreasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically decreasing - effect on the predicted variable. + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonic constraint to enforce on each feature. -1, 1 + and 0 respectively correspond to a positive constraint, negative + constraint and no constraint. Attributes ---------- @@ -1524,7 +1516,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, warm_start=False, presort='deprecated', validation_fraction=0.1, n_iter_no_change=None, tol=1e-4, - ccp_alpha=0.0, increasing=None, decreasing=None): + ccp_alpha=0.0, monotonic_cst=None): super().__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -1539,7 +1531,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, max_leaf_nodes=max_leaf_nodes, warm_start=warm_start, presort=presort, validation_fraction=validation_fraction, n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha, - increasing=increasing, decreasing=decreasing) + monotonic_cst=monotonic_cst) def predict(self, X): """Predict regression target for X. diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 975ffe3f00b2b..2331ea2e4fc89 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -96,8 +96,7 @@ def __init__(self, min_impurity_split, class_weight=None, presort='deprecated', - increasing=None, - decreasing=None, + monotonic_cst=None, ccp_alpha=0.0): self.criterion = criterion @@ -113,8 +112,7 @@ def __init__(self, self.min_impurity_split = min_impurity_split self.class_weight = class_weight self.presort = presort - self.increasing = increasing - self.decreasing = decreasing + self.monotonic_cst = monotonic_cst self.ccp_alpha = ccp_alpha def get_depth(self): @@ -335,44 +333,6 @@ def fit(self, X, y, sample_weight=None, check_input=True, SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS - def _encode_monotonic(increasing, decreasing): - increasing = [] if increasing is None else increasing - decreasing = [] if decreasing is None else decreasing - - def is_int_in_range(feature_): - return isinstance(feature_, int) and \ - 0 <= feature_ < self.n_features_ - - def is_valid(features): - return (isinstance(features, list) and - all(is_int_in_range(feature) for feature in features)) - - if not is_valid(increasing): - raise ValueError("increasing should be a list of ints in " - "the range [0,n_features].") - - if not is_valid(decreasing): - raise ValueError("decreasing should be a list of ints in " - "the range [0,n_features].") - - if increasing and decreasing: - intersection = set(increasing) & set(decreasing) - if intersection: - raise ValueError("The following features cannot be both " - "increasing and decreasing: " - + str(list(intersection))) - - monotonic_ = np.zeros(self.n_features_, dtype=np.int32) - if increasing: - for feature in increasing: - monotonic_[feature] = 1 - if decreasing: - for feature in decreasing: - monotonic_[feature] = -1 - return monotonic_ - - monotonic = _encode_monotonic(self.increasing, self.decreasing) - splitter = self.splitter if not isinstance(self.splitter, Splitter): splitter = SPLITTERS[self.splitter](criterion, @@ -380,7 +340,7 @@ def is_valid(features): min_samples_leaf, min_weight_leaf, random_state, - monotonic) + self.monotonic_cst) if is_classifier(self): self.tree_ = Tree(self.n_features_, @@ -773,13 +733,10 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): .. versionadded:: 0.22 - increasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically increasing - effect on the predicted variable. - - decreasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically decreasing - effect on the predicted variable. + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonic constraint to enforce on each feature. -1, 1 + and 0 respectively correspond to a positive constraint, negative + constraint and no constraint.. Attributes ---------- @@ -866,8 +823,7 @@ def __init__(self, min_impurity_split=None, class_weight=None, presort='deprecated', - increasing=None, - decreasing=None, + monotonic_cst=None, ccp_alpha=0.0): super().__init__( @@ -884,8 +840,7 @@ def __init__(self, min_impurity_decrease=min_impurity_decrease, min_impurity_split=min_impurity_split, presort=presort, - increasing=increasing, - decreasing=decreasing, + monotonic_cst=monotonic_cst, ccp_alpha=ccp_alpha) def fit(self, X, y, sample_weight=None, check_input=True, @@ -1144,13 +1099,10 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): .. versionadded:: 0.22 - increasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically increasing - effect on the predicted variable. - - decreasing : list of ints, optional (default=None) - Indices of features constrained to have a monotonically decreasing - effect on the predicted variable. + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonic constraint to enforce on each feature. -1, 1 + and 0 respectively correspond to a positive constraint, negative + constraint and no constraint. Attributes ---------- @@ -1229,8 +1181,7 @@ def __init__(self, min_impurity_split=None, presort='deprecated', ccp_alpha=0.0, - increasing=None, - decreasing=None): + monotonic_cst=None): super().__init__( criterion=criterion, splitter=splitter, @@ -1245,8 +1196,7 @@ def __init__(self, min_impurity_split=min_impurity_split, presort=presort, ccp_alpha=ccp_alpha, - increasing=increasing, - decreasing=decreasing) + monotonic_cst=monotonic_cst) def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): @@ -1516,8 +1466,7 @@ def __init__(self, min_impurity_split=None, class_weight=None, ccp_alpha=0.0, - increasing=None, - decreasing=None): + monotonic_cst=None): super().__init__( criterion=criterion, @@ -1533,8 +1482,7 @@ def __init__(self, min_impurity_split=min_impurity_split, random_state=random_state, ccp_alpha=ccp_alpha, - increasing=increasing, - decreasing=decreasing) + monotonic_cst=monotonic_cst) class ExtraTreeRegressor(DecisionTreeRegressor): @@ -1733,8 +1681,7 @@ def __init__(self, min_impurity_split=None, random_state=None, ccp_alpha=0.0, - increasing=None, - decreasing=None): + monotonic_cst=None): super().__init__( criterion=criterion, splitter=splitter, @@ -1748,5 +1695,4 @@ def __init__(self, min_impurity_split=min_impurity_split, random_state=random_state, ccp_alpha=ccp_alpha, - increasing=increasing, - decreasing=decreasing) + monotonic_cst=monotonic_cst) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index acef6e7e03883..51441abd521cb 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1962,9 +1962,12 @@ def test_montonic_constraints(): X_test_1[:, 1] += 10 X_test_2 = np.copy(X_test_0) X_test_2[:, 2] += 10 + monotonic_cst = np.zeros(X.shape[0]) + monotonic_cst[0] = -1 + monotonic_cst[1] = 1 + for name, TreeRegressor in REG_TREES.items(): - est = TreeRegressor(max_depth=None, - decreasing=[1], increasing=[2]) + est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) if hasattr(est, "random_state"): est.set_params(**{"random_state": 0}) est.fit(X_train, y_train) @@ -1981,8 +1984,7 @@ def test_montonic_constraints(): assert(np.min(y2 - y0) >= 0) for name, TreeClassifier in CLF_TREES.items(): - est = TreeClassifier(max_depth=None, - decreasing=[1], increasing=[2]) + est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) if hasattr(est, "random_state"): est.set_params(**{"random_state": 0}) est.fit(X_train, y_train) From 8a687c177906844d3291e07bbf25416d55b5cd95 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Jan 2020 20:49:22 +0100 Subject: [PATCH 023/174] Fix segfault caused by sending None instead of np.array to Splitter init --- sklearn/tree/_classes.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 2331ea2e4fc89..7115077489a2c 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -334,6 +334,23 @@ def fit(self, X, y, sample_weight=None, check_input=True, SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS splitter = self.splitter + if self.monotonic_cst is None: + self.monotonic_cst = np.full(shape=X.shape[1], + fill_value=0, + dtype=np.int32) + else: + self.monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) + + if self.monotonic_cst.shape[0] != X.shape[1]: + raise ValueError( + "monotonic_cst has shape {} but the input data " + "X has {} features.".format( + self.monotonic_cst.shape[0], X.shape[1] + ) + ) + if np.any(self.monotonic_cst < -1) or np.any(self.monotonic_cst > 1): + raise ValueError( + "monotonic_cst must be None or an array-like of -1, 0 or 1.") if not isinstance(self.splitter, Splitter): splitter = SPLITTERS[self.splitter](criterion, self.max_features_, From 900ed04f1f02c4a43bcc6bb6e65eaa94150bf7c7 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Jan 2020 20:50:02 +0100 Subject: [PATCH 024/174] Fix tests --- sklearn/tree/tests/test_tree.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 51441abd521cb..43cde1bbd5d22 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1959,10 +1959,10 @@ def test_montonic_constraints(): y_train = y[train] X_test_0 = np.copy(X[test]) X_test_1 = np.copy(X_test_0) - X_test_1[:, 1] += 10 + X_test_1[:, 0] += 10 X_test_2 = np.copy(X_test_0) - X_test_2[:, 2] += 10 - monotonic_cst = np.zeros(X.shape[0]) + X_test_2[:, 1] += 10 + monotonic_cst = np.zeros(X.shape[1]) monotonic_cst[0] = -1 monotonic_cst[1] = 1 From 15fe7b54f353eb6fbf27615622ef6d24ec606322 Mon Sep 17 00:00:00 2001 From: dsleo Date: Thu, 30 Jan 2020 10:55:24 +0100 Subject: [PATCH 025/174] fix linting --- sklearn/tree/_classes.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 7115077489a2c..11fecea46a58b 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -335,20 +335,20 @@ def fit(self, X, y, sample_weight=None, check_input=True, splitter = self.splitter if self.monotonic_cst is None: - self.monotonic_cst = np.full(shape=X.shape[1], + monotonic_cst = np.full(shape=X.shape[1], fill_value=0, dtype=np.int32) else: - self.monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) + monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) - if self.monotonic_cst.shape[0] != X.shape[1]: + if monotonic_cst.shape[0] != X.shape[1]: raise ValueError( "monotonic_cst has shape {} but the input data " "X has {} features.".format( - self.monotonic_cst.shape[0], X.shape[1] + monotonic_cst.shape[0], X.shape[1] ) ) - if np.any(self.monotonic_cst < -1) or np.any(self.monotonic_cst > 1): + if np.any(monotonic_cst < -1) or np.any(monotonic_cst > 1): raise ValueError( "monotonic_cst must be None or an array-like of -1, 0 or 1.") if not isinstance(self.splitter, Splitter): @@ -357,7 +357,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, min_samples_leaf, min_weight_leaf, random_state, - self.monotonic_cst) + monotonic_cst) if is_classifier(self): self.tree_ = Tree(self.n_features_, From 803efe0520d883df3eeb8f0d6d70bef8eafae2d4 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 30 Jan 2020 11:56:06 +0100 Subject: [PATCH 026/174] Preliminary pass at adding upper and lower bounds on splits --- sklearn/tree/_splitter.pxd | 5 ++++- sklearn/tree/_splitter.pyx | 26 ++++++++++++++++---------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index a591543f24cc4..af99032463219 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -97,4 +97,7 @@ cdef class Splitter: cdef double node_impurity(self) nogil - cdef inline bint split_passes_monotonic_check(self, INT32_t monotonic_constraint) nogil + cdef inline bint split_passes_monotonic_check(self, + INT32_t monotonic_cst, + double lower_bound, + double upper_bound) nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index dfc5da4a9e162..3f29495a62ccd 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -237,16 +237,22 @@ cdef class Splitter: return self.criterion.node_impurity() - cdef inline bint split_passes_monotonic_check(self, INT32_t monotonic_constraint) nogil: + cdef inline bint split_passes_monotonic_check(self, INT32_t monotonic_cst, + double lower_bound, + double upper_bound) nogil: """Check monotonic constraint is satisfied at the current split""" - if monotonic_constraint == 0: # No constraint + if monotonic_cst == 0: # No constraint return 1 else: - delta = (self.criterion.sum_left[0] * self.criterion.weighted_n_right \ - - self.criterion.sum_right[0] * self.criterion.weighted_n_left) \ - * monotonic_constraint - return delta <= 0 + b_middle = (self.criterion.sum_left[0] * self.criterion.weighted_n_right + - self.criterion.sum_right[0] * self.criterion.weighted_n_left) \ + * monotonic_cst + #b_left = (self.criterion.sum_left[0] - lower_bound * self.criterion.weighted_n_left) \ + # * monotonic_cst >= 0 + #b_right = (self.criterion.sum_right[0] - upper_bound * self.criterion.weighted_n_right) \ + # * monotonic_cst <= 0 + return b_middle <= 0 # & b_left & b_right cdef class BaseDenseSplitter(Splitter): cdef const DTYPE_t[:, :] X @@ -437,7 +443,7 @@ cdef class BestSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - if not self.split_passes_monotonic_check(monotonic_constraint): + if not self.split_passes_monotonic_check(monotonic_constraint, -INFINITY, INFINITY): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -761,7 +767,7 @@ cdef class RandomSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - if not self.split_passes_monotonic_check(monotonic_constraint): + if not self.split_passes_monotonic_check(monotonic_constraint, -INFINITY, INFINITY): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1313,7 +1319,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - if not self.split_passes_monotonic_check(monotonic_constraint): + if not self.split_passes_monotonic_check(monotonic_constraint, -INFINITY, INFINITY): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1549,7 +1555,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - if not self.split_passes_monotonic_check(monotonic_constraint): + if not self.split_passes_monotonic_check(monotonic_constraint, -INFINITY, INFINITY): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() From 23da3e7eae95f881afd6e5982e9094c8460cd608 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Fri, 31 Jan 2020 17:52:02 +0100 Subject: [PATCH 027/174] Work towards adding upper and lower bounds on splits and nodes --- sklearn/tree/_splitter.pxd | 8 ++++-- sklearn/tree/_splitter.pyx | 51 +++++++++++++++++++++----------------- sklearn/tree/_tree.pxd | 6 ++++- sklearn/tree/_tree.pyx | 29 +++++++++++++++------- 4 files changed, 59 insertions(+), 35 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index af99032463219..fb7f1dccffd56 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -30,6 +30,8 @@ cdef struct SplitRecord: double improvement # Impurity improvement given parent node. double impurity_left # Impurity of the left split. double impurity_right # Impurity of the right split. + double lower_bound # Lower bound on value of left child for monotonicity + double upper_bound # Upper bound on value of right child for monotonicity cdef class Splitter: # The splitter searches in the input space for a feature and a threshold @@ -59,7 +61,7 @@ cdef class Splitter: cdef const DOUBLE_t[:, ::1] y cdef DOUBLE_t* sample_weight - cdef INT32_t* monotonic # Monotonicity constraints + cdef INT32_t* monotonic_cst # Monotonicity constraints # -1: monotonically decreasing # 0: no constraint # +1: monotonically increasing @@ -91,7 +93,9 @@ cdef class Splitter: cdef int node_split(self, double impurity, # Impurity of the node SplitRecord* split, - SIZE_t* n_constant_features) nogil except -1 + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound) nogil except -1 cdef void node_value(self, double* dest) nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 3f29495a62ccd..9a4354d62ddfe 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -60,7 +60,7 @@ cdef class Splitter: def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, np.ndarray[INT32_t] monotonic): + object random_state, np.ndarray[INT32_t] monotonic_cst): """ Parameters ---------- @@ -98,7 +98,7 @@ cdef class Splitter: self.min_samples_leaf = min_samples_leaf self.min_weight_leaf = min_weight_leaf self.random_state = random_state - self.monotonic = monotonic.data + self.monotonic_cst = monotonic_cst.data def __dealloc__(self): """Destructor.""" @@ -216,7 +216,8 @@ cdef class Splitter: return 0 cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) nogil except -1: + SIZE_t* n_constant_features, double lower_bound, + double upper_bound) nogil except -1: """Find the best split on node samples[start:end]. This is a placeholder method. The majority of computation will be done @@ -245,14 +246,14 @@ cdef class Splitter: if monotonic_cst == 0: # No constraint return 1 else: - b_middle = (self.criterion.sum_left[0] * self.criterion.weighted_n_right - - self.criterion.sum_right[0] * self.criterion.weighted_n_left) \ - * monotonic_cst - #b_left = (self.criterion.sum_left[0] - lower_bound * self.criterion.weighted_n_left) \ - # * monotonic_cst >= 0 - #b_right = (self.criterion.sum_right[0] - upper_bound * self.criterion.weighted_n_right) \ - # * monotonic_cst <= 0 - return b_middle <= 0 # & b_left & b_right + delta_middle = (self.criterion.sum_left[0] * self.criterion.weighted_n_right + - self.criterion.sum_right[0] * self.criterion.weighted_n_left) \ + * monotonic_cst + delta_left = (self.criterion.sum_left[0] - lower_bound * self.criterion.weighted_n_left) \ + * monotonic_cst + delta_right = (self.criterion.sum_right[0] - upper_bound * self.criterion.weighted_n_right) \ + * monotonic_cst + return (delta_middle <= 0) & (delta_left >= 0) & (delta_right <= 0) cdef class BaseDenseSplitter(Splitter): cdef const DTYPE_t[:, :] X @@ -265,12 +266,12 @@ cdef class BaseDenseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, np.ndarray[INT32_t] monotonic): + object random_state, np.ndarray[INT32_t] monotonic_cst): self.X_idx_sorted_ptr = NULL self.X_idx_sorted_stride = 0 self.sample_mask = NULL - self.monotonic = monotonic.data + self.monotonic_cst = monotonic_cst.data cdef int init(self, object X, @@ -300,7 +301,8 @@ cdef class BestSplitter(BaseDenseSplitter): self.random_state), self.__getstate__()) cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) nogil except -1: + SIZE_t* n_constant_features, double lower_bound, + double upper_bound) nogil except -1: """Find the best split on node samples[start:end] Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -393,7 +395,7 @@ cdef class BestSplitter(BaseDenseSplitter): # f_j in the interval [n_total_constants, f_i[ current.feature = features[f_j] - monotonic_constraint = self.monotonic[current.feature] + monotonic_constraint = self.monotonic_cst[current.feature] # Sort samples along that feature; by # copying the values into an array and # sorting the array in a manner which utilizes the cache more @@ -619,7 +621,8 @@ cdef class RandomSplitter(BaseDenseSplitter): self.random_state), self.__getstate__()) cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) nogil except -1: + SIZE_t* n_constant_features, double lower_bound, + double upper_bound) nogil except -1: """Find the best random split on node samples[start:end] Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -706,7 +709,7 @@ cdef class RandomSplitter(BaseDenseSplitter): current.feature = features[f_j] - monotonic_constraint = self.monotonic[current.feature] + monotonic_constraint = self.monotonic_cst[current.feature] # Find min, max min_feature_value = self.X[samples[start], current.feature] @@ -824,7 +827,7 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, np.ndarray[INT32_t] monotonic): + object random_state, np.ndarray[INT32_t] monotonic_cst): # Parent __cinit__ is automatically called self.X_data = NULL @@ -835,7 +838,7 @@ cdef class BaseSparseSplitter(Splitter): self.index_to_samples = NULL self.sorted_samples = NULL - self.monotonic = monotonic.data + self.monotonic_cst = monotonic_cst.data def __dealloc__(self): """Deallocate memory.""" @@ -1143,7 +1146,8 @@ cdef class BestSparseSplitter(BaseSparseSplitter): self.random_state), self.__getstate__()) cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) nogil except -1: + SIZE_t* n_constant_features, double lower_bound, + double upper_bound) nogil except -1: """Find the best split on node samples[start:end], using sparse features Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -1245,7 +1249,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): &end_negative, &start_positive, &is_samples_sorted) - monotonic_constraint = self.monotonic[current.feature] + monotonic_constraint = self.monotonic_cst[current.feature] # Sort the positive and negative parts of `Xf` sort(Xf + start, samples + start, end_negative - start) @@ -1377,7 +1381,8 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): self.random_state), self.__getstate__()) cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) nogil except -1: + SIZE_t* n_constant_features, double lower_bound, + double upper_bound) nogil except -1: """Find a random split on node samples[start:end], using sparse features Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -1483,7 +1488,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): &end_negative, &start_positive, &is_samples_sorted) - monotonic_constraint = self.monotonic[current.feature] + monotonic_constraint = self.monotonic_cst[current.feature] # Add one or two zeros in Xf, if there is any if end_negative < start_positive: diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 14b03103deff0..cc9417852ffa0 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -32,6 +32,8 @@ cdef struct Node: DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) SIZE_t n_node_samples # Number of samples at the node DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node + DOUBLE_t children_lower_bound # Lower bound for the node's children values for monotonicity constraints + DOUBLE_t children_upper_bound # Upper bound for the node's children values for monotonicity constraints cdef class Tree: @@ -58,7 +60,9 @@ cdef class Tree: cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, - double weighted_n_samples) nogil except -1 + double weighted_n_samples, + double children_lower_bound, + double children_upper_bound) nogil except -1 cdef int _resize(self, SIZE_t capacity) nogil except -1 cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1 diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index bbe2c8a796578..77da95899243e 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -70,9 +70,9 @@ cdef SIZE_t INITIAL_STACK_SIZE = 10 # Repeat struct definition for numpy NODE_DTYPE = np.dtype({ 'names': ['left_child', 'right_child', 'feature', 'threshold', 'impurity', - 'n_node_samples', 'weighted_n_node_samples'], + 'n_node_samples', 'weighted_n_node_samples', 'children_lower_bound', 'children_upper_bound'], 'formats': [np.intp, np.intp, np.intp, np.float64, np.float64, np.intp, - np.float64], + np.float64, np.float64, np.float64], 'offsets': [ &( NULL).left_child, &( NULL).right_child, @@ -80,7 +80,9 @@ NODE_DTYPE = np.dtype({ &( NULL).threshold, &( NULL).impurity, &( NULL).n_node_samples, - &( NULL).weighted_n_node_samples + &( NULL).weighted_n_node_samples, + &( NULL).children_lower_bound, + &( NULL).children_upper_bound ] }) @@ -233,7 +235,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): (impurity <= min_impurity_split)) if not is_leaf: - splitter.node_split(impurity, &split, &n_constant_features) + # TODO add proper lower and upper bounds + splitter.node_split(impurity, &split, &n_constant_features, -INFINITY, INFINITY) # If EPSILON=0 in the below comparison, float precision # issues stop splitting, producing trees that are # dissimilar to v0.18 @@ -241,9 +244,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): (split.improvement + EPSILON < min_impurity_decrease)) + # TODO add proper lower and upper bounds node_id = tree._add_node(parent, is_left, is_leaf, split.feature, split.threshold, impurity, n_node_samples, - weighted_n_node_samples) + weighted_n_node_samples, -INFINITY, INFINITY) if node_id == SIZE_MAX: rc = -1 @@ -457,18 +461,20 @@ cdef class BestFirstTreeBuilder(TreeBuilder): impurity <= min_impurity_split) if not is_leaf: - splitter.node_split(impurity, &split, &n_constant_features) + # TODO add proper lower and upper bounds + splitter.node_split(impurity, &split, &n_constant_features, -INFINITY, INFINITY) # If EPSILON=0 in the below comparison, float precision issues stop # splitting early, producing trees that are dissimilar to v0.18 is_leaf = (is_leaf or split.pos >= end or split.improvement + EPSILON < min_impurity_decrease) + # TODO add proper lower and upper bounds node_id = tree._add_node(parent - tree.nodes if parent != NULL else _TREE_UNDEFINED, is_left, is_leaf, split.feature, split.threshold, impurity, n_node_samples, - weighted_n_node_samples) + weighted_n_node_samples, -INFINITY, INFINITY) if node_id == SIZE_MAX: return -1 @@ -726,7 +732,9 @@ cdef class Tree: cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, - double weighted_n_node_samples) nogil except -1: + double weighted_n_node_samples, + double children_lower_bound, + double children_upper_bound) nogil except -1: """Add a node to the tree. The new node registers itself as the child of its parent. @@ -743,6 +751,8 @@ cdef class Tree: node.impurity = impurity node.n_node_samples = n_node_samples node.weighted_n_node_samples = weighted_n_node_samples + node.children_lower_bound = children_lower_bound + node.children_upper_bound = children_upper_bound if parent != _TREE_UNDEFINED: if is_left: @@ -1616,7 +1626,8 @@ cdef _build_pruned_tree( new_node_id = tree._add_node( parent, is_left, is_leaf, node.feature, node.threshold, node.impurity, node.n_node_samples, - node.weighted_n_node_samples) + node.weighted_n_node_samples, node.children_lower_bound, + node.children_upper_bound) if new_node_id == SIZE_MAX: rc = -1 From f73c820214b3140b5f67ba9a84be567795930e6b Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Fri, 31 Jan 2020 18:19:00 +0100 Subject: [PATCH 028/174] YOLO --- sklearn/tree/_splitter.pyx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 9a4354d62ddfe..27e95e2931299 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -445,6 +445,7 @@ cdef class BestSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # TODO if not self.split_passes_monotonic_check(monotonic_constraint, -INFINITY, INFINITY): continue @@ -770,6 +771,7 @@ cdef class RandomSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # TODO if not self.split_passes_monotonic_check(monotonic_constraint, -INFINITY, INFINITY): continue @@ -1323,6 +1325,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # TODO if not self.split_passes_monotonic_check(monotonic_constraint, -INFINITY, INFINITY): continue @@ -1560,6 +1563,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # TODO if not self.split_passes_monotonic_check(monotonic_constraint, -INFINITY, INFINITY): continue From d43a4a921e6e1578e9b09e0e796d854579533360 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Mon, 3 Feb 2020 19:54:49 +0100 Subject: [PATCH 029/174] Do the easier TODOs --- sklearn/tree/_splitter.pyx | 12 ++++-------- sklearn/tree/_tree.pyx | 13 +++++++++---- sklearn/tree/_utils.pxd | 2 ++ 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 27e95e2931299..edd4eecf9528a 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -445,8 +445,7 @@ cdef class BestSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - # TODO - if not self.split_passes_monotonic_check(monotonic_constraint, -INFINITY, INFINITY): + if not self.split_passes_monotonic_check(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -771,8 +770,7 @@ cdef class RandomSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - # TODO - if not self.split_passes_monotonic_check(monotonic_constraint, -INFINITY, INFINITY): + if not self.split_passes_monotonic_check(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1325,8 +1323,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - # TODO - if not self.split_passes_monotonic_check(monotonic_constraint, -INFINITY, INFINITY): + if not self.split_passes_monotonic_check(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1563,8 +1560,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - # TODO - if not self.split_passes_monotonic_check(monotonic_constraint, -INFINITY, INFINITY): + if not self.split_passes_monotonic_check(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 77da95899243e..ef8ae25ac4a95 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -191,6 +191,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef SIZE_t node_id cdef double impurity = INFINITY + cdef double children_lower_bound = -INFINITY + cdef double children_upper_bound = INFINITY cdef SIZE_t n_constant_features cdef bint is_leaf cdef bint first = 1 @@ -218,6 +220,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): is_left = stack_record.is_left impurity = stack_record.impurity n_constant_features = stack_record.n_constant_features + children_lower_bound = stack_record.children_lower_bound + children_upper_bound = stack_record.children_upper_bound n_node_samples = end - start splitter.node_reset(start, end, &weighted_n_node_samples) @@ -235,8 +239,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): (impurity <= min_impurity_split)) if not is_leaf: - # TODO add proper lower and upper bounds - splitter.node_split(impurity, &split, &n_constant_features, -INFINITY, INFINITY) + splitter.node_split(impurity, &split, &n_constant_features, children_lower_bound, children_upper_bound) # If EPSILON=0 in the below comparison, float precision # issues stop splitting, producing trees that are # dissimilar to v0.18 @@ -447,6 +450,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef bint is_leaf cdef SIZE_t n_left, n_right cdef double imp_diff + cdef double children_lower_bound = parent.children_lower_bound if parent != NULL else -INFINITY + cdef double children_upper_bound = parent.children_upper_bound if parent != NULL else INFINITY splitter.node_reset(start, end, &weighted_n_node_samples) @@ -461,8 +466,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): impurity <= min_impurity_split) if not is_leaf: - # TODO add proper lower and upper bounds - splitter.node_split(impurity, &split, &n_constant_features, -INFINITY, INFINITY) + splitter.node_split(impurity, &split, &n_constant_features, children_lower_bound, + children_upper_bound) # If EPSILON=0 in the below comparison, float precision issues stop # splitting early, producing trees that are dissimilar to v0.18 is_leaf = (is_leaf or split.pos >= end or diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index e60e148871385..c6c991c1514e3 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -75,6 +75,8 @@ cdef struct StackRecord: bint is_left double impurity SIZE_t n_constant_features + double children_lower_bound + double children_upper_bound cdef class Stack: cdef SIZE_t capacity From 822bd5b60ca8c0b6b8e26c1bf7616fc7cfb9db95 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Mon, 10 Feb 2020 20:14:08 +0100 Subject: [PATCH 030/174] Roll back adding bound information to nodes --- sklearn/tree/_tree.pxd | 7 ++----- sklearn/tree/_tree.pyx | 24 +++++++----------------- 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index cc9417852ffa0..a65ff567735ec 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -32,8 +32,6 @@ cdef struct Node: DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) SIZE_t n_node_samples # Number of samples at the node DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node - DOUBLE_t children_lower_bound # Lower bound for the node's children values for monotonicity constraints - DOUBLE_t children_upper_bound # Upper bound for the node's children values for monotonicity constraints cdef class Tree: @@ -60,9 +58,8 @@ cdef class Tree: cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, - double weighted_n_samples, - double children_lower_bound, - double children_upper_bound) nogil except -1 + double weighted_n_samples) nogil except -1 + cdef int _resize(self, SIZE_t capacity) nogil except -1 cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1 diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index ef8ae25ac4a95..570bbf1f506f0 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -70,7 +70,7 @@ cdef SIZE_t INITIAL_STACK_SIZE = 10 # Repeat struct definition for numpy NODE_DTYPE = np.dtype({ 'names': ['left_child', 'right_child', 'feature', 'threshold', 'impurity', - 'n_node_samples', 'weighted_n_node_samples', 'children_lower_bound', 'children_upper_bound'], + 'n_node_samples', 'weighted_n_node_samples'], 'formats': [np.intp, np.intp, np.intp, np.float64, np.float64, np.intp, np.float64, np.float64, np.float64], 'offsets': [ @@ -80,9 +80,7 @@ NODE_DTYPE = np.dtype({ &( NULL).threshold, &( NULL).impurity, &( NULL).n_node_samples, - &( NULL).weighted_n_node_samples, - &( NULL).children_lower_bound, - &( NULL).children_upper_bound + &( NULL).weighted_n_node_samples ] }) @@ -247,10 +245,9 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): (split.improvement + EPSILON < min_impurity_decrease)) - # TODO add proper lower and upper bounds node_id = tree._add_node(parent, is_left, is_leaf, split.feature, split.threshold, impurity, n_node_samples, - weighted_n_node_samples, -INFINITY, INFINITY) + weighted_n_node_samples) if node_id == SIZE_MAX: rc = -1 @@ -450,8 +447,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef bint is_leaf cdef SIZE_t n_left, n_right cdef double imp_diff - cdef double children_lower_bound = parent.children_lower_bound if parent != NULL else -INFINITY - cdef double children_upper_bound = parent.children_upper_bound if parent != NULL else INFINITY + cdef double node_value splitter.node_reset(start, end, &weighted_n_node_samples) @@ -473,13 +469,12 @@ cdef class BestFirstTreeBuilder(TreeBuilder): is_leaf = (is_leaf or split.pos >= end or split.improvement + EPSILON < min_impurity_decrease) - # TODO add proper lower and upper bounds node_id = tree._add_node(parent - tree.nodes if parent != NULL else _TREE_UNDEFINED, is_left, is_leaf, split.feature, split.threshold, impurity, n_node_samples, - weighted_n_node_samples, -INFINITY, INFINITY) + weighted_n_node_samples) if node_id == SIZE_MAX: return -1 @@ -737,9 +732,7 @@ cdef class Tree: cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, - double weighted_n_node_samples, - double children_lower_bound, - double children_upper_bound) nogil except -1: + double weighted_n_node_samples) nogil except -1: """Add a node to the tree. The new node registers itself as the child of its parent. @@ -756,8 +749,6 @@ cdef class Tree: node.impurity = impurity node.n_node_samples = n_node_samples node.weighted_n_node_samples = weighted_n_node_samples - node.children_lower_bound = children_lower_bound - node.children_upper_bound = children_upper_bound if parent != _TREE_UNDEFINED: if is_left: @@ -1631,8 +1622,7 @@ cdef _build_pruned_tree( new_node_id = tree._add_node( parent, is_left, is_leaf, node.feature, node.threshold, node.impurity, node.n_node_samples, - node.weighted_n_node_samples, node.children_lower_bound, - node.children_upper_bound) + node.weighted_n_node_samples) if new_node_id == SIZE_MAX: rc = -1 From e4854165e08184861a094061bcde420d548aa6aa Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Mon, 10 Feb 2020 20:26:40 +0100 Subject: [PATCH 031/174] Add bound information to intermediate tree-building structures --- sklearn/tree/_tree.pyx | 47 +++++++++++++++++++++++++++-------------- sklearn/tree/_utils.pxd | 8 +++++-- sklearn/tree/_utils.pyx | 10 +++++++-- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 570bbf1f506f0..3cd2b967e7eea 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -189,8 +189,9 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef SIZE_t node_id cdef double impurity = INFINITY - cdef double children_lower_bound = -INFINITY - cdef double children_upper_bound = INFINITY + cdef double children_lower_bound + cdef double children_upper_bound + cdef double node_value cdef SIZE_t n_constant_features cdef bint is_leaf cdef bint first = 1 @@ -202,7 +203,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): with nogil: # push root node onto stack - rc = stack.push(0, n_node_samples, 0, _TREE_UNDEFINED, 0, INFINITY, 0) + rc = stack.push(0, n_node_samples, 0, _TREE_UNDEFINED, 0, INFINITY, 0, -INFINITY, INFINITY) if rc == -1: # got return code -1 - out-of-memory with gil: @@ -257,16 +258,17 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # inspection and interpretation splitter.node_value(tree.value + node_id * tree.value_stride) + node_value = splitter.criterion.sum_total[node_id] / weighted_n_node_samples if not is_leaf: # Push right child on stack rc = stack.push(split.pos, end, depth + 1, node_id, 0, - split.impurity_right, n_constant_features) + split.impurity_right, n_constant_features, node_value, children_upper_bound) if rc == -1: break # Push left child on stack rc = stack.push(start, split.pos, depth + 1, node_id, 1, - split.impurity_left, n_constant_features) + split.impurity_left, n_constant_features, children_lower_bound, node_value) if rc == -1: break @@ -293,7 +295,8 @@ cdef inline int _add_to_frontier(PriorityHeapRecord* rec, """ return frontier.push(rec.node_id, rec.start, rec.end, rec.pos, rec.depth, rec.is_leaf, rec.improvement, rec.impurity, - rec.impurity_left, rec.impurity_right) + rec.impurity_left, rec.impurity_right, + rec.children_lower_bound, rec.children_upper_bound) cdef class BestFirstTreeBuilder(TreeBuilder): @@ -335,6 +338,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_leaf = self.min_samples_leaf cdef double min_weight_leaf = self.min_weight_leaf cdef SIZE_t min_samples_split = self.min_samples_split + cdef double children_lower_bound + cdef double children_upper_bound # Recursive partition (without actual recursion) splitter.init(X, y, sample_weight_ptr, X_idx_sorted) @@ -359,6 +364,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # add root to frontier rc = self._add_split_node(splitter, tree, 0, n_node_samples, INFINITY, IS_FIRST, IS_LEFT, NULL, 0, + -INFINITY, INFINITY, &split_node_left) if rc >= 0: rc = _add_to_frontier(&split_node_left, frontier) @@ -382,6 +388,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): else: # Node is expandable + node_value = splitter.criterion.sum_total[0] / splitter.criterion.weighted_n_node_samples # Decrement number of split nodes available max_split_nodes -= 1 @@ -392,6 +399,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): record.impurity_left, IS_NOT_FIRST, IS_LEFT, node, record.depth + 1, + record.children_lower_bound, + record.children_upper_bound, &split_node_left) if rc == -1: break @@ -405,6 +414,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): record.impurity_right, IS_NOT_FIRST, IS_NOT_LEFT, node, record.depth + 1, + record.children_lower_bound, + record.children_upper_bound, &split_node_right) if rc == -1: break @@ -433,7 +444,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef inline int _add_split_node(self, Splitter splitter, Tree tree, SIZE_t start, SIZE_t end, double impurity, bint is_first, bint is_left, Node* parent, - SIZE_t depth, + SIZE_t depth, double children_lower_bound, + double children_upper_bound, PriorityHeapRecord* res) nogil except -1: """Adds node w/ partition ``[start, end)`` to the frontier. """ cdef SplitRecord split @@ -486,6 +498,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): res.end = end res.depth = depth res.impurity = impurity + node_value = splitter.criterion.sum_total[node_id] / weighted_n_node_samples if not is_leaf: # is split node @@ -502,6 +515,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): res.improvement = 0.0 res.impurity_left = impurity res.impurity_right = impurity + res.children_lower_bound = children_lower_bound if is_left else node_value + res.children_upper_bound = node_value if is_left else children_upper_bound return 0 @@ -1369,7 +1384,7 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT weighted_n_node_samples[i] * impurity[i] / total_sum_weights) # Push root node, using StackRecord.start as node id - rc = stack.push(0, 0, 0, -1, 0, 0, 0) + rc = stack.push(0, 0, 0, -1, 0, 0, 0, -INFINITY, INFINITY) if rc == -1: with gil: raise MemoryError("pruning tree") @@ -1382,12 +1397,12 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT # ... and child_r[node_idx] == _TREE_LEAF: leaves_in_subtree[node_idx] = 1 else: - rc = stack.push(child_l[node_idx], 0, 0, node_idx, 0, 0, 0) + rc = stack.push(child_l[node_idx], 0, 0, node_idx, 0, 0, 0, -INFINITY, INFINITY) if rc == -1: with gil: raise MemoryError("pruning tree") - rc = stack.push(child_r[node_idx], 0, 0, node_idx, 0, 0, 0) + rc = stack.push(child_r[node_idx], 0, 0, node_idx, 0, 0, 0, -INFINITY, INFINITY) if rc == -1: with gil: raise MemoryError("pruning tree") @@ -1430,7 +1445,7 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT break # stack uses only the start variable - rc = stack.push(pruned_branch_node_idx, 0, 0, 0, 0, 0, 0) + rc = stack.push(pruned_branch_node_idx, 0, 0, 0, 0, 0, 0, -INFINITY, INFINITY) if rc == -1: with gil: raise MemoryError("pruning tree") @@ -1448,11 +1463,11 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT if child_l[node_idx] != _TREE_LEAF: # ... and child_r[node_idx] != _TREE_LEAF: - rc = stack.push(child_l[node_idx], 0, 0, 0, 0, 0, 0) + rc = stack.push(child_l[node_idx], 0, 0, 0, 0, 0, 0, -INFINITY, INFINITY) if rc == -1: with gil: raise MemoryError("pruning tree") - rc = stack.push(child_r[node_idx], 0, 0, 0, 0, 0, 0) + rc = stack.push(child_r[node_idx], 0, 0, 0, 0, 0, 0, -INFINITY, INFINITY) if rc == -1: with gil: raise MemoryError("pruning tree") @@ -1603,7 +1618,7 @@ cdef _build_pruned_tree( with nogil: # push root node onto stack - rc = stack.push(0, 0, 0, _TREE_UNDEFINED, 0, 0.0, 0) + rc = stack.push(0, 0, 0, _TREE_UNDEFINED, 0, 0.0, 0, -INFINITY, INFINITY) if rc == -1: with gil: raise MemoryError("pruning tree") @@ -1636,13 +1651,13 @@ cdef _build_pruned_tree( if not is_leaf: # Push right child on stack rc = stack.push( - node.right_child, 0, depth + 1, new_node_id, 0, 0.0, 0) + node.right_child, 0, depth + 1, new_node_id, 0, 0.0, 0, -INFINITY, INFINITY) if rc == -1: break # push left child on stack rc = stack.push( - node.left_child, 0, depth + 1, new_node_id, 1, 0.0, 0) + node.left_child, 0, depth + 1, new_node_id, 1, 0.0, 0, -INFINITY, INFINITY) if rc == -1: break diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index c6c991c1514e3..8b36ce4684eb0 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -86,7 +86,8 @@ cdef class Stack: cdef bint is_empty(self) nogil cdef int push(self, SIZE_t start, SIZE_t end, SIZE_t depth, SIZE_t parent, bint is_left, double impurity, - SIZE_t n_constant_features) nogil except -1 + SIZE_t n_constant_features, double children_lower_bound, + double children_upper_bound) nogil except -1 cdef int pop(self, StackRecord* res) nogil @@ -106,6 +107,8 @@ cdef struct PriorityHeapRecord: double impurity_left double impurity_right double improvement + double children_lower_bound + double children_upper_bound cdef class PriorityHeap: cdef SIZE_t capacity @@ -118,7 +121,8 @@ cdef class PriorityHeap: cdef int push(self, SIZE_t node_id, SIZE_t start, SIZE_t end, SIZE_t pos, SIZE_t depth, bint is_leaf, double improvement, double impurity, double impurity_left, - double impurity_right) nogil except -1 + double impurity_right, double children_lower_bound, + double children_upper_bound) nogil except -1 cdef int pop(self, PriorityHeapRecord* res) nogil # ============================================================================= diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index e6552debd3149..1cfdbb265d700 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -111,7 +111,8 @@ cdef class Stack: cdef int push(self, SIZE_t start, SIZE_t end, SIZE_t depth, SIZE_t parent, bint is_left, double impurity, - SIZE_t n_constant_features) nogil except -1: + SIZE_t n_constant_features, double children_lower_bound, + double children_upper_bound) nogil except -1: """Push a new element onto the stack. Return -1 in case of failure to allocate memory (and raise MemoryError) @@ -134,6 +135,8 @@ cdef class Stack: stack[top].is_left = is_left stack[top].impurity = impurity stack[top].n_constant_features = n_constant_features + stack[top].children_lower_bound = children_lower_bound + stack[top].children_upper_bound = children_upper_bound # Increment stack pointer self.top = top + 1 @@ -227,7 +230,8 @@ cdef class PriorityHeap: cdef int push(self, SIZE_t node_id, SIZE_t start, SIZE_t end, SIZE_t pos, SIZE_t depth, bint is_leaf, double improvement, double impurity, double impurity_left, - double impurity_right) nogil except -1: + double impurity_right, double children_lower_bound, + double children_upper_bound) nogil except -1: """Push record on the priority heap. Return -1 in case of failure to allocate memory (and raise MemoryError) @@ -254,6 +258,8 @@ cdef class PriorityHeap: heap[heap_ptr].impurity_left = impurity_left heap[heap_ptr].impurity_right = impurity_right heap[heap_ptr].improvement = improvement + heap[heap_ptr].children_lower_bound = children_lower_bound + heap[heap_ptr].children_upper_bound = children_upper_bound # Heapify up self.heapify_up(heap, heap_ptr) From 32e4c70d2cfc6a32591d39e94ea119b4be70b255 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Mon, 27 Apr 2020 22:20:41 +0200 Subject: [PATCH 032/174] Add monotonicity tests for tree with single feature --- sklearn/tree/tests/test_tree.py | 73 +++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 4afc77399e06c..4a386f8573b56 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1985,6 +1985,79 @@ def test_montonic_constraints(): assert(np.min(y2 - y0) >= 0) +def is_increasing(a): + return (np.diff(a) >= 0.0).all() + + +def is_decreasing(a): + return (np.diff(a) <= 0.0).all() + + +def assert_children_values_monotonic_bounded(tree_, monotonic_cst): + # Flip values so that only need to check for increasing constraint + values = monotonic_cst * tree_.value + + for i in range(tree_.node_count): + if tree_.feature[i] >= 0: + # Check monotonicity + i_left = tree_.children_left[i] + i_right = tree_.children_right[i] + assert(values[i_left] <= values[i_right]) + val_middle = (values[i_left] + values[i_right]) / 2 + # Check bounds + if tree_.feature[i_left] >= 0: + i_left_right = tree_.children_right[i_left] + assert(values[i_left_right] <= val_middle) + if tree_.feature[i_right] >= 0: + i_right_left = tree_.children_left[i_right] + assert(val_middle <= values[i_right_left]) + + +def assert_leaves_values_monotonic(tree, monotonic_cst): + leaf_values = [] + for i in range(tree.node_count): + if tree.feature[i] < 0: + leaf_values.append(float(tree.value[i])) + if monotonic_cst == 1: + assert is_increasing(leaf_values) + elif monotonic_cst == -1: + assert is_decreasing(leaf_values) + + +@pytest.mark.parametrize('monotonic_cst', (-1, 1)) +@pytest.mark.parametrize('seed', range(4)) +def test_nodes_values(monotonic_cst, seed): + # Adaptation from test_nodes_values in test_montonic_constraints.py + # Build a single tree with only one feature, and make sure the nodes + # values respect the monotonic constraints. + + # Considering the following tree with a monotonic POS constraint, we + # should have: + # + # root + # / \ + # 5 10 # middle = 7.5 + # / \ / \ + # a b c d + # + # a <= b <= middle <= c <= d (assert_children_values_monotonic_bounded) + # a <= b <= c <= d (assert_leaves_values_monotonic) + # + # The last one is a consequence of the first, but can't hurt to check + + rng = np.random.RandomState(seed) + n_samples = 1000 + n_features = 1 + X = rng.rand(n_samples, n_features) + y = rng.rand(n_samples) + + clf = DecisionTreeRegressor(monotonic_cst=[monotonic_cst]) + clf.fit(X, y) + + assert_children_values_monotonic_bounded(clf.tree_, monotonic_cst) + assert_leaves_values_monotonic(clf.tree_, monotonic_cst) + + def test_classes_deprecated(): X = [[0, 0], [2, 2], [4, 6], [10, 11]] y = [0.5, 2.5, 3.5, 5.5] From 8c3984240c260897a40ddf917135c182112526a8 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 28 Apr 2020 13:04:43 +0200 Subject: [PATCH 033/174] Add test for best first tree builder --- sklearn/tree/tests/test_tree.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 4a386f8573b56..17dd9d8ac1302 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2025,8 +2025,9 @@ def assert_leaves_values_monotonic(tree, monotonic_cst): @pytest.mark.parametrize('monotonic_cst', (-1, 1)) +@pytest.mark.parametrize('depth_first', (True, False)) @pytest.mark.parametrize('seed', range(4)) -def test_nodes_values(monotonic_cst, seed): +def test_nodes_values(monotonic_cst, depth_first, seed): # Adaptation from test_nodes_values in test_montonic_constraints.py # Build a single tree with only one feature, and make sure the nodes # values respect the monotonic constraints. @@ -2051,7 +2052,12 @@ def test_nodes_values(monotonic_cst, seed): X = rng.rand(n_samples, n_features) y = rng.rand(n_samples) - clf = DecisionTreeRegressor(monotonic_cst=[monotonic_cst]) + if depth_first: + # No max_leaf_nodes, default depth first tree builder + clf = DecisionTreeRegressor(monotonic_cst=[monotonic_cst]) + else: + # max_leaf_nodes triggers depth first tree builder + clf = DecisionTreeRegressor(monotonic_cst=[monotonic_cst], max_leaf_nodes=5) clf.fit(X, y) assert_children_values_monotonic_bounded(clf.tree_, monotonic_cst) From 89fd890f1bc254fd665593a0c4de1decd6ce5d5a Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 28 Apr 2020 16:12:01 +0200 Subject: [PATCH 034/174] Fix monotonicity checks related to bounds --- sklearn/tree/_splitter.pyx | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index edd4eecf9528a..8f187d1439c67 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -238,22 +238,22 @@ cdef class Splitter: return self.criterion.node_impurity() - cdef inline bint split_passes_monotonic_check(self, INT32_t monotonic_cst, - double lower_bound, - double upper_bound) nogil: + cdef inline bint split_passes_monotonicity_check(self, INT32_t monotonic_cst, + double lower_bound, + double upper_bound) nogil: """Check monotonic constraint is satisfied at the current split""" - + cdef bint check_lower, check_upper, check_monotonic + check_lower = (self.criterion.sum_left[0] >= lower_bound * self.criterion.weighted_n_left) \ + & (self.criterion.sum_right[0] >= lower_bound * self.criterion.weighted_n_right) + check_upper = (self.criterion.sum_left[0] <= upper_bound * self.criterion.weighted_n_left) \ + & (self.criterion.sum_right[0] <= upper_bound * self.criterion.weighted_n_right) if monotonic_cst == 0: # No constraint - return 1 + return check_lower & check_upper else: - delta_middle = (self.criterion.sum_left[0] * self.criterion.weighted_n_right - - self.criterion.sum_right[0] * self.criterion.weighted_n_left) \ - * monotonic_cst - delta_left = (self.criterion.sum_left[0] - lower_bound * self.criterion.weighted_n_left) \ - * monotonic_cst - delta_right = (self.criterion.sum_right[0] - upper_bound * self.criterion.weighted_n_right) \ - * monotonic_cst - return (delta_middle <= 0) & (delta_left >= 0) & (delta_right <= 0) + check_monotonic = (self.criterion.sum_left[0] * self.criterion.weighted_n_right + - self.criterion.sum_right[0] * self.criterion.weighted_n_left) \ + * monotonic_cst <= 0 + return check_monotonic cdef class BaseDenseSplitter(Splitter): cdef const DTYPE_t[:, :] X @@ -445,7 +445,7 @@ cdef class BestSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - if not self.split_passes_monotonic_check(monotonic_constraint, lower_bound, upper_bound): + if not self.split_passes_monotonicity_check(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -770,7 +770,7 @@ cdef class RandomSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - if not self.split_passes_monotonic_check(monotonic_constraint, lower_bound, upper_bound): + if not self.split_passes_monotonicity_check(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1323,7 +1323,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - if not self.split_passes_monotonic_check(monotonic_constraint, lower_bound, upper_bound): + if not self.split_passes_monotonicity_check(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1560,7 +1560,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - if not self.split_passes_monotonic_check(monotonic_constraint, lower_bound, upper_bound): + if not self.split_passes_monotonicity_check(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() From 678b1f8976b4335f2632fcbacd958577cfc05728 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 28 Apr 2020 17:27:11 +0200 Subject: [PATCH 035/174] Fix depth-first tree building with decreasing constraint --- sklearn/tree/_tree.pyx | 35 ++++++++++++++++++++++++++++------- sklearn/tree/_utils.pxd | 4 ++-- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index e0ca4d1c1bd60..a426cee9614b2 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -189,8 +189,12 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef SIZE_t node_id cdef double impurity = INFINITY - cdef double children_lower_bound - cdef double children_upper_bound + cdef double lower_bound + cdef double upper_bound + cdef double lower_bound_left + cdef double upper_bound_left + cdef double lower_bound_right + cdef double upper_bound_right cdef double node_value cdef SIZE_t n_constant_features cdef bint is_leaf @@ -219,8 +223,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): is_left = stack_record.is_left impurity = stack_record.impurity n_constant_features = stack_record.n_constant_features - children_lower_bound = stack_record.children_lower_bound - children_upper_bound = stack_record.children_upper_bound + lower_bound = stack_record.lower_bound + upper_bound = stack_record.upper_bound n_node_samples = end - start splitter.node_reset(start, end, &weighted_n_node_samples) @@ -238,7 +242,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): (impurity <= min_impurity_split)) if not is_leaf: - splitter.node_split(impurity, &split, &n_constant_features, children_lower_bound, children_upper_bound) + splitter.node_split(impurity, &split, &n_constant_features, lower_bound, upper_bound) # If EPSILON=0 in the below comparison, float precision # issues stop splitting, producing trees that are # dissimilar to v0.18 @@ -260,15 +264,32 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): node_value = splitter.criterion.sum_total[node_id] / weighted_n_node_samples if not is_leaf: + if splitter.monotonic_cst[split.feature] == 0: + left_child_min = lower_bound + left_child_max = upper_bound + right_child_min = lower_bound + right_child_max = upper_bound + elif splitter.monotonic_cst[split.feature] == 1: + left_child_min = lower_bound + left_child_max = node_value + right_child_min = node_value + right_child_max = upper_bound + elif splitter.monotonic_cst[split.feature] == -1: + left_child_min = node_value + left_child_max = upper_bound + right_child_min = lower_bound + right_child_max = node_value + # Push right child on stack rc = stack.push(split.pos, end, depth + 1, node_id, 0, - split.impurity_right, n_constant_features, node_value, children_upper_bound) + split.impurity_right, n_constant_features, + right_child_min, right_child_max) if rc == -1: break # Push left child on stack rc = stack.push(start, split.pos, depth + 1, node_id, 1, - split.impurity_left, n_constant_features, children_lower_bound, node_value) + split.impurity_left, n_constant_features, left_child_min, left_child_max) if rc == -1: break diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 8b36ce4684eb0..3c35fc884895b 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -75,8 +75,8 @@ cdef struct StackRecord: bint is_left double impurity SIZE_t n_constant_features - double children_lower_bound - double children_upper_bound + double lower_bound + double upper_bound cdef class Stack: cdef SIZE_t capacity From 0ce8b3ab725622d0719bc5444e7a00c24e2d76da Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 28 Apr 2020 19:55:42 +0200 Subject: [PATCH 036/174] Clean up --- sklearn/tree/_splitter.pxd | 8 ++++---- sklearn/tree/_tree.pyx | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index fb7f1dccffd56..9acd2e1c93aa5 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -101,7 +101,7 @@ cdef class Splitter: cdef double node_impurity(self) nogil - cdef inline bint split_passes_monotonic_check(self, - INT32_t monotonic_cst, - double lower_bound, - double upper_bound) nogil + cdef inline bint split_passes_monotonicity_check(self, + INT32_t monotonic_cst, + double lower_bound, + double upper_bound) nogil diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index a426cee9614b2..fd098011dd6d7 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -270,15 +270,15 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): right_child_min = lower_bound right_child_max = upper_bound elif splitter.monotonic_cst[split.feature] == 1: - left_child_min = lower_bound - left_child_max = node_value - right_child_min = node_value - right_child_max = upper_bound + left_child_min = lower_bound + left_child_max = node_value + right_child_min = node_value + right_child_max = upper_bound elif splitter.monotonic_cst[split.feature] == -1: - left_child_min = node_value - left_child_max = upper_bound - right_child_min = lower_bound - right_child_max = node_value + left_child_min = node_value + left_child_max = upper_bound + right_child_min = lower_bound + right_child_max = node_value # Push right child on stack rc = stack.push(split.pos, end, depth + 1, node_id, 0, From 2fac624f0a6890d589420ee463a27fec93570d6c Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Apr 2020 00:48:51 +0200 Subject: [PATCH 037/174] Add tests for random splitter --- sklearn/tree/tests/test_tree.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 17dd9d8ac1302..c9369c0966823 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2025,9 +2025,10 @@ def assert_leaves_values_monotonic(tree, monotonic_cst): @pytest.mark.parametrize('monotonic_cst', (-1, 1)) +@pytest.mark.parametrize('splitter', ("best", "random")) @pytest.mark.parametrize('depth_first', (True, False)) @pytest.mark.parametrize('seed', range(4)) -def test_nodes_values(monotonic_cst, depth_first, seed): +def test_nodes_values(monotonic_cst, splitter, depth_first, seed): # Adaptation from test_nodes_values in test_montonic_constraints.py # Build a single tree with only one feature, and make sure the nodes # values respect the monotonic constraints. @@ -2054,10 +2055,10 @@ def test_nodes_values(monotonic_cst, depth_first, seed): if depth_first: # No max_leaf_nodes, default depth first tree builder - clf = DecisionTreeRegressor(monotonic_cst=[monotonic_cst]) + clf = DecisionTreeRegressor(splitter=splitter, monotonic_cst=[monotonic_cst]) else: # max_leaf_nodes triggers depth first tree builder - clf = DecisionTreeRegressor(monotonic_cst=[monotonic_cst], max_leaf_nodes=5) + clf = DecisionTreeRegressor(splitter=splitter, monotonic_cst=[monotonic_cst], max_leaf_nodes=5) clf.fit(X, y) assert_children_values_monotonic_bounded(clf.tree_, monotonic_cst) From cafb3092413e3736ae0de13690299259a127bb1c Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Apr 2020 01:32:32 +0200 Subject: [PATCH 038/174] Clean up Stack --- sklearn/tree/_utils.pyx | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 1cfdbb265d700..042daab74437a 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -110,9 +110,8 @@ cdef class Stack: return self.top <= 0 cdef int push(self, SIZE_t start, SIZE_t end, SIZE_t depth, SIZE_t parent, - bint is_left, double impurity, - SIZE_t n_constant_features, double children_lower_bound, - double children_upper_bound) nogil except -1: + bint is_left, double impurity, SIZE_t n_constant_features, + double lower_bound, double upper_bound) nogil except -1: """Push a new element onto the stack. Return -1 in case of failure to allocate memory (and raise MemoryError) @@ -135,8 +134,8 @@ cdef class Stack: stack[top].is_left = is_left stack[top].impurity = impurity stack[top].n_constant_features = n_constant_features - stack[top].children_lower_bound = children_lower_bound - stack[top].children_upper_bound = children_upper_bound + stack[top].lower_bound = lower_bound + stack[top].upper_bound = upper_bound # Increment stack pointer self.top = top + 1 From 4601690c16d44a95bcc7b9a28e0ecd1936470882 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Apr 2020 01:38:59 +0200 Subject: [PATCH 039/174] Clean up PriorityHeap --- sklearn/tree/_utils.pxd | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 3c35fc884895b..16964ffe07244 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -107,8 +107,8 @@ cdef struct PriorityHeapRecord: double impurity_left double impurity_right double improvement - double children_lower_bound - double children_upper_bound + double lower_bound + double upper_bound cdef class PriorityHeap: cdef SIZE_t capacity @@ -121,8 +121,8 @@ cdef class PriorityHeap: cdef int push(self, SIZE_t node_id, SIZE_t start, SIZE_t end, SIZE_t pos, SIZE_t depth, bint is_leaf, double improvement, double impurity, double impurity_left, - double impurity_right, double children_lower_bound, - double children_upper_bound) nogil except -1 + double impurity_right, double lower_bound, + double upper_bound) nogil except -1 cdef int pop(self, PriorityHeapRecord* res) nogil # ============================================================================= From 3ee90f606be1f469e2d4b4a6074f1fa98ba4f66a Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Apr 2020 01:40:22 +0200 Subject: [PATCH 040/174] Clean up PriorityHeap --- sklearn/tree/_utils.pyx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 042daab74437a..392ba46040646 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -229,8 +229,8 @@ cdef class PriorityHeap: cdef int push(self, SIZE_t node_id, SIZE_t start, SIZE_t end, SIZE_t pos, SIZE_t depth, bint is_leaf, double improvement, double impurity, double impurity_left, - double impurity_right, double children_lower_bound, - double children_upper_bound) nogil except -1: + double impurity_right, double lower_bound, + double upper_bound) nogil except -1: """Push record on the priority heap. Return -1 in case of failure to allocate memory (and raise MemoryError) @@ -257,8 +257,8 @@ cdef class PriorityHeap: heap[heap_ptr].impurity_left = impurity_left heap[heap_ptr].impurity_right = impurity_right heap[heap_ptr].improvement = improvement - heap[heap_ptr].children_lower_bound = children_lower_bound - heap[heap_ptr].children_upper_bound = children_upper_bound + heap[heap_ptr].lower_bound = lower_bound + heap[heap_ptr].upper_bound = upper_bound # Heapify up self.heapify_up(heap, heap_ptr) From 1bfdc4c98b68a9877b90b7f211288c7e5d7af113 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Apr 2020 16:30:12 +0200 Subject: [PATCH 041/174] Fix wrong test for leaves and seed trees --- sklearn/tree/tests/test_tree.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index c9369c0966823..91ceccae74d51 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1985,13 +1985,8 @@ def test_montonic_constraints(): assert(np.min(y2 - y0) >= 0) -def is_increasing(a): - return (np.diff(a) >= 0.0).all() - - -def is_decreasing(a): - return (np.diff(a) <= 0.0).all() - +def is_monotonic(a, cst): + return (cst * np.diff(a) >= 0.0).all() def assert_children_values_monotonic_bounded(tree_, monotonic_cst): # Flip values so that only need to check for increasing constraint @@ -2013,15 +2008,10 @@ def assert_children_values_monotonic_bounded(tree_, monotonic_cst): assert(val_middle <= values[i_right_left]) -def assert_leaves_values_monotonic(tree, monotonic_cst): - leaf_values = [] - for i in range(tree.node_count): - if tree.feature[i] < 0: - leaf_values.append(float(tree.value[i])) - if monotonic_cst == 1: - assert is_increasing(leaf_values) - elif monotonic_cst == -1: - assert is_decreasing(leaf_values) +def assert_tree_monotonic(clf, monotonic_cst): + X_grid = np.arange(0, 1, 0.01).reshape(-1, 1) + y_pred_grid = clf.predict(X_grid) + assert is_monotonic(y_pred_grid, monotonic_cst) @pytest.mark.parametrize('monotonic_cst', (-1, 1)) @@ -2055,14 +2045,14 @@ def test_nodes_values(monotonic_cst, splitter, depth_first, seed): if depth_first: # No max_leaf_nodes, default depth first tree builder - clf = DecisionTreeRegressor(splitter=splitter, monotonic_cst=[monotonic_cst]) + clf = DecisionTreeRegressor(splitter=splitter, monotonic_cst=[monotonic_cst], random_state=seed) else: # max_leaf_nodes triggers depth first tree builder - clf = DecisionTreeRegressor(splitter=splitter, monotonic_cst=[monotonic_cst], max_leaf_nodes=5) + clf = DecisionTreeRegressor(splitter=splitter, monotonic_cst=[monotonic_cst], max_leaf_nodes=n_samples, random_state=seed) clf.fit(X, y) assert_children_values_monotonic_bounded(clf.tree_, monotonic_cst) - assert_leaves_values_monotonic(clf.tree_, monotonic_cst) + assert_tree_monotonic(clf, monotonic_cst) def test_classes_deprecated(): From 3366a30ed54abcce2e5c3a2b8cc0e37927737be6 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Apr 2020 19:22:07 +0200 Subject: [PATCH 042/174] Make tests consistents with implemented bounding strategy --- sklearn/tree/tests/test_tree.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 91ceccae74d51..429704a872389 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1993,19 +1993,19 @@ def assert_children_values_monotonic_bounded(tree_, monotonic_cst): values = monotonic_cst * tree_.value for i in range(tree_.node_count): - if tree_.feature[i] >= 0: + if tree_.children_left[i] > i and tree_.children_right[i] > i: # Check monotonicity i_left = tree_.children_left[i] i_right = tree_.children_right[i] - assert(values[i_left] <= values[i_right]) - val_middle = (values[i_left] + values[i_right]) / 2 + assert(float(values[i_left]) <= float(values[i_right])) + val_middle = float(values[i]) # Check bounds - if tree_.feature[i_left] >= 0: + if tree_.children_left[i_left] > i_left and tree_.children_right[i_right] > i_right: i_left_right = tree_.children_right[i_left] - assert(values[i_left_right] <= val_middle) + assert(float(values[i_left_right]) <= val_middle) if tree_.feature[i_right] >= 0: i_right_left = tree_.children_left[i_right] - assert(val_middle <= values[i_right_left]) + assert(val_middle <= float(values[i_right_left])) def assert_tree_monotonic(clf, monotonic_cst): @@ -2028,14 +2028,11 @@ def test_nodes_values(monotonic_cst, splitter, depth_first, seed): # # root # / \ - # 5 10 # middle = 7.5 - # / \ / \ - # a b c d + # 5 10 + # / \ / \ + # a b c d # - # a <= b <= middle <= c <= d (assert_children_values_monotonic_bounded) - # a <= b <= c <= d (assert_leaves_values_monotonic) - # - # The last one is a consequence of the first, but can't hurt to check + # a <= b <= root <= c <= d (assert_children_values_monotonic_bounded) rng = np.random.RandomState(seed) n_samples = 1000 From 4634deab0e95bc97023d3dd70544c80a41868877 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Apr 2020 19:30:18 +0200 Subject: [PATCH 043/174] Fix monotonicity checks --- sklearn/tree/_splitter.pxd | 10 ++++------ sklearn/tree/_splitter.pyx | 14 ++++++-------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 9acd2e1c93aa5..f828b77a2940b 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -30,8 +30,8 @@ cdef struct SplitRecord: double improvement # Impurity improvement given parent node. double impurity_left # Impurity of the left split. double impurity_right # Impurity of the right split. - double lower_bound # Lower bound on value of left child for monotonicity - double upper_bound # Upper bound on value of right child for monotonicity + double lower_bound # Lower bound on value of both children for monotonicity + double upper_bound # Upper bound on value of both children for monotonicity cdef class Splitter: # The splitter searches in the input space for a feature and a threshold @@ -101,7 +101,5 @@ cdef class Splitter: cdef double node_impurity(self) nogil - cdef inline bint split_passes_monotonicity_check(self, - INT32_t monotonic_cst, - double lower_bound, - double upper_bound) nogil + cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, + double lower_bound, double upper_bound) nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 8f187d1439c67..1a1741d5931e9 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -238,9 +238,7 @@ cdef class Splitter: return self.criterion.node_impurity() - cdef inline bint split_passes_monotonicity_check(self, INT32_t monotonic_cst, - double lower_bound, - double upper_bound) nogil: + cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: """Check monotonic constraint is satisfied at the current split""" cdef bint check_lower, check_upper, check_monotonic check_lower = (self.criterion.sum_left[0] >= lower_bound * self.criterion.weighted_n_left) \ @@ -253,7 +251,7 @@ cdef class Splitter: check_monotonic = (self.criterion.sum_left[0] * self.criterion.weighted_n_right - self.criterion.sum_right[0] * self.criterion.weighted_n_left) \ * monotonic_cst <= 0 - return check_monotonic + return check_lower & check_upper & check_monotonic cdef class BaseDenseSplitter(Splitter): cdef const DTYPE_t[:, :] X @@ -445,7 +443,7 @@ cdef class BestSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - if not self.split_passes_monotonicity_check(monotonic_constraint, lower_bound, upper_bound): + if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -770,7 +768,7 @@ cdef class RandomSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - if not self.split_passes_monotonicity_check(monotonic_constraint, lower_bound, upper_bound): + if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1323,7 +1321,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - if not self.split_passes_monotonicity_check(monotonic_constraint, lower_bound, upper_bound): + if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1560,7 +1558,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue - if not self.split_passes_monotonicity_check(monotonic_constraint, lower_bound, upper_bound): + if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() From a7b41339769f8dff89221743c1f2d71863409fb1 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Apr 2020 19:31:54 +0200 Subject: [PATCH 044/174] Clean up depth-first tree building --- sklearn/tree/_tree.pyx | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index fd098011dd6d7..001370548b468 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -195,7 +195,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double upper_bound_left cdef double lower_bound_right cdef double upper_bound_right - cdef double node_value + cdef double middle_value cdef SIZE_t n_constant_features cdef bint is_leaf cdef bint first = 1 @@ -262,7 +262,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # inspection and interpretation splitter.node_value(tree.value + node_id * tree.value_stride) - node_value = splitter.criterion.sum_total[node_id] / weighted_n_node_samples + middle_value = tree.value[node_id] + if not is_leaf: if splitter.monotonic_cst[split.feature] == 0: left_child_min = lower_bound @@ -271,14 +272,14 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): right_child_max = upper_bound elif splitter.monotonic_cst[split.feature] == 1: left_child_min = lower_bound - left_child_max = node_value - right_child_min = node_value + left_child_max = middle_value + right_child_min = middle_value right_child_max = upper_bound elif splitter.monotonic_cst[split.feature] == -1: - left_child_min = node_value + left_child_min = middle_value left_child_max = upper_bound right_child_min = lower_bound - right_child_max = node_value + right_child_max = middle_value # Push right child on stack rc = stack.push(split.pos, end, depth + 1, node_id, 0, From c2cecf5e7dce8807550bc515a2ac0ecabc84d305 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 29 Apr 2020 19:33:03 +0200 Subject: [PATCH 045/174] Fix and clean up best-first tree building --- sklearn/tree/_tree.pyx | 49 ++++++++++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 001370548b468..7f0430aa38b5f 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -318,7 +318,7 @@ cdef inline int _add_to_frontier(PriorityHeapRecord* rec, return frontier.push(rec.node_id, rec.start, rec.end, rec.pos, rec.depth, rec.is_leaf, rec.improvement, rec.impurity, rec.impurity_left, rec.impurity_right, - rec.children_lower_bound, rec.children_upper_bound) + rec.lower_bound, rec.upper_bound) cdef class BestFirstTreeBuilder(TreeBuilder): @@ -360,8 +360,6 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_leaf = self.min_samples_leaf cdef double min_weight_leaf = self.min_weight_leaf cdef SIZE_t min_samples_split = self.min_samples_split - cdef double children_lower_bound - cdef double children_upper_bound # Recursive partition (without actual recursion) splitter.init(X, y, sample_weight_ptr, X_idx_sorted) @@ -370,6 +368,11 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef PriorityHeapRecord record cdef PriorityHeapRecord split_node_left cdef PriorityHeapRecord split_node_right + cdef double middle_value + cdef double left_child_min + cdef double left_child_max + cdef double right_child_min + cdef double right_child_max cdef SIZE_t n_node_samples = splitter.n_samples cdef SIZE_t max_split_nodes = max_leaf_nodes - 1 @@ -410,7 +413,23 @@ cdef class BestFirstTreeBuilder(TreeBuilder): else: # Node is expandable - node_value = splitter.criterion.sum_total[0] / splitter.criterion.weighted_n_node_samples + middle_value = tree.value[record.node_id] + + if splitter.monotonic_cst[node.feature] == 0: + left_child_min = record.lower_bound + left_child_max = record.upper_bound + right_child_min = record.lower_bound + right_child_max = record.upper_bound + elif splitter.monotonic_cst[node.feature] == 1: + left_child_min = record.lower_bound + left_child_max = middle_value + right_child_min = middle_value + right_child_max = record.upper_bound + elif splitter.monotonic_cst[node.feature] == -1: + left_child_min = middle_value + left_child_max = record.upper_bound + right_child_min = record.lower_bound + right_child_max = middle_value # Decrement number of split nodes available max_split_nodes -= 1 @@ -421,8 +440,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): record.impurity_left, IS_NOT_FIRST, IS_LEFT, node, record.depth + 1, - record.children_lower_bound, - record.children_upper_bound, + left_child_min, + left_child_max, &split_node_left) if rc == -1: break @@ -436,8 +455,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): record.impurity_right, IS_NOT_FIRST, IS_NOT_LEFT, node, record.depth + 1, - record.children_lower_bound, - record.children_upper_bound, + right_child_min, + right_child_max, &split_node_right) if rc == -1: break @@ -466,8 +485,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef inline int _add_split_node(self, Splitter splitter, Tree tree, SIZE_t start, SIZE_t end, double impurity, bint is_first, bint is_left, Node* parent, - SIZE_t depth, double children_lower_bound, - double children_upper_bound, + SIZE_t depth, double lower_bound, + double upper_bound, PriorityHeapRecord* res) nogil except -1: """Adds node w/ partition ``[start, end)`` to the frontier. """ cdef SplitRecord split @@ -496,8 +515,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): impurity <= min_impurity_split) if not is_leaf: - splitter.node_split(impurity, &split, &n_constant_features, children_lower_bound, - children_upper_bound) + splitter.node_split(impurity, &split, &n_constant_features, lower_bound, + upper_bound) # If EPSILON=0 in the below comparison, float precision issues stop # splitting early, producing trees that are dissimilar to v0.18 is_leaf = (is_leaf or split.pos >= end or @@ -529,6 +548,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): res.improvement = split.improvement res.impurity_left = split.impurity_left res.impurity_right = split.impurity_right + res.lower_bound = lower_bound + res.upper_bound = upper_bound else: # is leaf => 0 improvement @@ -537,8 +558,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): res.improvement = 0.0 res.impurity_left = impurity res.impurity_right = impurity - res.children_lower_bound = children_lower_bound if is_left else node_value - res.children_upper_bound = node_value if is_left else children_upper_bound + res.lower_bound = lower_bound + res.upper_bound = upper_bound return 0 From 4764cf4357a7d21a87bb8ded44b03bbceeb9c0b8 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 30 Apr 2020 01:37:33 +0200 Subject: [PATCH 046/174] Flake8 --- sklearn/tree/tests/test_tree.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 429704a872389..1f5eef528ede1 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2000,7 +2000,7 @@ def assert_children_values_monotonic_bounded(tree_, monotonic_cst): assert(float(values[i_left]) <= float(values[i_right])) val_middle = float(values[i]) # Check bounds - if tree_.children_left[i_left] > i_left and tree_.children_right[i_right] > i_right: + if tree_.feature[i_left] >= 0: i_left_right = tree_.children_right[i_left] assert(float(values[i_left_right]) <= val_middle) if tree_.feature[i_right] >= 0: @@ -2042,10 +2042,15 @@ def test_nodes_values(monotonic_cst, splitter, depth_first, seed): if depth_first: # No max_leaf_nodes, default depth first tree builder - clf = DecisionTreeRegressor(splitter=splitter, monotonic_cst=[monotonic_cst], random_state=seed) + clf = DecisionTreeRegressor(splitter=splitter, + monotonic_cst=[monotonic_cst], + random_state=seed) else: # max_leaf_nodes triggers depth first tree builder - clf = DecisionTreeRegressor(splitter=splitter, monotonic_cst=[monotonic_cst], max_leaf_nodes=n_samples, random_state=seed) + clf = DecisionTreeRegressor(splitter=splitter, + monotonic_cst=[monotonic_cst], + max_leaf_nodes=n_samples, + random_state=seed) clf.fit(X, y) assert_children_values_monotonic_bounded(clf.tree_, monotonic_cst) From dcc9f1862c9eac5b749c7af426656a359e51b9f7 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 30 Apr 2020 11:31:30 +0200 Subject: [PATCH 047/174] Formatting --- sklearn/tree/tests/test_tree.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 1f5eef528ede1..f3370a670ebd4 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1988,6 +1988,7 @@ def test_montonic_constraints(): def is_monotonic(a, cst): return (cst * np.diff(a) >= 0.0).all() + def assert_children_values_monotonic_bounded(tree_, monotonic_cst): # Flip values so that only need to check for increasing constraint values = monotonic_cst * tree_.value @@ -2028,9 +2029,9 @@ def test_nodes_values(monotonic_cst, splitter, depth_first, seed): # # root # / \ - # 5 10 - # / \ / \ - # a b c d + # 5 10 + # / \ / \ + # a b c d # # a <= b <= root <= c <= d (assert_children_values_monotonic_bounded) From 84b912bb4ce3b630451013bda7e92506695bde4b Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 30 Apr 2020 11:50:31 +0200 Subject: [PATCH 048/174] Raise ValueError on multiclass --- sklearn/tree/_classes.py | 6 ++++++ sklearn/tree/tests/test_tree.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 794693558dcc1..03e8bc700706f 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -347,6 +347,12 @@ def fit(self, X, y, sample_weight=None, check_input=True, fill_value=0, dtype=np.int32) else: + if self.n_outputs_ > 1: + raise ValueError("Monotonic constraints are not supported with " + "multiple output") + if is_classifier(self) and self.n_classes_[0] > 2: + raise ValueError("Monotonic constraints are not supported with " + "multiclass classification") monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) if monotonic_cst.shape[0] != X.shape[1]: diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index f3370a670ebd4..405e9527a1550 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1985,6 +1985,22 @@ def test_montonic_constraints(): assert(np.min(y2 - y0) >= 0) +def test_multiclass_raises(): + X, y = datasets.make_hastie_10_2(n_samples=100, random_state=0) + y[0] = 0 + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = -1 + monotonic_cst[1] = 1 + for name, TreeClassifier in CLF_TREES.items(): + est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst, + random_state=0) + if hasattr(est, "random_state"): + est.set_params(**{"random_state": 0}) + + with pytest.raises(ValueError): + est.fit(X, y) + + def is_monotonic(a, cst): return (cst * np.diff(a) >= 0.0).all() From 60a4b8c09f12be9b47d1872486d3a61751f05dac Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 30 Apr 2020 12:14:01 +0200 Subject: [PATCH 049/174] Linting --- sklearn/tree/_classes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 03e8bc700706f..f6573e897e39d 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -348,11 +348,11 @@ def fit(self, X, y, sample_weight=None, check_input=True, dtype=np.int32) else: if self.n_outputs_ > 1: - raise ValueError("Monotonic constraints are not supported with " - "multiple output") + raise ValueError("Monotonic constraints are not supported with" + " multiple output") if is_classifier(self) and self.n_classes_[0] > 2: - raise ValueError("Monotonic constraints are not supported with " - "multiclass classification") + raise ValueError("Monotonic constraints are not supported with" + " multiclass classification") monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) if monotonic_cst.shape[0] != X.shape[1]: From 3ccb5406970919fdd7b013e576976f5f71fda9e1 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 30 Apr 2020 12:31:18 +0200 Subject: [PATCH 050/174] Fix previous merge --- sklearn/tree/tests/test_tree.py | 71 ++++++++++++++++----------------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 405e9527a1550..0dd85a5dc551a 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1930,8 +1930,40 @@ def test_prune_tree_raises_negative_ccp_alpha(): clf.set_params(ccp_alpha=-1.0) clf._prune_tree() - with TempMemmap((X, y)) as (X_read_only, y_read_only): - DecisionTreeClassifier().fit(X_read_only, y_read_only) + +def test_classes_deprecated(): + X = [[0, 0], [2, 2], [4, 6], [10, 11]] + y = [0.5, 2.5, 3.5, 5.5] + clf = DecisionTreeRegressor() + clf = clf.fit(X, y) + + match = ("attribute is to be deprecated from version " + "0.22 and will be removed in 0.24.") + + with pytest.warns(FutureWarning, match=match): + n = len(clf.classes_) + assert n == clf.n_outputs_ + + with pytest.warns(FutureWarning, match=match): + assert len(clf.n_classes_) == clf.n_outputs_ + + +def check_apply_path_readonly(name): + X_readonly = create_memmap_backed_data(X_small.astype(tree._tree.DTYPE, + copy=False)) + y_readonly = create_memmap_backed_data(np.array(y_small, + dtype=tree._tree.DTYPE)) + est = ALL_TREES[name]() + est.fit(X_readonly, y_readonly) + assert_array_equal(est.predict(X_readonly), + est.predict(X_small)) + assert_array_equal(est.decision_path(X_readonly).todense(), + est.decision_path(X_small).todense()) + + +@pytest.mark.parametrize("name", ALL_TREES) +def test_apply_path_readonly_all_trees(name): + check_apply_path_readonly(name) def test_montonic_constraints(): @@ -2072,38 +2104,3 @@ def test_nodes_values(monotonic_cst, splitter, depth_first, seed): assert_children_values_monotonic_bounded(clf.tree_, monotonic_cst) assert_tree_monotonic(clf, monotonic_cst) - - -def test_classes_deprecated(): - X = [[0, 0], [2, 2], [4, 6], [10, 11]] - y = [0.5, 2.5, 3.5, 5.5] - clf = DecisionTreeRegressor() - clf = clf.fit(X, y) - - match = ("attribute is to be deprecated from version " - "0.22 and will be removed in 0.24.") - - with pytest.warns(FutureWarning, match=match): - n = len(clf.classes_) - assert n == clf.n_outputs_ - - with pytest.warns(FutureWarning, match=match): - assert len(clf.n_classes_) == clf.n_outputs_ - - -def check_apply_path_readonly(name): - X_readonly = create_memmap_backed_data(X_small.astype(tree._tree.DTYPE, - copy=False)) - y_readonly = create_memmap_backed_data(np.array(y_small, - dtype=tree._tree.DTYPE)) - est = ALL_TREES[name]() - est.fit(X_readonly, y_readonly) - assert_array_equal(est.predict(X_readonly), - est.predict(X_small)) - assert_array_equal(est.decision_path(X_readonly).todense(), - est.decision_path(X_small).todense()) - - -@pytest.mark.parametrize("name", ALL_TREES) -def test_apply_path_readonly_all_trees(name): - check_apply_path_readonly(name) From 5f75c1c8a3338d6ec74b1a98a487ecec437071f0 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Mon, 4 May 2020 11:57:16 +0200 Subject: [PATCH 051/174] Remove useless check when splitting on unconstrainted feature --- sklearn/tree/_splitter.pyx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 1a1741d5931e9..6c5a829c8a9c2 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -241,13 +241,13 @@ cdef class Splitter: cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: """Check monotonic constraint is satisfied at the current split""" cdef bint check_lower, check_upper, check_monotonic - check_lower = (self.criterion.sum_left[0] >= lower_bound * self.criterion.weighted_n_left) \ - & (self.criterion.sum_right[0] >= lower_bound * self.criterion.weighted_n_right) - check_upper = (self.criterion.sum_left[0] <= upper_bound * self.criterion.weighted_n_left) \ - & (self.criterion.sum_right[0] <= upper_bound * self.criterion.weighted_n_right) if monotonic_cst == 0: # No constraint - return check_lower & check_upper + return 1 else: + check_lower = (self.criterion.sum_left[0] >= lower_bound * self.criterion.weighted_n_left) \ + & (self.criterion.sum_right[0] >= lower_bound * self.criterion.weighted_n_right) + check_upper = (self.criterion.sum_left[0] <= upper_bound * self.criterion.weighted_n_left) \ + & (self.criterion.sum_right[0] <= upper_bound * self.criterion.weighted_n_right) check_monotonic = (self.criterion.sum_left[0] * self.criterion.weighted_n_right - self.criterion.sum_right[0] * self.criterion.weighted_n_left) \ * monotonic_cst <= 0 From 05d023c3d6725a923e2b5e508e47fb998140e966 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 2 Jun 2020 16:23:09 +0200 Subject: [PATCH 052/174] Group monotonic tree tests into dedicated file --- sklearn/tree/tests/test_monotonic_tree.py | 145 ++++++++++++++++++++++ sklearn/tree/tests/test_tree.py | 139 --------------------- 2 files changed, 145 insertions(+), 139 deletions(-) create mode 100644 sklearn/tree/tests/test_monotonic_tree.py diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py new file mode 100644 index 0000000000000..ebeb6dade334e --- /dev/null +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -0,0 +1,145 @@ +import numpy as np +import pytest +from sklearn import datasets +from sklearn.tree import DecisionTreeRegressor +from sklearn.tree.tests.test_tree import REG_TREES, CLF_TREES + + +def test_montonic_constraints(): + X, y = datasets.make_hastie_10_2(n_samples=100, random_state=0) + train = np.arange(90) + test = np.arange(90, 100) + X_train = X[train] + y_train = y[train] + X_test_0 = np.copy(X[test]) + X_test_1 = np.copy(X_test_0) + X_test_1[:, 0] += 10 + X_test_2 = np.copy(X_test_0) + X_test_2[:, 1] += 10 + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = -1 + monotonic_cst[1] = 1 + + for name, TreeRegressor in REG_TREES.items(): + est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) + if hasattr(est, "random_state"): + est.set_params(**{"random_state": 0}) + est.fit(X_train, y_train) + + y0 = est.predict(X_test_0) + # decreasing constraint + y1 = est.predict(X_test_1) + # y1 should always be lower than y0 + assert(np.max(y1 - y0) <= 0) + + # increasing constraint + y2 = est.predict(X_test_2) + # y2 should always be greater than y0 + assert(np.min(y2 - y0) >= 0) + + for name, TreeClassifier in CLF_TREES.items(): + est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) + if hasattr(est, "random_state"): + est.set_params(**{"random_state": 0}) + est.fit(X_train, y_train) + + y0 = est.predict_proba(X_test_0)[:, 0] + + # decreasing constraint + y1 = est.predict_proba(X_test_1)[:, 0] + # y1 should always be lower than y0 + assert(np.max(y1 - y0) <= 0) + + # increasing constraint + y2 = est.predict_proba(X_test_2)[:, 0] + # y2 should always be greater than y0 + assert(np.min(y2 - y0) >= 0) + + +def test_multiclass_raises(): + X, y = datasets.make_hastie_10_2(n_samples=100, random_state=0) + y[0] = 0 + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = -1 + monotonic_cst[1] = 1 + for name, TreeClassifier in CLF_TREES.items(): + est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst, + random_state=0) + if hasattr(est, "random_state"): + est.set_params(**{"random_state": 0}) + + with pytest.raises(ValueError): + est.fit(X, y) + + +def is_monotonic(a, cst): + return (cst * np.diff(a) >= 0.0).all() + + +def assert_children_values_monotonic_bounded(tree_, monotonic_cst): + # Flip values so that only need to check for increasing constraint + values = monotonic_cst * tree_.value + + for i in range(tree_.node_count): + if tree_.children_left[i] > i and tree_.children_right[i] > i: + # Check monotonicity + i_left = tree_.children_left[i] + i_right = tree_.children_right[i] + assert(float(values[i_left]) <= float(values[i_right])) + val_middle = float(values[i]) + # Check bounds + if tree_.feature[i_left] >= 0: + i_left_right = tree_.children_right[i_left] + assert(float(values[i_left_right]) <= val_middle) + if tree_.feature[i_right] >= 0: + i_right_left = tree_.children_left[i_right] + assert(val_middle <= float(values[i_right_left])) + + +def assert_tree_monotonic(clf, monotonic_cst): + X_grid = np.arange(0, 1, 0.01).reshape(-1, 1) + y_pred_grid = clf.predict(X_grid) + assert is_monotonic(y_pred_grid, monotonic_cst) + + +@pytest.mark.parametrize('monotonic_cst', (-1, 1)) +@pytest.mark.parametrize('splitter', ("best", "random")) +@pytest.mark.parametrize('depth_first', (True, False)) +@pytest.mark.parametrize('seed', range(4)) +def test_nodes_values(monotonic_cst, splitter, depth_first, seed): + # Adaptation from test_nodes_values in test_montonic_constraints.py + # Build a single tree with only one feature, and make sure the nodes + # values respect the monotonic constraints. + + # Considering the following tree with a monotonic POS constraint, we + # should have: + # + # root + # / \ + # 5 10 + # / \ / \ + # a b c d + # + # a <= b <= root <= c <= d (assert_children_values_monotonic_bounded) + + rng = np.random.RandomState(seed) + n_samples = 1000 + n_features = 1 + X = rng.rand(n_samples, n_features) + y = rng.rand(n_samples) + + if depth_first: + # No max_leaf_nodes, default depth first tree builder + clf = DecisionTreeRegressor(splitter=splitter, + monotonic_cst=[monotonic_cst], + random_state=seed) + else: + # max_leaf_nodes triggers depth first tree builder + clf = DecisionTreeRegressor(splitter=splitter, + monotonic_cst=[monotonic_cst], + max_leaf_nodes=n_samples, + random_state=seed) + clf.fit(X, y) + + assert_children_values_monotonic_bounded(clf.tree_, monotonic_cst) + assert_tree_monotonic(clf, monotonic_cst) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 9c3544c749a27..ba94415781e79 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1949,142 +1949,3 @@ def check_apply_path_readonly(name): def test_apply_path_readonly_all_trees(name): check_apply_path_readonly(name) - -def test_montonic_constraints(): - X, y = datasets.make_hastie_10_2(n_samples=100, random_state=0) - train = np.arange(90) - test = np.arange(90, 100) - X_train = X[train] - y_train = y[train] - X_test_0 = np.copy(X[test]) - X_test_1 = np.copy(X_test_0) - X_test_1[:, 0] += 10 - X_test_2 = np.copy(X_test_0) - X_test_2[:, 1] += 10 - monotonic_cst = np.zeros(X.shape[1]) - monotonic_cst[0] = -1 - monotonic_cst[1] = 1 - - for name, TreeRegressor in REG_TREES.items(): - est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) - if hasattr(est, "random_state"): - est.set_params(**{"random_state": 0}) - est.fit(X_train, y_train) - - y0 = est.predict(X_test_0) - # decreasing constraint - y1 = est.predict(X_test_1) - # y1 should always be lower than y0 - assert(np.max(y1 - y0) <= 0) - - # increasing constraint - y2 = est.predict(X_test_2) - # y2 should always be greater than y0 - assert(np.min(y2 - y0) >= 0) - - for name, TreeClassifier in CLF_TREES.items(): - est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) - if hasattr(est, "random_state"): - est.set_params(**{"random_state": 0}) - est.fit(X_train, y_train) - - y0 = est.predict_proba(X_test_0)[:, 0] - - # decreasing constraint - y1 = est.predict_proba(X_test_1)[:, 0] - # y1 should always be lower than y0 - assert(np.max(y1 - y0) <= 0) - - # increasing constraint - y2 = est.predict_proba(X_test_2)[:, 0] - # y2 should always be greater than y0 - assert(np.min(y2 - y0) >= 0) - - -def test_multiclass_raises(): - X, y = datasets.make_hastie_10_2(n_samples=100, random_state=0) - y[0] = 0 - monotonic_cst = np.zeros(X.shape[1]) - monotonic_cst[0] = -1 - monotonic_cst[1] = 1 - for name, TreeClassifier in CLF_TREES.items(): - est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst, - random_state=0) - if hasattr(est, "random_state"): - est.set_params(**{"random_state": 0}) - - with pytest.raises(ValueError): - est.fit(X, y) - - -def is_monotonic(a, cst): - return (cst * np.diff(a) >= 0.0).all() - - -def assert_children_values_monotonic_bounded(tree_, monotonic_cst): - # Flip values so that only need to check for increasing constraint - values = monotonic_cst * tree_.value - - for i in range(tree_.node_count): - if tree_.children_left[i] > i and tree_.children_right[i] > i: - # Check monotonicity - i_left = tree_.children_left[i] - i_right = tree_.children_right[i] - assert(float(values[i_left]) <= float(values[i_right])) - val_middle = float(values[i]) - # Check bounds - if tree_.feature[i_left] >= 0: - i_left_right = tree_.children_right[i_left] - assert(float(values[i_left_right]) <= val_middle) - if tree_.feature[i_right] >= 0: - i_right_left = tree_.children_left[i_right] - assert(val_middle <= float(values[i_right_left])) - - -def assert_tree_monotonic(clf, monotonic_cst): - X_grid = np.arange(0, 1, 0.01).reshape(-1, 1) - y_pred_grid = clf.predict(X_grid) - assert is_monotonic(y_pred_grid, monotonic_cst) - - -@pytest.mark.parametrize('monotonic_cst', (-1, 1)) -@pytest.mark.parametrize('splitter', ("best", "random")) -@pytest.mark.parametrize('depth_first', (True, False)) -@pytest.mark.parametrize('seed', range(4)) -def test_nodes_values(monotonic_cst, splitter, depth_first, seed): - # Adaptation from test_nodes_values in test_montonic_constraints.py - # Build a single tree with only one feature, and make sure the nodes - # values respect the monotonic constraints. - - # Considering the following tree with a monotonic POS constraint, we - # should have: - # - # root - # / \ - # 5 10 - # / \ / \ - # a b c d - # - # a <= b <= root <= c <= d (assert_children_values_monotonic_bounded) - - rng = np.random.RandomState(seed) - n_samples = 1000 - n_features = 1 - X = rng.rand(n_samples, n_features) - y = rng.rand(n_samples) - - if depth_first: - # No max_leaf_nodes, default depth first tree builder - clf = DecisionTreeRegressor(splitter=splitter, - monotonic_cst=[monotonic_cst], - random_state=seed) - else: - # max_leaf_nodes triggers depth first tree builder - clf = DecisionTreeRegressor(splitter=splitter, - monotonic_cst=[monotonic_cst], - max_leaf_nodes=n_samples, - random_state=seed) - clf.fit(X, y) - - assert_children_values_monotonic_bounded(clf.tree_, monotonic_cst) - assert_tree_monotonic(clf, monotonic_cst) From 3fd364fc9e3395ca884086513223e1e2489b2c4c Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 2 Jun 2020 16:37:17 +0200 Subject: [PATCH 053/174] Fix linting --- sklearn/ensemble/_gb.py | 1 - sklearn/tree/tests/test_tree.py | 1 - 2 files changed, 2 deletions(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 9e877e30c2314..b9b2ee583ad16 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -205,7 +205,6 @@ def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask, ccp_alpha=self.ccp_alpha, monotonic_cst=self.monotonic_cst) - if self.subsample < 1.0: # no inplace multiplication! sample_weight = sample_weight * sample_mask.astype(np.float64) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index ba94415781e79..964ca978bbfb0 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1948,4 +1948,3 @@ def check_apply_path_readonly(name): @pytest.mark.parametrize("name", ALL_TREES) def test_apply_path_readonly_all_trees(name): check_apply_path_readonly(name) - From de6341fa4e29329ba886a434d800f7013b4909a1 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 10 Feb 2021 14:57:21 +0100 Subject: [PATCH 054/174] Make test_montonic_constraints harder to pass --- sklearn/tree/tests/test_monotonic_tree.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index ebeb6dade334e..add79a9d5483c 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -5,7 +5,8 @@ from sklearn.tree.tests.test_tree import REG_TREES, CLF_TREES -def test_montonic_constraints(): +@pytest.mark.parametrize('seed', range(4)) +def test_montonic_constraints(seed): X, y = datasets.make_hastie_10_2(n_samples=100, random_state=0) train = np.arange(90) test = np.arange(90, 100) @@ -23,7 +24,7 @@ def test_montonic_constraints(): for name, TreeRegressor in REG_TREES.items(): est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) if hasattr(est, "random_state"): - est.set_params(**{"random_state": 0}) + est.set_params(**{"random_state": seed}) est.fit(X_train, y_train) y0 = est.predict(X_test_0) @@ -40,7 +41,7 @@ def test_montonic_constraints(): for name, TreeClassifier in CLF_TREES.items(): est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) if hasattr(est, "random_state"): - est.set_params(**{"random_state": 0}) + est.set_params(**{"random_state": seed}) est.fit(X_train, y_train) y0 = est.predict_proba(X_test_0)[:, 0] From 3a78b94fa76d79fa1821c29de9517b373f898311 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Fri, 12 Feb 2021 16:19:31 +0100 Subject: [PATCH 055/174] Fix tree regressions by enforcing bounds for all splits --- sklearn/tree/_splitter.pyx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 9bd3c24d00e94..3654ef48dd45f 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -237,13 +237,13 @@ cdef class Splitter: cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: """Check monotonic constraint is satisfied at the current split""" cdef bint check_lower, check_upper, check_monotonic + check_lower = (self.criterion.sum_left[0] >= lower_bound * self.criterion.weighted_n_left) \ + & (self.criterion.sum_right[0] >= lower_bound * self.criterion.weighted_n_right) + check_upper = (self.criterion.sum_left[0] <= upper_bound * self.criterion.weighted_n_left) \ + & (self.criterion.sum_right[0] <= upper_bound * self.criterion.weighted_n_right) if monotonic_cst == 0: # No constraint - return 1 + return check_lower & check_upper else: - check_lower = (self.criterion.sum_left[0] >= lower_bound * self.criterion.weighted_n_left) \ - & (self.criterion.sum_right[0] >= lower_bound * self.criterion.weighted_n_right) - check_upper = (self.criterion.sum_left[0] <= upper_bound * self.criterion.weighted_n_left) \ - & (self.criterion.sum_right[0] <= upper_bound * self.criterion.weighted_n_right) check_monotonic = (self.criterion.sum_left[0] * self.criterion.weighted_n_right - self.criterion.sum_right[0] * self.criterion.weighted_n_left) \ * monotonic_cst <= 0 From 5090853b26a08245cec98d90efb6ff2c99e664f1 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 16 Feb 2021 17:25:44 +0100 Subject: [PATCH 056/174] Make test_montonic_constraints harder to pass --- sklearn/tree/tests/test_monotonic_tree.py | 72 +++++++++++++---------- 1 file changed, 41 insertions(+), 31 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index add79a9d5483c..f1f237419a4ed 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -6,55 +6,65 @@ @pytest.mark.parametrize('seed', range(4)) -def test_montonic_constraints(seed): - X, y = datasets.make_hastie_10_2(n_samples=100, random_state=0) - train = np.arange(90) - test = np.arange(90, 100) +@pytest.mark.parametrize('depth_first', (True, False)) +def test_montonic_constraints(seed, depth_first): + n_samples = 1000 + n_samples_train = 900 + X, y = datasets.make_hastie_10_2(n_samples=n_samples, random_state=0) + train = np.arange(n_samples_train) + test = np.arange(n_samples_train, n_samples) X_train = X[train] y_train = y[train] - X_test_0 = np.copy(X[test]) - X_test_1 = np.copy(X_test_0) - X_test_1[:, 0] += 10 - X_test_2 = np.copy(X_test_0) - X_test_2[:, 1] += 10 + X_test = np.copy(X[test]) + + X_test_incr = np.copy(X_test) + X_test_decr = np.copy(X_test) + X_test_incr[:, 0] += 10 + X_test_decr[:, 1] += 10 monotonic_cst = np.zeros(X.shape[1]) - monotonic_cst[0] = -1 - monotonic_cst[1] = 1 + monotonic_cst[0] = 1 + monotonic_cst[1] = -1 for name, TreeRegressor in REG_TREES.items(): - est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) + if depth_first: + est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) + else: + est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst, + max_leaf_nodes=n_samples_train) if hasattr(est, "random_state"): est.set_params(**{"random_state": seed}) est.fit(X_train, y_train) + y = est.predict(X_test) + # increasing constraint + y_incr = est.predict(X_test_incr) + # y_incr should always be greater than y + assert np.all(y_incr >= y) - y0 = est.predict(X_test_0) # decreasing constraint - y1 = est.predict(X_test_1) - # y1 should always be lower than y0 - assert(np.max(y1 - y0) <= 0) - - # increasing constraint - y2 = est.predict(X_test_2) - # y2 should always be greater than y0 - assert(np.min(y2 - y0) >= 0) + y_decr = est.predict(X_test_decr) + # y_decr should always be lower than y + assert np.all(y_decr <= y) for name, TreeClassifier in CLF_TREES.items(): - est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) + if depth_first: + est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) + else: + est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst, + max_leaf_nodes=n_samples_train) if hasattr(est, "random_state"): est.set_params(**{"random_state": seed}) est.fit(X_train, y_train) + y = est.predict_proba(X_test)[:, 0] - y0 = est.predict_proba(X_test_0)[:, 0] + # increasing constraint + y_incr = est.predict_proba(X_test_incr)[:, 0] + # y_incr should always be greater than y + assert np.all(y_incr >= y) # decreasing constraint - y1 = est.predict_proba(X_test_1)[:, 0] - # y1 should always be lower than y0 - assert(np.max(y1 - y0) <= 0) - - # increasing constraint - y2 = est.predict_proba(X_test_2)[:, 0] - # y2 should always be greater than y0 - assert(np.min(y2 - y0) >= 0) + y_decr = est.predict_proba(X_test_decr)[:, 0] + # y_decr should always be lower than y + assert np.all(y_decr <= y) def test_multiclass_raises(): From 70f02768c0d99223d21b1d9ff0ed94105dcde0b7 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 16 Feb 2021 17:41:43 +0100 Subject: [PATCH 057/174] Fix middle value bound for classifications --- sklearn/tree/_tree.pyx | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index a4d701914b1ed..cbfe094e453e5 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -251,7 +251,14 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # inspection and interpretation splitter.node_value(tree.value + node_id * tree.value_stride) - middle_value = tree.value[node_id] + if tree.max_n_classes == 1: + # Regression + middle_value = tree.value[node_id] + else: + # Classification + with gil: + middle_value = (tree._get_value_ndarray()[node_id][0][0] + / np.sum(tree._get_value_ndarray()[node_id][0])) if not is_leaf: if splitter.monotonic_cst[split.feature] == 0: @@ -401,8 +408,13 @@ cdef class BestFirstTreeBuilder(TreeBuilder): else: # Node is expandable - middle_value = tree.value[record.node_id] - + if tree.max_n_classes == 1: + # Regression + middle_value = tree.value[record.node_id] + else: + # Classification + with gil: + middle_value = (tree._get_value_ndarray()[record.node_id][0][0] / np.sum(tree._get_value_ndarray()[record.node_id][0])) if splitter.monotonic_cst[node.feature] == 0: left_child_min = record.lower_bound left_child_max = record.upper_bound From 930ed41916a9084624b1cd27d8277153d7a458d4 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 16 Feb 2021 17:42:52 +0100 Subject: [PATCH 058/174] Clean up, document and comment --- sklearn/ensemble/_gb.py | 2 +- sklearn/tree/_splitter.pyx | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index b357e3351d2b6..409c907a3bfeb 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -1105,7 +1105,7 @@ def __init__(self, *, loss='deviance', learning_rate=0.1, n_estimators=100, min_impurity_split=min_impurity_split, warm_start=warm_start, validation_fraction=validation_fraction, n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha, - monotonic_cst = monotonic_cst) + monotonic_cst=monotonic_cst) def _validate_y(self, y, sample_weight): check_classification_targets(y) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 3654ef48dd45f..52d648b5cb522 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -82,6 +82,10 @@ cdef class Splitter: random_state : object The user inputted random state to be used for pseudo-randomness + + monotonic_cst : INT32_t* + Monotonicity constraints + """ self.criterion = criterion @@ -422,6 +426,7 @@ cdef class BestSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # Reject if monotonicity constraints are not satisfied if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): continue @@ -748,6 +753,7 @@ cdef class RandomSplitter(BaseDenseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # Reject if monotonicity constraints are not satisfied if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): continue @@ -1301,6 +1307,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # Reject if monotonicity constraints are not satisfied if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): continue @@ -1539,6 +1546,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # Reject if monotonicity constraints are not satisfied if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): continue From b4d1757788524f29531cbc3155e84477de870b7e Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 16 Feb 2021 21:48:00 +0100 Subject: [PATCH 059/174] Clean up --- sklearn/tree/_classes.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index f44f43fef9762..53c1491066cf9 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -99,8 +99,8 @@ def __init__(self, *, min_impurity_decrease, min_impurity_split, class_weight=None, - monotonic_cst=None, - ccp_alpha=0.0): + ccp_alpha=0.0, + monotonic_cst=None): self.criterion = criterion self.splitter = splitter @@ -114,8 +114,8 @@ def __init__(self, *, self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split self.class_weight = class_weight - self.monotonic_cst = monotonic_cst self.ccp_alpha = ccp_alpha + self.monotonic_cst = monotonic_cst def get_depth(self): """Return the depth of the decision tree. @@ -777,7 +777,7 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): monotonic_cst : array-like of int of shape (n_features), default=None Indicates the monotonic constraint to enforce on each feature. -1, 1 and 0 respectively correspond to a positive constraint, negative - constraint and no constraint.. + constraint and no constraint. Attributes ---------- @@ -874,8 +874,8 @@ def __init__(self, *, min_impurity_decrease=0., min_impurity_split=None, class_weight=None, - monotonic_cst=None, - ccp_alpha=0.0): + ccp_alpha=0.0, + monotonic_cst=None): super().__init__( criterion=criterion, @@ -1463,6 +1463,11 @@ class ExtraTreeClassifier(DecisionTreeClassifier): .. versionadded:: 0.22 + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonic constraint to enforce on each feature. -1, 1 + and 0 respectively correspond to a positive constraint, negative + constraint and no constraint. + Attributes ---------- classes_ : ndarray of shape (n_classes,) or list of ndarray From b73984d774638dbf083ffa24477b59804bfa8899 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 16 Feb 2021 22:55:18 +0100 Subject: [PATCH 060/174] Fix ExtraTreeRegressor Docstring --- sklearn/tree/_classes.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 53c1491066cf9..6b796e2c457d3 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -1705,6 +1705,11 @@ class ExtraTreeRegressor(DecisionTreeRegressor): .. versionadded:: 0.22 + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonic constraint to enforce on each feature. -1, 1 + and 0 respectively correspond to a positive constraint, negative + constraint and no constraint. + Attributes ---------- max_features_ : int From 173db431fc8a538fea305152365b46223e7a5905 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 16 Feb 2021 23:06:44 +0100 Subject: [PATCH 061/174] Clean up --- sklearn/tree/_classes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 6b796e2c457d3..955adaa70d7ca 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -1779,10 +1779,10 @@ def __init__(self, *, min_samples_leaf=1, min_weight_fraction_leaf=0., max_features="auto", - max_leaf_nodes=None, + random_state=None, min_impurity_decrease=0., min_impurity_split=None, - random_state=None, + max_leaf_nodes=None, ccp_alpha=0.0, monotonic_cst=None): super().__init__( From c4e119dfc88669025dea50389cbc1f4eba83a6a3 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 18 Feb 2021 11:05:29 +0100 Subject: [PATCH 062/174] Clean up tests --- sklearn/tree/tests/test_monotonic_tree.py | 33 ++++++++++------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index f1f237419a4ed..31eeec49c7da7 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -83,14 +83,9 @@ def test_multiclass_raises(): est.fit(X, y) -def is_monotonic(a, cst): - return (cst * np.diff(a) >= 0.0).all() - - -def assert_children_values_monotonic_bounded(tree_, monotonic_cst): - # Flip values so that only need to check for increasing constraint - values = monotonic_cst * tree_.value - +def assert_1d_reg_tree_children_monotonic_bounded(tree_, monotonic_sign): + # Flip values to always check for increasing constraint + values = monotonic_sign * tree_.value for i in range(tree_.node_count): if tree_.children_left[i] > i and tree_.children_right[i] > i: # Check monotonicity @@ -107,17 +102,17 @@ def assert_children_values_monotonic_bounded(tree_, monotonic_cst): assert(val_middle <= float(values[i_right_left])) -def assert_tree_monotonic(clf, monotonic_cst): - X_grid = np.arange(0, 1, 0.01).reshape(-1, 1) +def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps): + X_grid = np.arange(min_x, max_x, (max_x-min_x)/n_steps).reshape(-1, 1) y_pred_grid = clf.predict(X_grid) - assert is_monotonic(y_pred_grid, monotonic_cst) + assert (monotonic_sign * np.diff(y_pred_grid) >= 0.0).all() -@pytest.mark.parametrize('monotonic_cst', (-1, 1)) +@pytest.mark.parametrize('monotonic_sign', (-1, 1)) @pytest.mark.parametrize('splitter', ("best", "random")) @pytest.mark.parametrize('depth_first', (True, False)) @pytest.mark.parametrize('seed', range(4)) -def test_nodes_values(monotonic_cst, splitter, depth_first, seed): +def test_1d_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): # Adaptation from test_nodes_values in test_montonic_constraints.py # Build a single tree with only one feature, and make sure the nodes # values respect the monotonic constraints. @@ -131,7 +126,7 @@ def test_nodes_values(monotonic_cst, splitter, depth_first, seed): # / \ / \ # a b c d # - # a <= b <= root <= c <= d (assert_children_values_monotonic_bounded) + # a <= b <= root <= c <= d rng = np.random.RandomState(seed) n_samples = 1000 @@ -142,15 +137,15 @@ def test_nodes_values(monotonic_cst, splitter, depth_first, seed): if depth_first: # No max_leaf_nodes, default depth first tree builder clf = DecisionTreeRegressor(splitter=splitter, - monotonic_cst=[monotonic_cst], + monotonic_cst=[monotonic_sign], random_state=seed) else: - # max_leaf_nodes triggers depth first tree builder + # max_leaf_nodes triggers best first tree builder clf = DecisionTreeRegressor(splitter=splitter, - monotonic_cst=[monotonic_cst], + monotonic_cst=[monotonic_sign], max_leaf_nodes=n_samples, random_state=seed) clf.fit(X, y) - assert_children_values_monotonic_bounded(clf.tree_, monotonic_cst) - assert_tree_monotonic(clf, monotonic_cst) + assert_1d_reg_tree_children_monotonic_bounded(clf.tree_, monotonic_sign) + assert_1d_reg_monotonic(clf, monotonic_sign, np.min(X), np.max(X), 100) From fc507999e678b32ead5f8d9da5c134919d642ccc Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 18 Feb 2021 15:30:32 +0100 Subject: [PATCH 063/174] Add test for multi-features trees --- sklearn/tree/tests/test_monotonic_tree.py | 84 +++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 31eeec49c7da7..c7851704b081d 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -149,3 +149,87 @@ def test_1d_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): assert_1d_reg_tree_children_monotonic_bounded(clf.tree_, monotonic_sign) assert_1d_reg_monotonic(clf, monotonic_sign, np.min(X), np.max(X), 100) + + +def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): + upper_bound = np.full(tree_.node_count, np.inf) + lower_bound = np.full(tree_.node_count, -np.inf) + for i in range(tree_.node_count): + feature = tree_.feature[i] + assert tree_.value[i] <= upper_bound[i] + assert tree_.value[i] >= lower_bound[i] + if feature < 0: + # leaf: nothing to do + continue + else: + i_left = tree_.children_left[i] + i_right = tree_.children_right[i] + if monotonic_cst[feature] == 0: + # unconstrained feature: propagate bounds down the tree + upper_bound[i_left] = upper_bound[i] + lower_bound[i_left] = lower_bound[i] + upper_bound[i_right] = upper_bound[i] + lower_bound[i_right] = lower_bound[i] + else: + # constrained feature + # check montonicity + assert float(monotonic_cst[feature] * tree_.value[i_left]) \ + <= float(monotonic_cst[feature] * tree_.value[i_right]) + # update and propagate bounds down the tree + if monotonic_cst[feature] == 1: + upper_bound[i_left] = tree_.value[i] + lower_bound[i_left] = lower_bound[i] + upper_bound[i_right] = upper_bound[i] + lower_bound[i_right] = tree_.value[i] + else: + upper_bound[i_left] = upper_bound[i] + lower_bound[i_left] = tree_.value[i] + upper_bound[i_right] = tree_.value[i] + lower_bound[i_right] = lower_bound[i] + + +@pytest.mark.parametrize('monotonic_sign', (-1, 1)) +@pytest.mark.parametrize('splitter', ("best", "random")) +@pytest.mark.parametrize('depth_first', (True, False)) +@pytest.mark.parametrize('seed', range(4)) +def test_nd_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): + # Build tree with several features, and make sure the nodes + # values respect the monotonic constraints. + + # Considering the following tree with a monotonic POS constraint on X[0], + # we should have: + # + # root + # X[0]<=t + # / \ + # a b + # X[0]<=u X[1]<=v + # / \ / \ + # c d e f + # + # i) a <= root <= b + # ii) c <= a <= d <= root + # iii) root <= min(e,f) + # For iii) we check that each node value is within the proper lower and + # upper bounds. + + rng = np.random.RandomState(seed) + n_samples = 10000 + n_features = 2 + monotonic_cst = [monotonic_sign, 0] + X = rng.rand(n_samples, n_features) + y = rng.rand(n_samples) + + if depth_first: + # No max_leaf_nodes, default depth first tree builder + clf = DecisionTreeRegressor(splitter=splitter, + monotonic_cst=monotonic_cst, + random_state=seed) + else: + # max_leaf_nodes triggers best first tree builder + clf = DecisionTreeRegressor(splitter=splitter, + monotonic_cst=monotonic_cst, + max_leaf_nodes=n_samples, + random_state=seed) + clf.fit(X, y) + assert_nd_reg_tree_children_monotonic_bounded(clf.tree_, monotonic_cst) From 847cc4f15fe49bae91cd8cf8703b6d55a8affb8f Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Fri, 18 Jun 2021 17:57:37 +0200 Subject: [PATCH 064/174] Apply suggestions from code review Co-authored-by: Olivier Grisel --- sklearn/tree/_tree.pyx | 14 ++++++++++---- sklearn/tree/tests/test_monotonic_tree.py | 8 ++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 36165f00c0b3b..eecf803ccbdcb 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -256,8 +256,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): else: # Classification with gil: - middle_value = (tree._get_value_ndarray()[node_id][0][0] - / np.sum(tree._get_value_ndarray()[node_id][0])) + middle_value = ( + tree._get_value_ndarray()[node_id][0][0] + / np.sum(tree._get_value_ndarray()[node_id][0]) + ) if not is_leaf: if splitter.monotonic_cst[split.feature] == 0: @@ -285,7 +287,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # Push left child on stack rc = stack.push(start, split.pos, depth + 1, node_id, 1, - split.impurity_left, n_constant_features, left_child_min, left_child_max) + split.impurity_left, n_constant_features, left_child_min, + left_child_max) if rc == -1: break @@ -413,7 +416,10 @@ cdef class BestFirstTreeBuilder(TreeBuilder): else: # Classification with gil: - middle_value = (tree._get_value_ndarray()[record.node_id][0][0] / np.sum(tree._get_value_ndarray()[record.node_id][0])) + middle_value = ( + tree._get_value_ndarray()[record.node_id][0][0] + / np.sum(tree._get_value_ndarray()[record.node_id][0]) + ) if splitter.monotonic_cst[node.feature] == 0: left_child_min = record.lower_bound left_child_max = record.upper_bound diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index c7851704b081d..938b6726e67f8 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -91,19 +91,19 @@ def assert_1d_reg_tree_children_monotonic_bounded(tree_, monotonic_sign): # Check monotonicity i_left = tree_.children_left[i] i_right = tree_.children_right[i] - assert(float(values[i_left]) <= float(values[i_right])) + assert values[i_left] <= values[i_right] val_middle = float(values[i]) # Check bounds if tree_.feature[i_left] >= 0: i_left_right = tree_.children_right[i_left] - assert(float(values[i_left_right]) <= val_middle) + assert float(values[i_left_right]) <= val_middle if tree_.feature[i_right] >= 0: i_right_left = tree_.children_left[i_right] - assert(val_middle <= float(values[i_right_left])) + assert val_middle <= float(values[i_right_left]) def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps): - X_grid = np.arange(min_x, max_x, (max_x-min_x)/n_steps).reshape(-1, 1) + X_grid = np.linspace(min_x, max_x, n_steps).reshape(-1, 1) y_pred_grid = clf.predict(X_grid) assert (monotonic_sign * np.diff(y_pred_grid) >= 0.0).all() From 08f8247b8011c66ea0e7bbe7b93e589c1ad8ff08 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Fri, 18 Jun 2021 18:37:41 +0200 Subject: [PATCH 065/174] Apply more suggestions from code review --- sklearn/tree/tests/test_monotonic_tree.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 938b6726e67f8..008b512069df0 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -88,18 +88,18 @@ def assert_1d_reg_tree_children_monotonic_bounded(tree_, monotonic_sign): values = monotonic_sign * tree_.value for i in range(tree_.node_count): if tree_.children_left[i] > i and tree_.children_right[i] > i: - # Check monotonicity + # Check monotonicity on children i_left = tree_.children_left[i] i_right = tree_.children_right[i] assert values[i_left] <= values[i_right] - val_middle = float(values[i]) - # Check bounds + val_middle = values[i] + # Check bounds on grand-children, filtering out leaf nodes if tree_.feature[i_left] >= 0: i_left_right = tree_.children_right[i_left] - assert float(values[i_left_right]) <= val_middle + assert values[i_left_right] <= val_middle if tree_.feature[i_right] >= 0: i_right_left = tree_.children_left[i_right] - assert val_middle <= float(values[i_right_left]) + assert val_middle <= values[i_right_left] def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps): @@ -122,11 +122,11 @@ def test_1d_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): # # root # / \ - # 5 10 - # / \ / \ - # a b c d + # a b + # / \ / \ + # c d e f # - # a <= b <= root <= c <= d + # c <= d <= root <= e <= f rng = np.random.RandomState(seed) n_samples = 1000 @@ -173,8 +173,8 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): else: # constrained feature # check montonicity - assert float(monotonic_cst[feature] * tree_.value[i_left]) \ - <= float(monotonic_cst[feature] * tree_.value[i_right]) + assert monotonic_cst[feature] * tree_.value[i_left] \ + <= monotonic_cst[feature] * tree_.value[i_right] # update and propagate bounds down the tree if monotonic_cst[feature] == 1: upper_bound[i_left] = tree_.value[i] From 5a923996d590f1787f2f7fcb0f3a5b88aa7379b9 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Mon, 21 Jun 2021 16:38:19 +0200 Subject: [PATCH 066/174] Test error message for multiclass --- sklearn/tree/tests/test_monotonic_tree.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 008b512069df0..f23b84baddb6e 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -79,7 +79,9 @@ def test_multiclass_raises(): if hasattr(est, "random_state"): est.set_params(**{"random_state": 0}) - with pytest.raises(ValueError): + msg = "Monotonic constraints are not supported with multiclass " \ + "classification" + with pytest.raises(ValueError, match=msg): est.fit(X, y) From c2e6e4e3d5bf093ae672e7a4c8fef1d7c28bb750 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 17 Jun 2021 15:48:26 -0400 Subject: [PATCH 067/174] MAINT Adds target_version to black config (#20293) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 613d53e25d295..b312612236080 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ requires = [ [tool.black] line-length = 88 +target_version = ['py37', 'py38', 'py39'] exclude = ''' /( \.eggs # exclude a few common directories in the From b88a637acd74e647131945c1032e54ec4261f825 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Mon, 21 Jun 2021 17:00:00 +0200 Subject: [PATCH 068/174] Add formatting changes from black --- sklearn/ensemble/_forest.py | 692 +++++++++++++--------- sklearn/ensemble/_gb.py | 639 ++++++++++++-------- sklearn/tree/_classes.py | 489 ++++++++------- sklearn/tree/tests/test_monotonic_tree.py | 82 +-- 4 files changed, 1139 insertions(+), 763 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 5f6b421ef175d..457760f5aca0c 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -54,8 +54,12 @@ class calls the ``fit`` method of each sub-estimator on random samples from ..base import ClassifierMixin, RegressorMixin, MultiOutputMixin from ..metrics import accuracy_score, r2_score from ..preprocessing import OneHotEncoder -from ..tree import (DecisionTreeClassifier, DecisionTreeRegressor, - ExtraTreeClassifier, ExtraTreeRegressor) +from ..tree import ( + DecisionTreeClassifier, + DecisionTreeRegressor, + ExtraTreeClassifier, + ExtraTreeRegressor, +) from ..tree._tree import DTYPE, DOUBLE from ..utils import check_random_state, compute_sample_weight, deprecated from ..exceptions import DataConversionWarning @@ -66,11 +70,13 @@ class calls the ``fit`` method of each sub-estimator on random samples from ..utils.validation import check_is_fitted, _check_sample_weight -__all__ = ["RandomForestClassifier", - "RandomForestRegressor", - "ExtraTreesClassifier", - "ExtraTreesRegressor", - "RandomTreesEmbedding"] +__all__ = [ + "RandomForestClassifier", + "RandomForestRegressor", + "ExtraTreesClassifier", + "ExtraTreesRegressor", + "RandomTreesEmbedding", +] MAX_INT = np.iinfo(np.int32).max @@ -127,8 +133,9 @@ def _generate_sample_indices(random_state, n_samples, n_samples_bootstrap): def _generate_unsampled_indices(random_state, n_samples, n_samples_bootstrap): """ Private function used to forest._set_oob_score function.""" - sample_indices = _generate_sample_indices(random_state, n_samples, - n_samples_bootstrap) + sample_indices = _generate_sample_indices( + random_state, n_samples, n_samples_bootstrap + ) sample_counts = np.bincount(sample_indices, minlength=n_samples) unsampled_mask = sample_counts == 0 indices_range = np.arange(n_samples) @@ -137,9 +144,18 @@ def _generate_unsampled_indices(random_state, n_samples, n_samples_bootstrap): return unsampled_indices -def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, - verbose=0, class_weight=None, - n_samples_bootstrap=None): +def _parallel_build_trees( + tree, + forest, + X, + y, + sample_weight, + tree_idx, + n_trees, + verbose=0, + class_weight=None, + n_samples_bootstrap=None, +): """ Private function used to fit a single tree in parallel.""" if verbose > 1: @@ -152,19 +168,18 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, else: curr_sample_weight = sample_weight.copy() - indices = _generate_sample_indices(tree.random_state, n_samples, - n_samples_bootstrap) + indices = _generate_sample_indices( + tree.random_state, n_samples, n_samples_bootstrap + ) sample_counts = np.bincount(indices, minlength=n_samples) curr_sample_weight *= sample_counts - if class_weight == 'subsample': + if class_weight == "subsample": with catch_warnings(): - simplefilter('ignore', DeprecationWarning) - curr_sample_weight *= compute_sample_weight('auto', y, - indices=indices) - elif class_weight == 'balanced_subsample': - curr_sample_weight *= compute_sample_weight('balanced', y, - indices=indices) + simplefilter("ignore", DeprecationWarning) + curr_sample_weight *= compute_sample_weight("auto", y, indices=indices) + elif class_weight == "balanced_subsample": + curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices) tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False) else: @@ -182,22 +197,26 @@ class BaseForest(MultiOutputMixin, BaseEnsemble, metaclass=ABCMeta): """ @abstractmethod - def __init__(self, - base_estimator, - n_estimators=100, *, - estimator_params=tuple(), - bootstrap=False, - oob_score=False, - n_jobs=None, - random_state=None, - verbose=0, - warm_start=False, - class_weight=None, - max_samples=None): + def __init__( + self, + base_estimator, + n_estimators=100, + *, + estimator_params=tuple(), + bootstrap=False, + oob_score=False, + n_jobs=None, + random_state=None, + verbose=0, + warm_start=False, + class_weight=None, + max_samples=None, + ): super().__init__( base_estimator=base_estimator, n_estimators=n_estimators, - estimator_params=estimator_params) + estimator_params=estimator_params, + ) self.bootstrap = bootstrap self.oob_score = oob_score @@ -226,10 +245,11 @@ def apply(self, X): return the index of the leaf x ends up in. """ X = self._validate_X_predict(X) - results = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, - **_joblib_parallel_args(prefer="threads"))( - delayed(tree.apply)(X, check_input=False) - for tree in self.estimators_) + results = Parallel( + n_jobs=self.n_jobs, + verbose=self.verbose, + **_joblib_parallel_args(prefer="threads"), + )(delayed(tree.apply)(X, check_input=False) for tree in self.estimators_) return np.array(results).T @@ -259,10 +279,14 @@ def decision_path(self, X): """ X = self._validate_X_predict(X) - indicators = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, - **_joblib_parallel_args(prefer='threads'))( + indicators = Parallel( + n_jobs=self.n_jobs, + verbose=self.verbose, + **_joblib_parallel_args(prefer="threads"), + )( delayed(tree.decision_path)(X, check_input=False) - for tree in self.estimators_) + for tree in self.estimators_ + ) n_nodes = [0] n_nodes.extend([i.shape[1] for i in indicators]) @@ -298,11 +322,10 @@ def fit(self, X, y, sample_weight=None): """ # Validate or convert input data if issparse(y): - raise ValueError( - "sparse multilabel-indicator for y is not supported." - ) - X, y = self._validate_data(X, y, multi_output=True, - accept_sparse="csc", dtype=DTYPE) + raise ValueError("sparse multilabel-indicator for y is not supported.") + X, y = self._validate_data( + X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE + ) if sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X) @@ -313,10 +336,13 @@ def fit(self, X, y, sample_weight=None): y = np.atleast_1d(y) if y.ndim == 2 and y.shape[1] == 1: - warn("A column-vector y was passed when a 1d array was" - " expected. Please change the shape of y to " - "(n_samples,), for example using ravel().", - DataConversionWarning, stacklevel=2) + warn( + "A column-vector y was passed when a 1d array was" + " expected. Please change the shape of y to " + "(n_samples,), for example using ravel().", + DataConversionWarning, + stacklevel=2, + ) if y.ndim == 1: # reshape is necessary to preserve the data contiguity against vs @@ -325,11 +351,15 @@ def fit(self, X, y, sample_weight=None): if self.criterion == "poisson": if np.any(y < 0): - raise ValueError("Some value(s) of y are negative which is " - "not allowed for Poisson regression.") + raise ValueError( + "Some value(s) of y are negative which is " + "not allowed for Poisson regression." + ) if np.sum(y) <= 0: - raise ValueError("Sum of y is not strictly positive which " - "is necessary for Poisson regression.") + raise ValueError( + "Sum of y is not strictly positive which " + "is necessary for Poisson regression." + ) self.n_outputs_ = y.shape[1] @@ -346,8 +376,7 @@ def fit(self, X, y, sample_weight=None): # Get bootstrap sample size n_samples_bootstrap = _get_n_samples_bootstrap( - n_samples=X.shape[0], - max_samples=self.max_samples + n_samples=X.shape[0], max_samples=self.max_samples ) # Check parameters @@ -359,19 +388,20 @@ def fit(self, X, y, sample_weight=None): "Criterion 'mse' was deprecated in v1.0 and will be " "removed in version 1.2. Use `criterion='squared_error'` " "which is equivalent.", - FutureWarning + FutureWarning, ) elif self.criterion == "mae": warn( "Criterion 'mae' was deprecated in v1.0 and will be " "removed in version 1.2. Use `criterion='absolute_error'` " "which is equivalent.", - FutureWarning + FutureWarning, ) if not self.bootstrap and self.oob_score: - raise ValueError("Out of bag estimation only available" - " if bootstrap=True") + raise ValueError( + "Out of bag estimation only available" " if bootstrap=True" + ) random_state = check_random_state(self.random_state) @@ -382,22 +412,27 @@ def fit(self, X, y, sample_weight=None): n_more_estimators = self.n_estimators - len(self.estimators_) if n_more_estimators < 0: - raise ValueError('n_estimators=%d must be larger or equal to ' - 'len(estimators_)=%d when warm_start==True' - % (self.n_estimators, len(self.estimators_))) + raise ValueError( + "n_estimators=%d must be larger or equal to " + "len(estimators_)=%d when warm_start==True" + % (self.n_estimators, len(self.estimators_)) + ) elif n_more_estimators == 0: - warn("Warm-start fitting without increasing n_estimators does not " - "fit new trees.") + warn( + "Warm-start fitting without increasing n_estimators does not " + "fit new trees." + ) else: if self.warm_start and len(self.estimators_) > 0: # We draw from the random state to get the random state we # would have got if we hadn't used a warm_start. random_state.randint(MAX_INT, size=len(self.estimators_)) - trees = [self._make_estimator(append=False, - random_state=random_state) - for i in range(n_more_estimators)] + trees = [ + self._make_estimator(append=False, random_state=random_state) + for i in range(n_more_estimators) + ] # Parallel loop: we prefer the threading backend as the Cython code # for fitting the trees is internally releasing the Python GIL @@ -405,13 +440,25 @@ def fit(self, X, y, sample_weight=None): # that case. However, for joblib 0.12+ we respect any # parallel_backend contexts set at a higher level, # since correctness does not rely on using threads. - trees = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, - **_joblib_parallel_args(prefer='threads'))( + trees = Parallel( + n_jobs=self.n_jobs, + verbose=self.verbose, + **_joblib_parallel_args(prefer="threads"), + )( delayed(_parallel_build_trees)( - t, self, X, y, sample_weight, i, len(trees), - verbose=self.verbose, class_weight=self.class_weight, - n_samples_bootstrap=n_samples_bootstrap) - for i, t in enumerate(trees)) + t, + self, + X, + y, + sample_weight, + i, + len(trees), + verbose=self.verbose, + class_weight=self.class_weight, + n_samples_bootstrap=n_samples_bootstrap, + ) + for i, t in enumerate(trees) + ) # Collect newly grown trees self.estimators_.extend(trees) @@ -466,8 +513,7 @@ def _compute_oob_predictions(self, X, y): (n_samples, 1, n_outputs) The OOB predictions. """ - X = self._validate_data(X, dtype=DTYPE, accept_sparse='csr', - reset=False) + X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False) n_samples = y.shape[0] n_outputs = self.n_outputs_ @@ -486,16 +532,17 @@ def _compute_oob_predictions(self, X, y): n_oob_pred = np.zeros((n_samples, n_outputs), dtype=np.int64) n_samples_bootstrap = _get_n_samples_bootstrap( - n_samples, self.max_samples, + n_samples, + self.max_samples, ) for estimator in self.estimators_: unsampled_indices = _generate_unsampled_indices( - estimator.random_state, n_samples, n_samples_bootstrap, + estimator.random_state, + n_samples, + n_samples_bootstrap, ) - y_pred = self._get_oob_predictions( - estimator, X[unsampled_indices, :] - ) + y_pred = self._get_oob_predictions(estimator, X[unsampled_indices, :]) oob_pred[unsampled_indices, ...] += y_pred n_oob_pred[unsampled_indices, :] += 1 @@ -504,7 +551,8 @@ def _compute_oob_predictions(self, X, y): warn( "Some inputs do not have OOB scores. This probably means " "too few trees were used to compute any reliable OOB " - "estimates.", UserWarning + "estimates.", + UserWarning, ) n_oob_pred[n_oob_pred == 0] = 1 oob_pred[..., k] /= n_oob_pred[..., [k]] @@ -545,16 +593,18 @@ def feature_importances_(self): """ check_is_fitted(self) - all_importances = Parallel(n_jobs=self.n_jobs, - **_joblib_parallel_args(prefer='threads'))( - delayed(getattr)(tree, 'feature_importances_') - for tree in self.estimators_ if tree.tree_.node_count > 1) + all_importances = Parallel( + n_jobs=self.n_jobs, **_joblib_parallel_args(prefer="threads") + )( + delayed(getattr)(tree, "feature_importances_") + for tree in self.estimators_ + if tree.tree_.node_count > 1 + ) if not all_importances: return np.zeros(self.n_features_in_, dtype=np.float64) - all_importances = np.mean(all_importances, - axis=0, dtype=np.float64) + all_importances = np.mean(all_importances, axis=0, dtype=np.float64) return all_importances / np.sum(all_importances) # TODO: Remove in 1.2 @@ -593,18 +643,21 @@ class ForestClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta): """ @abstractmethod - def __init__(self, - base_estimator, - n_estimators=100, *, - estimator_params=tuple(), - bootstrap=False, - oob_score=False, - n_jobs=None, - random_state=None, - verbose=0, - warm_start=False, - class_weight=None, - max_samples=None): + def __init__( + self, + base_estimator, + n_estimators=100, + *, + estimator_params=tuple(), + bootstrap=False, + oob_score=False, + n_jobs=None, + random_state=None, + verbose=0, + warm_start=False, + class_weight=None, + max_samples=None, + ): super().__init__( base_estimator, n_estimators=n_estimators, @@ -616,7 +669,8 @@ def __init__(self, verbose=verbose, warm_start=warm_start, class_weight=class_weight, - max_samples=max_samples) + max_samples=max_samples, + ) @staticmethod def _get_oob_predictions(tree, X): @@ -659,9 +713,7 @@ def _set_oob_score_and_attributes(self, X, y): self.oob_decision_function_ = super()._compute_oob_predictions(X, y) if self.oob_decision_function_.shape[-1] == 1: # drop the n_outputs axis if there is a single output - self.oob_decision_function_ = self.oob_decision_function_.squeeze( - axis=-1 - ) + self.oob_decision_function_ = self.oob_decision_function_.squeeze(axis=-1) self.oob_score_ = accuracy_score( y, np.argmax(self.oob_decision_function_, axis=1) ) @@ -680,40 +732,42 @@ def _validate_y_class_weight(self, y): y_store_unique_indices = np.zeros(y.shape, dtype=int) for k in range(self.n_outputs_): - classes_k, y_store_unique_indices[:, k] = \ - np.unique(y[:, k], return_inverse=True) + classes_k, y_store_unique_indices[:, k] = np.unique( + y[:, k], return_inverse=True + ) self.classes_.append(classes_k) self.n_classes_.append(classes_k.shape[0]) y = y_store_unique_indices if self.class_weight is not None: - valid_presets = ('balanced', 'balanced_subsample') + valid_presets = ("balanced", "balanced_subsample") if isinstance(self.class_weight, str): if self.class_weight not in valid_presets: - raise ValueError('Valid presets for class_weight include ' - '"balanced" and "balanced_subsample".' - 'Given "%s".' - % self.class_weight) + raise ValueError( + "Valid presets for class_weight include " + '"balanced" and "balanced_subsample".' + 'Given "%s".' % self.class_weight + ) if self.warm_start: - warn('class_weight presets "balanced" or ' - '"balanced_subsample" are ' - 'not recommended for warm_start if the fitted data ' - 'differs from the full dataset. In order to use ' - '"balanced" weights, use compute_class_weight ' - '("balanced", classes, y). In place of y you can use ' - 'a large enough sample of the full training set ' - 'target to properly estimate the class frequency ' - 'distributions. Pass the resulting weights as the ' - 'class_weight parameter.') - - if (self.class_weight != 'balanced_subsample' or - not self.bootstrap): + warn( + 'class_weight presets "balanced" or ' + '"balanced_subsample" are ' + "not recommended for warm_start if the fitted data " + "differs from the full dataset. In order to use " + '"balanced" weights, use compute_class_weight ' + '("balanced", classes, y). In place of y you can use ' + "a large enough sample of the full training set " + "target to properly estimate the class frequency " + "distributions. Pass the resulting weights as the " + "class_weight parameter." + ) + + if self.class_weight != "balanced_subsample" or not self.bootstrap: if self.class_weight == "balanced_subsample": class_weight = "balanced" else: class_weight = self.class_weight - expanded_class_weight = compute_sample_weight(class_weight, - y_original) + expanded_class_weight = compute_sample_weight(class_weight, y_original) return y, expanded_class_weight @@ -747,13 +801,12 @@ def predict(self, X): n_samples = proba[0].shape[0] # all dtypes should be the same, so just take the first class_type = self.classes_[0].dtype - predictions = np.empty((n_samples, self.n_outputs_), - dtype=class_type) + predictions = np.empty((n_samples, self.n_outputs_), dtype=class_type) for k in range(self.n_outputs_): - predictions[:, k] = self.classes_[k].take(np.argmax(proba[k], - axis=1), - axis=0) + predictions[:, k] = self.classes_[k].take( + np.argmax(proba[k], axis=1), axis=0 + ) return predictions @@ -787,14 +840,19 @@ def predict_proba(self, X): n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs) # avoid storing the output of every estimator by summing them here - all_proba = [np.zeros((X.shape[0], j), dtype=np.float64) - for j in np.atleast_1d(self.n_classes_)] + all_proba = [ + np.zeros((X.shape[0], j), dtype=np.float64) + for j in np.atleast_1d(self.n_classes_) + ] lock = threading.Lock() - Parallel(n_jobs=n_jobs, verbose=self.verbose, - **_joblib_parallel_args(require="sharedmem"))( - delayed(_accumulate_prediction)(e.predict_proba, X, all_proba, - lock) - for e in self.estimators_) + Parallel( + n_jobs=n_jobs, + verbose=self.verbose, + **_joblib_parallel_args(require="sharedmem"), + )( + delayed(_accumulate_prediction)(e.predict_proba, X, all_proba, lock) + for e in self.estimators_ + ) for proba in all_proba: proba /= len(self.estimators_) @@ -846,17 +904,20 @@ class ForestRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta): """ @abstractmethod - def __init__(self, - base_estimator, - n_estimators=100, *, - estimator_params=tuple(), - bootstrap=False, - oob_score=False, - n_jobs=None, - random_state=None, - verbose=0, - warm_start=False, - max_samples=None): + def __init__( + self, + base_estimator, + n_estimators=100, + *, + estimator_params=tuple(), + bootstrap=False, + oob_score=False, + n_jobs=None, + random_state=None, + verbose=0, + warm_start=False, + max_samples=None, + ): super().__init__( base_estimator, n_estimators=n_estimators, @@ -867,7 +928,8 @@ def __init__(self, random_state=random_state, verbose=verbose, warm_start=warm_start, - max_samples=max_samples) + max_samples=max_samples, + ) def predict(self, X): """ @@ -903,10 +965,14 @@ def predict(self, X): # Parallel loop lock = threading.Lock() - Parallel(n_jobs=n_jobs, verbose=self.verbose, - **_joblib_parallel_args(require="sharedmem"))( + Parallel( + n_jobs=n_jobs, + verbose=self.verbose, + **_joblib_parallel_args(require="sharedmem"), + )( delayed(_accumulate_prediction)(e.predict, X, [y_hat], lock) - for e in self.estimators_) + for e in self.estimators_ + ) y_hat /= len(self.estimators_) @@ -947,9 +1013,7 @@ def _set_oob_score_and_attributes(self, X, y): y : ndarray of shape (n_samples, n_outputs) The target matrix. """ - self.oob_prediction_ = super()._compute_oob_predictions(X, y).squeeze( - axis=1 - ) + self.oob_prediction_ = super()._compute_oob_predictions(X, y).squeeze(axis=1) if self.oob_prediction_.shape[-1] == 1: # drop the n_outputs axis if there is a single output self.oob_prediction_ = self.oob_prediction_.squeeze(axis=-1) @@ -972,15 +1036,17 @@ def _compute_partial_dependence_recursion(self, grid, target_features): averaged_predictions : ndarray of shape (n_samples,) The value of the partial dependence function on each grid point. """ - grid = np.asarray(grid, dtype=DTYPE, order='C') - averaged_predictions = np.zeros(shape=grid.shape[0], - dtype=np.float64, order='C') + grid = np.asarray(grid, dtype=DTYPE, order="C") + averaged_predictions = np.zeros( + shape=grid.shape[0], dtype=np.float64, order="C" + ) for tree in self.estimators_: # Note: we don't sum in parallel because the GIL isn't released in # the fast method. tree.tree_.compute_partial_dependence( - grid, target_features, averaged_predictions) + grid, target_features, averaged_predictions + ) # Average over the forest averaged_predictions /= len(self.estimators_) @@ -1260,35 +1326,47 @@ class labels (multi-output problem). >>> print(clf.predict([[0, 0, 0, 0]])) [1] """ - def __init__(self, - n_estimators=100, *, - criterion="gini", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0., - max_features="auto", - max_leaf_nodes=None, - min_impurity_decrease=0., - bootstrap=True, - oob_score=False, - n_jobs=None, - random_state=None, - verbose=0, - warm_start=False, - class_weight=None, - ccp_alpha=0.0, - max_samples=None, - monotonic_cst=None): + + def __init__( + self, + n_estimators=100, + *, + criterion="gini", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features="auto", + max_leaf_nodes=None, + min_impurity_decrease=0.0, + bootstrap=True, + oob_score=False, + n_jobs=None, + random_state=None, + verbose=0, + warm_start=False, + class_weight=None, + ccp_alpha=0.0, + max_samples=None, + monotonic_cst=None, + ): super().__init__( base_estimator=DecisionTreeClassifier(), n_estimators=n_estimators, - estimator_params=("criterion", "max_depth", "min_samples_split", - "min_samples_leaf", "min_weight_fraction_leaf", - "max_features", "max_leaf_nodes", - "min_impurity_decrease", "random_state", - "ccp_alpha", "monotonic_cst"), + estimator_params=( + "criterion", + "max_depth", + "min_samples_split", + "min_samples_leaf", + "min_weight_fraction_leaf", + "max_features", + "max_leaf_nodes", + "min_impurity_decrease", + "random_state", + "ccp_alpha", + "monotonic_cst", + ), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1296,7 +1374,8 @@ def __init__(self, verbose=verbose, warm_start=warm_start, class_weight=class_weight, - max_samples=max_samples) + max_samples=max_samples, + ) self.criterion = criterion self.max_depth = max_depth @@ -1571,40 +1650,52 @@ class RandomForestRegressor(ForestRegressor): [-8.32987858] """ - def __init__(self, - n_estimators=100, *, - criterion="squared_error", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0., - max_features="auto", - max_leaf_nodes=None, - min_impurity_decrease=0., - bootstrap=True, - oob_score=False, - n_jobs=None, - random_state=None, - verbose=0, - warm_start=False, - ccp_alpha=0.0, - max_samples=None, - monotonic_cst=None): + def __init__( + self, + n_estimators=100, + *, + criterion="squared_error", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features="auto", + max_leaf_nodes=None, + min_impurity_decrease=0.0, + bootstrap=True, + oob_score=False, + n_jobs=None, + random_state=None, + verbose=0, + warm_start=False, + ccp_alpha=0.0, + max_samples=None, + monotonic_cst=None, + ): super().__init__( base_estimator=DecisionTreeRegressor(), n_estimators=n_estimators, - estimator_params=("criterion", "max_depth", "min_samples_split", - "min_samples_leaf", "min_weight_fraction_leaf", - "max_features", "max_leaf_nodes", - "min_impurity_decrease", "random_state", - "ccp_alpha", "monotonic_cst"), + estimator_params=( + "criterion", + "max_depth", + "min_samples_split", + "min_samples_leaf", + "min_weight_fraction_leaf", + "max_features", + "max_leaf_nodes", + "min_impurity_decrease", + "random_state", + "ccp_alpha", + "monotonic_cst", + ), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, random_state=random_state, verbose=verbose, warm_start=warm_start, - max_samples=max_samples) + max_samples=max_samples, + ) self.criterion = criterion self.max_depth = max_depth @@ -1887,34 +1978,45 @@ class labels (multi-output problem). array([1]) """ - def __init__(self, - n_estimators=100, *, - criterion="gini", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0., - max_features="auto", - max_leaf_nodes=None, - min_impurity_decrease=0., - bootstrap=False, - oob_score=False, - n_jobs=None, - random_state=None, - verbose=0, - warm_start=False, - class_weight=None, - ccp_alpha=0.0, - max_samples=None, - monotonic_cst=None): + def __init__( + self, + n_estimators=100, + *, + criterion="gini", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features="auto", + max_leaf_nodes=None, + min_impurity_decrease=0.0, + bootstrap=False, + oob_score=False, + n_jobs=None, + random_state=None, + verbose=0, + warm_start=False, + class_weight=None, + ccp_alpha=0.0, + max_samples=None, + monotonic_cst=None, + ): super().__init__( base_estimator=ExtraTreeClassifier(), n_estimators=n_estimators, - estimator_params=("criterion", "max_depth", "min_samples_split", - "min_samples_leaf", "min_weight_fraction_leaf", - "max_features", "max_leaf_nodes", - "min_impurity_decrease", "random_state", - "ccp_alpha", "monotonic_cst"), + estimator_params=( + "criterion", + "max_depth", + "min_samples_split", + "min_samples_leaf", + "min_weight_fraction_leaf", + "max_features", + "max_leaf_nodes", + "min_impurity_decrease", + "random_state", + "ccp_alpha", + "monotonic_cst", + ), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1922,7 +2024,8 @@ def __init__(self, verbose=verbose, warm_start=warm_start, class_weight=class_weight, - max_samples=max_samples) + max_samples=max_samples, + ) self.criterion = criterion self.max_depth = max_depth @@ -2182,40 +2285,52 @@ class ExtraTreesRegressor(ForestRegressor): 0.2708... """ - def __init__(self, - n_estimators=100, *, - criterion="squared_error", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0., - max_features="auto", - max_leaf_nodes=None, - min_impurity_decrease=0., - bootstrap=False, - oob_score=False, - n_jobs=None, - random_state=None, - verbose=0, - warm_start=False, - ccp_alpha=0.0, - max_samples=None, - monotonic_cst=None): + def __init__( + self, + n_estimators=100, + *, + criterion="squared_error", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features="auto", + max_leaf_nodes=None, + min_impurity_decrease=0.0, + bootstrap=False, + oob_score=False, + n_jobs=None, + random_state=None, + verbose=0, + warm_start=False, + ccp_alpha=0.0, + max_samples=None, + monotonic_cst=None, + ): super().__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, - estimator_params=("criterion", "max_depth", "min_samples_split", - "min_samples_leaf", "min_weight_fraction_leaf", - "max_features", "max_leaf_nodes", - "min_impurity_decrease", "random_state", - "ccp_alpha", "monotonic_cst"), + estimator_params=( + "criterion", + "max_depth", + "min_samples_split", + "min_samples_leaf", + "min_weight_fraction_leaf", + "max_features", + "max_leaf_nodes", + "min_impurity_decrease", + "random_state", + "ccp_alpha", + "monotonic_cst", + ), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, random_state=random_state, verbose=verbose, warm_start=warm_start, - max_samples=max_samples) + max_samples=max_samples, + ) self.criterion = criterion self.max_depth = max_depth @@ -2398,35 +2513,46 @@ class RandomTreesEmbedding(BaseForest): criterion = "squared_error" max_features = 1 - def __init__(self, - n_estimators=100, *, - max_depth=5, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0., - max_leaf_nodes=None, - min_impurity_decrease=0., - sparse_output=True, - n_jobs=None, - random_state=None, - verbose=0, - warm_start=False, - monotonic_cst=None): + def __init__( + self, + n_estimators=100, + *, + max_depth=5, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_leaf_nodes=None, + min_impurity_decrease=0.0, + sparse_output=True, + n_jobs=None, + random_state=None, + verbose=0, + warm_start=False, + monotonic_cst=None, + ): super().__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, - estimator_params=("criterion", "max_depth", "min_samples_split", - "min_samples_leaf", "min_weight_fraction_leaf", - "max_features", "max_leaf_nodes", - "min_impurity_decrease", "random_state", - "monotonic_cst"), + estimator_params=( + "criterion", + "max_depth", + "min_samples_split", + "min_samples_leaf", + "min_weight_fraction_leaf", + "max_features", + "max_leaf_nodes", + "min_impurity_decrease", + "random_state", + "monotonic_cst", + ), bootstrap=False, oob_score=False, n_jobs=n_jobs, random_state=random_state, verbose=verbose, warm_start=warm_start, - max_samples=None) + max_samples=None, + ) self.max_depth = max_depth self.min_samples_split = min_samples_split @@ -2494,7 +2620,7 @@ def fit_transform(self, X, y=None, sample_weight=None): X_transformed : sparse matrix of shape (n_samples, n_out) Transformed dataset. """ - X = self._validate_data(X, accept_sparse=['csc']) + X = self._validate_data(X, accept_sparse=["csc"]) if issparse(X): # Pre-sort indices to avoid that each individual tree of the # ensemble sorts the indices. diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 7c191b08ea80d..ecf51091aeeb8 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -66,6 +66,7 @@ class VerboseReporter: (when iteration mod verbose_mod is zero).; if larger than 1 then output is printed for each update. """ + def __init__(self, verbose): self.verbose = verbose @@ -81,20 +82,19 @@ def init(self, est, begin_at_stage=0): stage at which to begin reporting """ # header fields and line format str - header_fields = ['Iter', 'Train Loss'] - verbose_fmt = ['{iter:>10d}', '{train_score:>16.4f}'] + header_fields = ["Iter", "Train Loss"] + verbose_fmt = ["{iter:>10d}", "{train_score:>16.4f}"] # do oob? if est.subsample < 1: - header_fields.append('OOB Improve') - verbose_fmt.append('{oob_impr:>16.4f}') - header_fields.append('Remaining Time') - verbose_fmt.append('{remaining_time:>16s}') + header_fields.append("OOB Improve") + verbose_fmt.append("{oob_impr:>16.4f}") + header_fields.append("Remaining Time") + verbose_fmt.append("{remaining_time:>16s}") # print the header line - print(('%10s ' + '%16s ' * - (len(header_fields) - 1)) % tuple(header_fields)) + print(("%10s " + "%16s " * (len(header_fields) - 1)) % tuple(header_fields)) - self.verbose_fmt = ' '.join(verbose_fmt) + self.verbose_fmt = " ".join(verbose_fmt) # plot verbose info each time i % verbose_mod == 0 self.verbose_mod = 1 self.start_time = time() @@ -115,16 +115,21 @@ def update(self, j, est): i = j - self.begin_at_stage # iteration relative to the start iter if (i + 1) % self.verbose_mod == 0: oob_impr = est.oob_improvement_[j] if do_oob else 0 - remaining_time = ((est.n_estimators - (j + 1)) * - (time() - self.start_time) / float(i + 1)) + remaining_time = ( + (est.n_estimators - (j + 1)) * (time() - self.start_time) / float(i + 1) + ) if remaining_time > 60: - remaining_time = '{0:.2f}m'.format(remaining_time / 60.0) + remaining_time = "{0:.2f}m".format(remaining_time / 60.0) else: - remaining_time = '{0:.2f}s'.format(remaining_time) - print(self.verbose_fmt.format(iter=j + 1, - train_score=est.train_score_[j], - oob_impr=oob_impr, - remaining_time=remaining_time)) + remaining_time = "{0:.2f}s".format(remaining_time) + print( + self.verbose_fmt.format( + iter=j + 1, + train_score=est.train_score_[j], + oob_impr=oob_impr, + remaining_time=remaining_time, + ) + ) if self.verbose == 1 and ((i + 1) // (self.verbose_mod * 10) > 0): # adjust verbose frequency (powers of 10) self.verbose_mod *= 10 @@ -134,13 +139,32 @@ class BaseGradientBoosting(BaseEnsemble, metaclass=ABCMeta): """Abstract base class for Gradient Boosting.""" @abstractmethod - def __init__(self, *, loss, learning_rate, n_estimators, criterion, - min_samples_split, min_samples_leaf, min_weight_fraction_leaf, - max_depth, min_impurity_decrease, init, subsample, - max_features, ccp_alpha, random_state, alpha=0.9, verbose=0, - max_leaf_nodes=None, warm_start=False, - validation_fraction=0.1, n_iter_no_change=None, tol=1e-4, - monotonic_cst=None): + def __init__( + self, + *, + loss, + learning_rate, + n_estimators, + criterion, + min_samples_split, + min_samples_leaf, + min_weight_fraction_leaf, + max_depth, + min_impurity_decrease, + init, + subsample, + max_features, + ccp_alpha, + random_state, + alpha=0.9, + verbose=0, + max_leaf_nodes=None, + warm_start=False, + validation_fraction=0.1, + n_iter_no_change=None, + tol=1e-4, + monotonic_cst=None, + ): self.n_estimators = n_estimators self.learning_rate = learning_rate @@ -169,8 +193,18 @@ def __init__(self, *, loss, learning_rate, n_estimators, criterion, def _validate_y(self, y, sample_weight=None): """Called by fit to validate y.""" - def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask, - random_state, X_csc=None, X_csr=None): + def _fit_stage( + self, + i, + X, + y, + raw_predictions, + sample_weight, + sample_mask, + random_state, + X_csc=None, + X_csr=None, + ): """Fit another stage of ``_n_classes`` trees to the boosting model.""" assert sample_mask.dtype == bool @@ -187,13 +221,14 @@ def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask, if loss.is_multi_class: y = np.array(original_y == k, dtype=np.float64) - residual = loss.negative_gradient(y, raw_predictions_copy, k=k, - sample_weight=sample_weight) + residual = loss.negative_gradient( + y, raw_predictions_copy, k=k, sample_weight=sample_weight + ) # induce regression tree on residuals tree = DecisionTreeRegressor( criterion=self.criterion, - splitter='best', + splitter="best", max_depth=self.max_depth, min_samples_split=self.min_samples_split, min_samples_leaf=self.min_samples_leaf, @@ -203,20 +238,28 @@ def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask, max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, ccp_alpha=self.ccp_alpha, - monotonic_cst=self.monotonic_cst) + monotonic_cst=self.monotonic_cst, + ) if self.subsample < 1.0: # no inplace multiplication! sample_weight = sample_weight * sample_mask.astype(np.float64) X = X_csr if X_csr is not None else X - tree.fit(X, residual, sample_weight=sample_weight, - check_input=False) + tree.fit(X, residual, sample_weight=sample_weight, check_input=False) # update tree leaves loss.update_terminal_regions( - tree.tree_, X, y, residual, raw_predictions, sample_weight, - sample_mask, learning_rate=self.learning_rate, k=k) + tree.tree_, + X, + y, + residual, + raw_predictions, + sample_weight, + sample_mask, + learning_rate=self.learning_rate, + k=k, + ) # add tree to ensemble self.estimators_[i, k] = tree @@ -226,33 +269,44 @@ def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask, def _check_params(self): """Check validity of parameters and raise ValueError if not valid.""" if self.n_estimators <= 0: - raise ValueError("n_estimators must be greater than 0 but " - "was %r" % self.n_estimators) + raise ValueError( + "n_estimators must be greater than 0 but " "was %r" % self.n_estimators + ) if self.learning_rate <= 0.0: - raise ValueError("learning_rate must be greater than 0 but " - "was %r" % self.learning_rate) + raise ValueError( + "learning_rate must be greater than 0 but " + "was %r" % self.learning_rate + ) - if (self.loss not in self._SUPPORTED_LOSS - or self.loss not in _gb_losses.LOSS_FUNCTIONS): + if ( + self.loss not in self._SUPPORTED_LOSS + or self.loss not in _gb_losses.LOSS_FUNCTIONS + ): raise ValueError("Loss '{0:s}' not supported. ".format(self.loss)) # TODO: Remove in v1.2 if self.loss == "ls": - warnings.warn("The loss 'ls' was deprecated in v1.0 and " - "will be removed in version 1.2. Use 'squared_error'" - " which is equivalent.", - FutureWarning) + warnings.warn( + "The loss 'ls' was deprecated in v1.0 and " + "will be removed in version 1.2. Use 'squared_error'" + " which is equivalent.", + FutureWarning, + ) elif self.loss == "lad": - warnings.warn("The loss 'lad' was deprecated in v1.0 and " - "will be removed in version 1.2. Use " - "'absolute_error' which is equivalent.", - FutureWarning) - - if self.loss == 'deviance': - loss_class = (_gb_losses.MultinomialDeviance - if len(self.classes_) > 2 - else _gb_losses.BinomialDeviance) + warnings.warn( + "The loss 'lad' was deprecated in v1.0 and " + "will be removed in version 1.2. Use " + "'absolute_error' which is equivalent.", + FutureWarning, + ) + + if self.loss == "deviance": + loss_class = ( + _gb_losses.MultinomialDeviance + if len(self.classes_) > 2 + else _gb_losses.BinomialDeviance + ) else: loss_class = _gb_losses.LOSS_FUNCTIONS[self.loss] @@ -264,22 +318,22 @@ def _check_params(self): self.loss_ = loss_class() if not (0.0 < self.subsample <= 1.0): - raise ValueError("subsample must be in (0,1] but " - "was %r" % self.subsample) + raise ValueError( + "subsample must be in (0,1] but " "was %r" % self.subsample + ) if self.init is not None: # init must be an estimator or 'zero' if isinstance(self.init, BaseEstimator): self.loss_.check_init_estimator(self.init) - elif not (isinstance(self.init, str) and self.init == 'zero'): + elif not (isinstance(self.init, str) and self.init == "zero"): raise ValueError( "The init parameter must be an estimator or 'zero'. " "Got init={}".format(self.init) ) if not (0.0 < self.alpha < 1.0): - raise ValueError("alpha must be in (0.0, 1.0) but " - "was %r" % self.alpha) + raise ValueError("alpha must be in (0.0, 1.0) but " "was %r" % self.alpha) if isinstance(self.max_features, str): if self.max_features == "auto": @@ -292,54 +346,53 @@ def _check_params(self): elif self.max_features == "log2": max_features = max(1, int(np.log2(self.n_features_in_))) else: - raise ValueError("Invalid value for max_features: %r. " - "Allowed string values are 'auto', 'sqrt' " - "or 'log2'." % self.max_features) + raise ValueError( + "Invalid value for max_features: %r. " + "Allowed string values are 'auto', 'sqrt' " + "or 'log2'." % self.max_features + ) elif self.max_features is None: max_features = self.n_features_in_ elif isinstance(self.max_features, numbers.Integral): max_features = self.max_features else: # float - if 0. < self.max_features <= 1.: - max_features = max(int(self.max_features * - self.n_features_in_), 1) + if 0.0 < self.max_features <= 1.0: + max_features = max(int(self.max_features * self.n_features_in_), 1) else: raise ValueError("max_features must be in (0, n_features]") self.max_features_ = max_features - if not isinstance(self.n_iter_no_change, - (numbers.Integral, type(None))): - raise ValueError("n_iter_no_change should either be None or an " - "integer. %r was passed" - % self.n_iter_no_change) + if not isinstance(self.n_iter_no_change, (numbers.Integral, type(None))): + raise ValueError( + "n_iter_no_change should either be None or an " + "integer. %r was passed" % self.n_iter_no_change + ) def _init_state(self): - """Initialize model state and allocate model state data structures. """ + """Initialize model state and allocate model state data structures.""" self.init_ = self.init if self.init_ is None: self.init_ = self.loss_.init_estimator() - self.estimators_ = np.empty((self.n_estimators, self.loss_.K), - dtype=object) + self.estimators_ = np.empty((self.n_estimators, self.loss_.K), dtype=object) self.train_score_ = np.zeros((self.n_estimators,), dtype=np.float64) # do oob? if self.subsample < 1.0: - self.oob_improvement_ = np.zeros((self.n_estimators), - dtype=np.float64) + self.oob_improvement_ = np.zeros((self.n_estimators), dtype=np.float64) def _clear_state(self): - """Clear the state of the gradient boosting model. """ - if hasattr(self, 'estimators_'): + """Clear the state of the gradient boosting model.""" + if hasattr(self, "estimators_"): self.estimators_ = np.empty((0, 0), dtype=object) - if hasattr(self, 'train_score_'): + if hasattr(self, "train_score_"): del self.train_score_ - if hasattr(self, 'oob_improvement_'): + if hasattr(self, "oob_improvement_"): del self.oob_improvement_ - if hasattr(self, 'init_'): + if hasattr(self, "init_"): del self.init_ - if hasattr(self, '_rng'): + if hasattr(self, "_rng"): del self._rng def _resize_state(self): @@ -347,23 +400,28 @@ def _resize_state(self): # self.n_estimators is the number of additional est to fit total_n_estimators = self.n_estimators if total_n_estimators < self.estimators_.shape[0]: - raise ValueError('resize with smaller n_estimators %d < %d' % - (total_n_estimators, self.estimators_[0])) + raise ValueError( + "resize with smaller n_estimators %d < %d" + % (total_n_estimators, self.estimators_[0]) + ) - self.estimators_ = np.resize(self.estimators_, - (total_n_estimators, self.loss_.K)) + self.estimators_ = np.resize( + self.estimators_, (total_n_estimators, self.loss_.K) + ) self.train_score_ = np.resize(self.train_score_, total_n_estimators) - if (self.subsample < 1 or hasattr(self, 'oob_improvement_')): + if self.subsample < 1 or hasattr(self, "oob_improvement_"): # if do oob resize arrays or create new if not available - if hasattr(self, 'oob_improvement_'): - self.oob_improvement_ = np.resize(self.oob_improvement_, - total_n_estimators) + if hasattr(self, "oob_improvement_"): + self.oob_improvement_ = np.resize( + self.oob_improvement_, total_n_estimators + ) else: - self.oob_improvement_ = np.zeros((total_n_estimators,), - dtype=np.float64) + self.oob_improvement_ = np.zeros( + (total_n_estimators,), dtype=np.float64 + ) def _is_initialized(self): - return len(getattr(self, 'estimators_', [])) > 0 + return len(getattr(self, "estimators_", [])) > 0 def _check_initialized(self): """Check that the estimator is initialized, raising an error if not.""" @@ -408,17 +466,17 @@ def fit(self, X, y, sample_weight=None, monitor=None): ------- self : object """ - if self.criterion in ('absolute_error', 'mae'): + if self.criterion in ("absolute_error", "mae"): # TODO: This should raise an error from 1.1 self._warn_mae_for_criterion() - if self.criterion == 'mse': + if self.criterion == "mse": # TODO: Remove in v1.2. By then it should raise an error. warnings.warn( "Criterion 'mse' was deprecated in v1.0 and will be " "removed in version 1.2. Use `criterion='squared_error'` " "which is equivalent.", - FutureWarning + FutureWarning, ) # if not warmstart - clear the estimator state @@ -429,8 +487,9 @@ def fit(self, X, y, sample_weight=None, monitor=None): # Since check_array converts both X and y to the same dtype, but the # trees use different types for X and y, checking them separately. - X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc', 'coo'], - dtype=DTYPE, multi_output=True) + X, y = self._validate_data( + X, y, accept_sparse=["csr", "csc", "coo"], dtype=DTYPE, multi_output=True + ) sample_weight_is_none = sample_weight is None @@ -445,11 +504,14 @@ def fit(self, X, y, sample_weight=None, monitor=None): if self.n_iter_no_change is not None: stratify = y if is_classifier(self) else None - X, X_val, y, y_val, sample_weight, sample_weight_val = ( - train_test_split(X, y, sample_weight, - random_state=self.random_state, - test_size=self.validation_fraction, - stratify=stratify)) + X, X_val, y, y_val, sample_weight, sample_weight_val = train_test_split( + X, + y, + sample_weight, + random_state=self.random_state, + test_size=self.validation_fraction, + stratify=stratify, + ) if is_classifier(self): if self._n_classes != np.unique(y).shape[0]: # We choose to error here. The problem is that the init @@ -457,9 +519,9 @@ def fit(self, X, y, sample_weight=None, monitor=None): # classes now, so its predictions would not have the # correct shape. raise ValueError( - 'The training data after the early stopping split ' - 'is missing some classes. Try using another random ' - 'seed.' + "The training data after the early stopping split " + "is missing some classes. Try using another random " + "seed." ) else: X_val = y_val = sample_weight_val = None @@ -471,31 +533,35 @@ def fit(self, X, y, sample_weight=None, monitor=None): self._init_state() # fit initial model and initialize raw predictions - if self.init_ == 'zero': - raw_predictions = np.zeros(shape=(X.shape[0], self.loss_.K), - dtype=np.float64) + if self.init_ == "zero": + raw_predictions = np.zeros( + shape=(X.shape[0], self.loss_.K), dtype=np.float64 + ) else: # XXX clean this once we have a support_sample_weight tag if sample_weight_is_none: self.init_.fit(X, y) else: - msg = ("The initial estimator {} does not support sample " - "weights.".format(self.init_.__class__.__name__)) + msg = ( + "The initial estimator {} does not support sample " + "weights.".format(self.init_.__class__.__name__) + ) try: self.init_.fit(X, y, sample_weight=sample_weight) except TypeError as e: # regular estimator without SW support raise ValueError(msg) from e except ValueError as e: - if "pass parameters to specific steps of "\ - "your pipeline using the "\ - "stepname__parameter" in str(e): # pipeline + if ( + "pass parameters to specific steps of " + "your pipeline using the " + "stepname__parameter" in str(e) + ): # pipeline raise ValueError(msg) from e else: # regular estimator whose input checking failed raise - raw_predictions = \ - self.loss_.get_init_raw_predictions(X, self.init_) + raw_predictions = self.loss_.get_init_raw_predictions(X, self.init_) begin_at_stage = 0 @@ -506,37 +572,56 @@ def fit(self, X, y, sample_weight=None, monitor=None): # add more estimators to fitted model # invariant: warm_start = True if self.n_estimators < self.estimators_.shape[0]: - raise ValueError('n_estimators=%d must be larger or equal to ' - 'estimators_.shape[0]=%d when ' - 'warm_start==True' - % (self.n_estimators, - self.estimators_.shape[0])) + raise ValueError( + "n_estimators=%d must be larger or equal to " + "estimators_.shape[0]=%d when " + "warm_start==True" % (self.n_estimators, self.estimators_.shape[0]) + ) begin_at_stage = self.estimators_.shape[0] # The requirements of _decision_function (called in two lines # below) are more constrained than fit. It accepts only CSR # matrices. - X = check_array(X, dtype=DTYPE, order="C", accept_sparse='csr') + X = check_array(X, dtype=DTYPE, order="C", accept_sparse="csr") raw_predictions = self._raw_predict(X) self._resize_state() # fit the boosting stages n_stages = self._fit_stages( - X, y, raw_predictions, sample_weight, self._rng, X_val, y_val, - sample_weight_val, begin_at_stage, monitor) + X, + y, + raw_predictions, + sample_weight, + self._rng, + X_val, + y_val, + sample_weight_val, + begin_at_stage, + monitor, + ) # change shape of arrays after fit (early-stopping or additional ests) if n_stages != self.estimators_.shape[0]: self.estimators_ = self.estimators_[:n_stages] self.train_score_ = self.train_score_[:n_stages] - if hasattr(self, 'oob_improvement_'): + if hasattr(self, "oob_improvement_"): self.oob_improvement_ = self.oob_improvement_[:n_stages] self.n_estimators_ = n_stages return self - def _fit_stages(self, X, y, raw_predictions, sample_weight, random_state, - X_val, y_val, sample_weight_val, - begin_at_stage=0, monitor=None): + def _fit_stages( + self, + X, + y, + raw_predictions, + sample_weight, + random_state, + X_val, + y_val, + sample_weight_val, + begin_at_stage=0, + monitor=None, + ): """Iteratively fits the stages. For each stage it computes the progress (OOB, train score) @@ -546,7 +631,7 @@ def _fit_stages(self, X, y, raw_predictions, sample_weight, random_state, """ n_samples = X.shape[0] do_oob = self.subsample < 1.0 - sample_mask = np.ones((n_samples, ), dtype=bool) + sample_mask = np.ones((n_samples,), dtype=bool) n_inbag = max(1, int(self.subsample * n_samples)) loss_ = self.loss_ @@ -569,27 +654,39 @@ def _fit_stages(self, X, y, raw_predictions, sample_weight, random_state, # subsampling if do_oob: - sample_mask = _random_sample_mask(n_samples, n_inbag, - random_state) + sample_mask = _random_sample_mask(n_samples, n_inbag, random_state) # OOB score before adding this stage - old_oob_score = loss_(y[~sample_mask], - raw_predictions[~sample_mask], - sample_weight[~sample_mask]) + old_oob_score = loss_( + y[~sample_mask], + raw_predictions[~sample_mask], + sample_weight[~sample_mask], + ) # fit next stage of trees raw_predictions = self._fit_stage( - i, X, y, raw_predictions, sample_weight, sample_mask, - random_state, X_csc, X_csr) + i, + X, + y, + raw_predictions, + sample_weight, + sample_mask, + random_state, + X_csc, + X_csr, + ) # track deviance (= loss) if do_oob: - self.train_score_[i] = loss_(y[sample_mask], - raw_predictions[sample_mask], - sample_weight[sample_mask]) - self.oob_improvement_[i] = ( - old_oob_score - loss_(y[~sample_mask], - raw_predictions[~sample_mask], - sample_weight[~sample_mask])) + self.train_score_[i] = loss_( + y[sample_mask], + raw_predictions[sample_mask], + sample_weight[sample_mask], + ) + self.oob_improvement_[i] = old_oob_score - loss_( + y[~sample_mask], + raw_predictions[~sample_mask], + sample_weight[~sample_mask], + ) else: # no need to fancy index w/ no subsampling self.train_score_[i] = loss_(y, raw_predictions, sample_weight) @@ -607,8 +704,7 @@ def _fit_stages(self, X, y, raw_predictions, sample_weight, random_state, if self.n_iter_no_change is not None: # By calling next(y_val_pred_iter), we get the predictions # for X_val after the addition of the current stage - validation_loss = loss_(y_val, next(y_val_pred_iter), - sample_weight_val) + validation_loss = loss_(y_val, next(y_val_pred_iter), sample_weight_val) # Require validation_score to be better (less) than at least # one of the last n_iter_no_change evaluations @@ -627,19 +723,20 @@ def _raw_predict_init(self, X): """Check input and compute raw predictions of the init estimator.""" self._check_initialized() X = self.estimators_[0, 0]._validate_X_predict(X, check_input=True) - if self.init_ == 'zero': - raw_predictions = np.zeros(shape=(X.shape[0], self.loss_.K), - dtype=np.float64) + if self.init_ == "zero": + raw_predictions = np.zeros( + shape=(X.shape[0], self.loss_.K), dtype=np.float64 + ) else: - raw_predictions = self.loss_.get_init_raw_predictions( - X, self.init_).astype(np.float64) + raw_predictions = self.loss_.get_init_raw_predictions(X, self.init_).astype( + np.float64 + ) return raw_predictions def _raw_predict(self, X): """Return the sum of the trees raw predictions (+ init estimator).""" raw_predictions = self._raw_predict_init(X) - predict_stages(self.estimators_, X, self.learning_rate, - raw_predictions) + predict_stages(self.estimators_, X, self.learning_rate, raw_predictions) return raw_predictions def _staged_raw_predict(self, X): @@ -663,12 +760,12 @@ def _staged_raw_predict(self, X): Regression and binary classification are special cases with ``k == 1``, otherwise ``k==n_classes``. """ - X = self._validate_data(X, dtype=DTYPE, order="C", accept_sparse='csr', - reset=False) + X = self._validate_data( + X, dtype=DTYPE, order="C", accept_sparse="csr", reset=False + ) raw_predictions = self._raw_predict_init(X) for i in range(self.estimators_.shape[0]): - predict_stage(self.estimators_, i, X, self.learning_rate, - raw_predictions) + predict_stage(self.estimators_, i, X, self.learning_rate, raw_predictions) yield raw_predictions.copy() @property @@ -693,9 +790,12 @@ def feature_importances_(self): """ self._check_initialized() - relevant_trees = [tree - for stage in self.estimators_ for tree in stage - if tree.tree_.node_count > 1] + relevant_trees = [ + tree + for stage in self.estimators_ + for tree in stage + if tree.tree_.node_count > 1 + ] if not relevant_trees: # degenerate case where all trees have only one node return np.zeros(shape=self.n_features_in_, dtype=np.float64) @@ -704,8 +804,9 @@ def feature_importances_(self): tree.tree_.compute_feature_importances(normalize=False) for tree in relevant_trees ] - avg_feature_importances = np.mean(relevant_feature_importances, - axis=0, dtype=np.float64) + avg_feature_importances = np.mean( + relevant_feature_importances, axis=0, dtype=np.float64 + ) return avg_feature_importances / np.sum(avg_feature_importances) def _compute_partial_dependence_recursion(self, grid, target_features): @@ -728,20 +829,22 @@ def _compute_partial_dependence_recursion(self, grid, target_features): """ if self.init is not None: warnings.warn( - 'Using recursion method with a non-constant init predictor ' - 'will lead to incorrect partial dependence values. ' - 'Got init=%s.' % self.init, - UserWarning + "Using recursion method with a non-constant init predictor " + "will lead to incorrect partial dependence values. " + "Got init=%s." % self.init, + UserWarning, ) - grid = np.asarray(grid, dtype=DTYPE, order='C') + grid = np.asarray(grid, dtype=DTYPE, order="C") n_estimators, n_trees_per_stage = self.estimators_.shape - averaged_predictions = np.zeros((n_trees_per_stage, grid.shape[0]), - dtype=np.float64, order='C') + averaged_predictions = np.zeros( + (n_trees_per_stage, grid.shape[0]), dtype=np.float64, order="C" + ) for stage in range(n_estimators): for k in range(n_trees_per_stage): tree = self.estimators_[stage, k].tree_ - tree.compute_partial_dependence(grid, target_features, - averaged_predictions[k]) + tree.compute_partial_dependence( + grid, target_features, averaged_predictions[k] + ) averaged_predictions *= self.learning_rate return averaged_predictions @@ -1108,40 +1211,68 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): 0.913... """ - _SUPPORTED_LOSS = ('deviance', 'exponential') - - def __init__(self, *, loss='deviance', learning_rate=0.1, n_estimators=100, - subsample=1.0, criterion='friedman_mse', min_samples_split=2, - min_samples_leaf=1, min_weight_fraction_leaf=0., - max_depth=3, min_impurity_decrease=0., init=None, - random_state=None, max_features=None, verbose=0, - max_leaf_nodes=None, warm_start=False, - validation_fraction=0.1, n_iter_no_change=None, tol=1e-4, - ccp_alpha=0.0, monotonic_cst=None): + _SUPPORTED_LOSS = ("deviance", "exponential") + + def __init__( + self, + *, + loss="deviance", + learning_rate=0.1, + n_estimators=100, + subsample=1.0, + criterion="friedman_mse", + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_depth=3, + min_impurity_decrease=0.0, + init=None, + random_state=None, + max_features=None, + verbose=0, + max_leaf_nodes=None, + warm_start=False, + validation_fraction=0.1, + n_iter_no_change=None, + tol=1e-4, + ccp_alpha=0.0, + monotonic_cst=None, + ): super().__init__( - loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, - criterion=criterion, min_samples_split=min_samples_split, + loss=loss, + learning_rate=learning_rate, + n_estimators=n_estimators, + criterion=criterion, + min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf, min_weight_fraction_leaf=min_weight_fraction_leaf, - max_depth=max_depth, init=init, subsample=subsample, + max_depth=max_depth, + init=init, + subsample=subsample, max_features=max_features, - random_state=random_state, verbose=verbose, + random_state=random_state, + verbose=verbose, max_leaf_nodes=max_leaf_nodes, min_impurity_decrease=min_impurity_decrease, - warm_start=warm_start, validation_fraction=validation_fraction, - n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha, - monotonic_cst=monotonic_cst) + warm_start=warm_start, + validation_fraction=validation_fraction, + n_iter_no_change=n_iter_no_change, + tol=tol, + ccp_alpha=ccp_alpha, + monotonic_cst=monotonic_cst, + ) def _validate_y(self, y, sample_weight): check_classification_targets(y) self.classes_, y = np.unique(y, return_inverse=True) n_trim_classes = np.count_nonzero(np.bincount(y, sample_weight)) if n_trim_classes < 2: - raise ValueError("y contains %d class after sample_weight " - "trimmed classes with zero weights, while a " - "minimum of 2 classes are required." - % n_trim_classes) + raise ValueError( + "y contains %d class after sample_weight " + "trimmed classes with zero weights, while a " + "minimum of 2 classes are required." % n_trim_classes + ) self._n_classes = len(self.classes_) # expose n_classes_ attribute self.n_classes_ = self._n_classes @@ -1149,11 +1280,14 @@ def _validate_y(self, y, sample_weight): def _warn_mae_for_criterion(self): # TODO: This should raise an error from 1.1 - warnings.warn("criterion='mae' was deprecated in version 0.24 and " - "will be removed in version 1.1 (renaming of 0.26). Use " - "criterion='friedman_mse' or 'squared_error' instead, as" - " trees should use a squared error criterion in Gradient" - " Boosting.", FutureWarning) + warnings.warn( + "criterion='mae' was deprecated in version 0.24 and " + "will be removed in version 1.1 (renaming of 0.26). Use " + "criterion='friedman_mse' or 'squared_error' instead, as" + " trees should use a squared error criterion in Gradient" + " Boosting.", + FutureWarning, + ) def decision_function(self, X): """Compute the decision function of ``X``. @@ -1174,8 +1308,9 @@ def decision_function(self, X): :term:`classes_`. Regression and binary classification produce an array of shape (n_samples,). """ - X = self._validate_data(X, dtype=DTYPE, order="C", accept_sparse='csr', - reset=False) + X = self._validate_data( + X, dtype=DTYPE, order="C", accept_sparse="csr", reset=False + ) raw_predictions = self._raw_predict(X) if raw_predictions.shape[1] == 1: return raw_predictions.ravel() @@ -1221,8 +1356,7 @@ def predict(self, X): The predicted values. """ raw_predictions = self.decision_function(X) - encoded_labels = \ - self.loss_._raw_prediction_to_decision(raw_predictions) + encoded_labels = self.loss_._raw_prediction_to_decision(raw_predictions) return self.classes_.take(encoded_labels, axis=0) def staged_predict(self, X): @@ -1244,8 +1378,7 @@ def staged_predict(self, X): The predicted value of the input samples. """ for raw_predictions in self._staged_raw_predict(X): - encoded_labels = \ - self.loss_._raw_prediction_to_decision(raw_predictions) + encoded_labels = self.loss_._raw_prediction_to_decision(raw_predictions) yield self.classes_.take(encoded_labels, axis=0) def predict_proba(self, X): @@ -1275,8 +1408,9 @@ def predict_proba(self, X): except NotFittedError: raise except AttributeError as e: - raise AttributeError('loss=%r does not support predict_proba' % - self.loss) from e + raise AttributeError( + "loss=%r does not support predict_proba" % self.loss + ) from e def predict_log_proba(self, X): """Predict class log-probabilities for X. @@ -1326,8 +1460,9 @@ def staged_predict_proba(self, X): except NotFittedError: raise except AttributeError as e: - raise AttributeError('loss=%r does not support predict_proba' % - self.loss) from e + raise AttributeError( + "loss=%r does not support predict_proba" % self.loss + ) from e class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): @@ -1648,44 +1783,81 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): """ # TODO: remove "ls" in verion 1.2 - _SUPPORTED_LOSS = ("squared_error", 'ls', "absolute_error", 'lad', 'huber', - 'quantile') - - def __init__(self, *, loss="squared_error", learning_rate=0.1, - n_estimators=100, - subsample=1.0, criterion='friedman_mse', min_samples_split=2, - min_samples_leaf=1, min_weight_fraction_leaf=0., - max_depth=3, min_impurity_decrease=0., init=None, - random_state=None, max_features=None, alpha=0.9, verbose=0, - max_leaf_nodes=None, warm_start=False, - validation_fraction=0.1, n_iter_no_change=None, tol=1e-4, - ccp_alpha=0.0, monotonic_cst=None): + _SUPPORTED_LOSS = ( + "squared_error", + "ls", + "absolute_error", + "lad", + "huber", + "quantile", + ) + + def __init__( + self, + *, + loss="squared_error", + learning_rate=0.1, + n_estimators=100, + subsample=1.0, + criterion="friedman_mse", + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_depth=3, + min_impurity_decrease=0.0, + init=None, + random_state=None, + max_features=None, + alpha=0.9, + verbose=0, + max_leaf_nodes=None, + warm_start=False, + validation_fraction=0.1, + n_iter_no_change=None, + tol=1e-4, + ccp_alpha=0.0, + monotonic_cst=None, + ): super().__init__( - loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, - criterion=criterion, min_samples_split=min_samples_split, + loss=loss, + learning_rate=learning_rate, + n_estimators=n_estimators, + criterion=criterion, + min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf, min_weight_fraction_leaf=min_weight_fraction_leaf, - max_depth=max_depth, init=init, subsample=subsample, + max_depth=max_depth, + init=init, + subsample=subsample, max_features=max_features, min_impurity_decrease=min_impurity_decrease, - random_state=random_state, alpha=alpha, verbose=verbose, - max_leaf_nodes=max_leaf_nodes, warm_start=warm_start, + random_state=random_state, + alpha=alpha, + verbose=verbose, + max_leaf_nodes=max_leaf_nodes, + warm_start=warm_start, validation_fraction=validation_fraction, - n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha, - monotonic_cst=monotonic_cst) + n_iter_no_change=n_iter_no_change, + tol=tol, + ccp_alpha=ccp_alpha, + monotonic_cst=monotonic_cst, + ) def _validate_y(self, y, sample_weight=None): - if y.dtype.kind == 'O': + if y.dtype.kind == "O": y = y.astype(DOUBLE) return y def _warn_mae_for_criterion(self): # TODO: This should raise an error from 1.1 - warnings.warn("criterion='mae' was deprecated in version 0.24 and " - "will be removed in version 1.1 (renaming of 0.26). The " - "correct way of minimizing the absolute error is to use " - " loss='absolute_error' instead.", FutureWarning) + warnings.warn( + "criterion='mae' was deprecated in version 0.24 and " + "will be removed in version 1.1 (renaming of 0.26). The " + "correct way of minimizing the absolute error is to use " + " loss='absolute_error' instead.", + FutureWarning, + ) def predict(self, X): """Predict regression target for X. @@ -1702,8 +1874,9 @@ def predict(self, X): y : ndarray of shape (n_samples,) The predicted values. """ - X = self._validate_data(X, dtype=DTYPE, order="C", accept_sparse='csr', - reset=False) + X = self._validate_data( + X, dtype=DTYPE, order="C", accept_sparse="csr", reset=False + ) # In regression we can directly return the raw value from the trees. return self._raw_predict(X).ravel() @@ -1755,14 +1928,14 @@ def apply(self, X): # mypy error: Decorated property not supported @deprecated( # type: ignore "Attribute n_classes_ was deprecated " - "in version 0.24 and will be removed in 1.1 (renaming of 0.26).") + "in version 0.24 and will be removed in 1.1 (renaming of 0.26)." + ) @property def n_classes_(self): try: check_is_fitted(self) except NotFittedError as nfe: raise AttributeError( - "{} object has no n_classes_ attribute." - .format(self.__class__.__name__) + "{} object has no n_classes_ attribute.".format(self.__class__.__name__) ) from nfe return 1 diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 93d85e11cb836..19720f949e00c 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -47,10 +47,12 @@ from ._tree import ccp_pruning_path from . import _tree, _splitter, _criterion -__all__ = ["DecisionTreeClassifier", - "DecisionTreeRegressor", - "ExtraTreeClassifier", - "ExtraTreeRegressor"] +__all__ = [ + "DecisionTreeClassifier", + "DecisionTreeRegressor", + "ExtraTreeClassifier", + "ExtraTreeRegressor", +] # ============================================================================= @@ -60,21 +62,23 @@ DTYPE = _tree.DTYPE DOUBLE = _tree.DOUBLE -CRITERIA_CLF = {"gini": _criterion.Gini, - "entropy": _criterion.Entropy} +CRITERIA_CLF = {"gini": _criterion.Gini, "entropy": _criterion.Entropy} # TODO: Remove "mse" and "mae" in version 1.2. -CRITERIA_REG = {"squared_error": _criterion.MSE, - "mse": _criterion.MSE, - "friedman_mse": _criterion.FriedmanMSE, - "absolute_error": _criterion.MAE, - "mae": _criterion.MAE, - "poisson": _criterion.Poisson} - -DENSE_SPLITTERS = {"best": _splitter.BestSplitter, - "random": _splitter.RandomSplitter} - -SPARSE_SPLITTERS = {"best": _splitter.BestSparseSplitter, - "random": _splitter.RandomSparseSplitter} +CRITERIA_REG = { + "squared_error": _criterion.MSE, + "mse": _criterion.MSE, + "friedman_mse": _criterion.FriedmanMSE, + "absolute_error": _criterion.MAE, + "mae": _criterion.MAE, + "poisson": _criterion.Poisson, +} + +DENSE_SPLITTERS = {"best": _splitter.BestSplitter, "random": _splitter.RandomSplitter} + +SPARSE_SPLITTERS = { + "best": _splitter.BestSparseSplitter, + "random": _splitter.RandomSparseSplitter, +} # ============================================================================= # Base decision tree @@ -89,20 +93,23 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): """ @abstractmethod - def __init__(self, *, - criterion, - splitter, - max_depth, - min_samples_split, - min_samples_leaf, - min_weight_fraction_leaf, - max_features, - max_leaf_nodes, - random_state, - min_impurity_decrease, - class_weight=None, - ccp_alpha=0.0, - monotonic_cst=None): + def __init__( + self, + *, + criterion, + splitter, + max_depth, + min_samples_split, + min_samples_leaf, + min_weight_fraction_leaf, + max_features, + max_leaf_nodes, + random_state, + min_impurity_decrease, + class_weight=None, + ccp_alpha=0.0, + monotonic_cst=None, + ): self.criterion = criterion self.splitter = splitter @@ -143,8 +150,9 @@ def get_n_leaves(self): check_is_fitted(self) return self.tree_.n_leaves - def fit(self, X, y, sample_weight=None, check_input=True, - X_idx_sorted="deprecated"): + def fit( + self, X, y, sample_weight=None, check_input=True, X_idx_sorted="deprecated" + ): random_state = check_random_state(self.random_state) @@ -157,23 +165,28 @@ def fit(self, X, y, sample_weight=None, check_input=True, # csr. check_X_params = dict(dtype=DTYPE, accept_sparse="csc") check_y_params = dict(ensure_2d=False, dtype=None) - X, y = self._validate_data(X, y, - validate_separately=(check_X_params, - check_y_params)) + X, y = self._validate_data( + X, y, validate_separately=(check_X_params, check_y_params) + ) if issparse(X): X.sort_indices() if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: - raise ValueError("No support for np.int64 index based " - "sparse matrices") + raise ValueError( + "No support for np.int64 index based " "sparse matrices" + ) if self.criterion == "poisson": if np.any(y < 0): - raise ValueError("Some value(s) of y are negative which is" - " not allowed for Poisson regression.") + raise ValueError( + "Some value(s) of y are negative which is" + " not allowed for Poisson regression." + ) if np.sum(y) <= 0: - raise ValueError("Sum of y is not positive which is " - "necessary for Poisson regression.") + raise ValueError( + "Sum of y is not positive which is " + "necessary for Poisson regression." + ) # Determine output settings n_samples, self.n_features_in_ = X.shape @@ -201,15 +214,15 @@ def fit(self, X, y, sample_weight=None, check_input=True, y_encoded = np.zeros(y.shape, dtype=int) for k in range(self.n_outputs_): - classes_k, y_encoded[:, k] = np.unique(y[:, k], - return_inverse=True) + classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) self.classes_.append(classes_k) self.n_classes_.append(classes_k.shape[0]) y = y_encoded if self.class_weight is not None: expanded_class_weight = compute_sample_weight( - self.class_weight, y_original) + self.class_weight, y_original + ) self.n_classes_ = np.array(self.n_classes_, dtype=np.intp) @@ -217,37 +230,39 @@ def fit(self, X, y, sample_weight=None, check_input=True, y = np.ascontiguousarray(y, dtype=DOUBLE) # Check parameters - max_depth = (np.iinfo(np.int32).max if self.max_depth is None - else self.max_depth) - max_leaf_nodes = (-1 if self.max_leaf_nodes is None - else self.max_leaf_nodes) + max_depth = np.iinfo(np.int32).max if self.max_depth is None else self.max_depth + max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes if isinstance(self.min_samples_leaf, numbers.Integral): if not 1 <= self.min_samples_leaf: - raise ValueError("min_samples_leaf must be at least 1 " - "or in (0, 0.5], got %s" - % self.min_samples_leaf) + raise ValueError( + "min_samples_leaf must be at least 1 " + "or in (0, 0.5], got %s" % self.min_samples_leaf + ) min_samples_leaf = self.min_samples_leaf else: # float - if not 0. < self.min_samples_leaf <= 0.5: - raise ValueError("min_samples_leaf must be at least 1 " - "or in (0, 0.5], got %s" - % self.min_samples_leaf) + if not 0.0 < self.min_samples_leaf <= 0.5: + raise ValueError( + "min_samples_leaf must be at least 1 " + "or in (0, 0.5], got %s" % self.min_samples_leaf + ) min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples)) if isinstance(self.min_samples_split, numbers.Integral): if not 2 <= self.min_samples_split: - raise ValueError("min_samples_split must be an integer " - "greater than 1 or a float in (0.0, 1.0]; " - "got the integer %s" - % self.min_samples_split) + raise ValueError( + "min_samples_split must be an integer " + "greater than 1 or a float in (0.0, 1.0]; " + "got the integer %s" % self.min_samples_split + ) min_samples_split = self.min_samples_split else: # float - if not 0. < self.min_samples_split <= 1.: - raise ValueError("min_samples_split must be an integer " - "greater than 1 or a float in (0.0, 1.0]; " - "got the float %s" - % self.min_samples_split) + if not 0.0 < self.min_samples_split <= 1.0: + raise ValueError( + "min_samples_split must be an integer " + "greater than 1 or a float in (0.0, 1.0]; " + "got the float %s" % self.min_samples_split + ) min_samples_split = int(ceil(self.min_samples_split * n_samples)) min_samples_split = max(2, min_samples_split) @@ -264,25 +279,28 @@ def fit(self, X, y, sample_weight=None, check_input=True, elif self.max_features == "log2": max_features = max(1, int(np.log2(self.n_features_in_))) else: - raise ValueError("Invalid value for max_features. " - "Allowed string values are 'auto', " - "'sqrt' or 'log2'.") + raise ValueError( + "Invalid value for max_features. " + "Allowed string values are 'auto', " + "'sqrt' or 'log2'." + ) elif self.max_features is None: max_features = self.n_features_in_ elif isinstance(self.max_features, numbers.Integral): max_features = self.max_features else: # float if self.max_features > 0.0: - max_features = ( - max(1, int(self.max_features * self.n_features_in_))) + max_features = max(1, int(self.max_features * self.n_features_in_)) else: max_features = 0 self.max_features_ = max_features if len(y) != n_samples: - raise ValueError("Number of labels=%d does not match " - "number of samples=%d" % (len(y), n_samples)) + raise ValueError( + "Number of labels=%d does not match " + "number of samples=%d" % (len(y), n_samples) + ) if not 0 <= self.min_weight_fraction_leaf <= 0.5: raise ValueError("min_weight_fraction_leaf must in [0, 0.5]") if max_depth <= 0: @@ -290,11 +308,15 @@ def fit(self, X, y, sample_weight=None, check_input=True, if not (0 < max_features <= self.n_features_in_): raise ValueError("max_features must be in (0, n_features]") if not isinstance(max_leaf_nodes, numbers.Integral): - raise ValueError("max_leaf_nodes must be integral number but was " - "%r" % max_leaf_nodes) + raise ValueError( + "max_leaf_nodes must be integral number but was " "%r" % max_leaf_nodes + ) if -1 < max_leaf_nodes < 2: - raise ValueError(("max_leaf_nodes {0} must be either None " - "or larger than 1").format(max_leaf_nodes)) + raise ValueError( + ("max_leaf_nodes {0} must be either None " "or larger than 1").format( + max_leaf_nodes + ) + ) if sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X, DOUBLE) @@ -307,15 +329,14 @@ def fit(self, X, y, sample_weight=None, check_input=True, # Set min_weight_leaf from min_weight_fraction_leaf if sample_weight is None: - min_weight_leaf = (self.min_weight_fraction_leaf * - n_samples) + min_weight_leaf = self.min_weight_fraction_leaf * n_samples else: - min_weight_leaf = (self.min_weight_fraction_leaf * - np.sum(sample_weight)) + min_weight_leaf = self.min_weight_fraction_leaf * np.sum(sample_weight) - if self.min_impurity_decrease < 0.: - raise ValueError("min_impurity_decrease must be greater than " - "or equal to 0") + if self.min_impurity_decrease < 0.0: + raise ValueError( + "min_impurity_decrease must be greater than " "or equal to 0" + ) # TODO: Remove in 1.1 if X_idx_sorted != "deprecated": @@ -324,32 +345,32 @@ def fit(self, X, y, sample_weight=None, check_input=True, "effect. It will be removed in 1.1 (renaming of 0.26). You " "can suppress this warning by not passing any value to the " "'X_idx_sorted' parameter.", - FutureWarning + FutureWarning, ) # Build tree criterion = self.criterion if not isinstance(criterion, Criterion): if is_classification: - criterion = CRITERIA_CLF[self.criterion](self.n_outputs_, - self.n_classes_) + criterion = CRITERIA_CLF[self.criterion]( + self.n_outputs_, self.n_classes_ + ) else: - criterion = CRITERIA_REG[self.criterion](self.n_outputs_, - n_samples) + criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples) # TODO: Remove in v1.2 if self.criterion == "mse": warnings.warn( "Criterion 'mse' was deprecated in v1.0 and will be " "removed in version 1.2. Use `criterion='squared_error'` " "which is equivalent.", - FutureWarning + FutureWarning, ) elif self.criterion == "mae": warnings.warn( "Criterion 'mae' was deprecated in v1.0 and will be " "removed in version 1.2. Use `criterion='absolute_error'` " "which is equivalent.", - FutureWarning + FutureWarning, ) else: # Make a deepcopy in case the criterion has mutable attributes that @@ -360,59 +381,68 @@ def fit(self, X, y, sample_weight=None, check_input=True, splitter = self.splitter if self.monotonic_cst is None: - monotonic_cst = np.full(shape=X.shape[1], - fill_value=0, - dtype=np.int32) + monotonic_cst = np.full(shape=X.shape[1], fill_value=0, dtype=np.int32) else: if self.n_outputs_ > 1: - raise ValueError("Monotonic constraints are not supported with" - " multiple output") + raise ValueError( + "Monotonic constraints are not supported with" " multiple output" + ) if is_classifier(self) and self.n_classes_[0] > 2: - raise ValueError("Monotonic constraints are not supported with" - " multiclass classification") + raise ValueError( + "Monotonic constraints are not supported with" + " multiclass classification" + ) monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) if monotonic_cst.shape[0] != X.shape[1]: raise ValueError( "monotonic_cst has shape {} but the input data " - "X has {} features.".format( - monotonic_cst.shape[0], X.shape[1] - ) + "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) ) if np.any(monotonic_cst < -1) or np.any(monotonic_cst > 1): raise ValueError( - "monotonic_cst must be None or an array-like of -1, 0 or 1.") + "monotonic_cst must be None or an array-like of -1, 0 or 1." + ) if not isinstance(self.splitter, Splitter): - splitter = SPLITTERS[self.splitter](criterion, - self.max_features_, - min_samples_leaf, - min_weight_leaf, - random_state, - monotonic_cst) + splitter = SPLITTERS[self.splitter]( + criterion, + self.max_features_, + min_samples_leaf, + min_weight_leaf, + random_state, + monotonic_cst, + ) if is_classifier(self): - self.tree_ = Tree(self.n_features_in_, - self.n_classes_, self.n_outputs_) + self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_) else: - self.tree_ = Tree(self.n_features_in_, - # TODO: tree should't need this in this case - np.array([1] * self.n_outputs_, dtype=np.intp), - self.n_outputs_) + self.tree_ = Tree( + self.n_features_in_, + # TODO: tree should't need this in this case + np.array([1] * self.n_outputs_, dtype=np.intp), + self.n_outputs_, + ) # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: - builder = DepthFirstTreeBuilder(splitter, min_samples_split, - min_samples_leaf, - min_weight_leaf, - max_depth, - self.min_impurity_decrease) + builder = DepthFirstTreeBuilder( + splitter, + min_samples_split, + min_samples_leaf, + min_weight_leaf, + max_depth, + self.min_impurity_decrease, + ) else: - builder = BestFirstTreeBuilder(splitter, min_samples_split, - min_samples_leaf, - min_weight_leaf, - max_depth, - max_leaf_nodes, - self.min_impurity_decrease) + builder = BestFirstTreeBuilder( + splitter, + min_samples_split, + min_samples_leaf, + min_weight_leaf, + max_depth, + max_leaf_nodes, + self.min_impurity_decrease, + ) builder.build(self.tree_, X, y, sample_weight) @@ -427,12 +457,13 @@ def fit(self, X, y, sample_weight=None, check_input=True, def _validate_X_predict(self, X, check_input): """Validate the training data on predict (probabilities).""" if check_input: - X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", - reset=False) - if issparse(X) and (X.indices.dtype != np.intc or - X.indptr.dtype != np.intc): - raise ValueError("No support for np.int64 index based " - "sparse matrices") + X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False) + if issparse(X) and ( + X.indices.dtype != np.intc or X.indptr.dtype != np.intc + ): + raise ValueError( + "No support for np.int64 index based " "sparse matrices" + ) else: # The number of features is checked regardless of `check_input` self._check_n_features(X, reset=False) @@ -473,12 +504,11 @@ def predict(self, X, check_input=True): else: class_type = self.classes_[0].dtype - predictions = np.zeros((n_samples, self.n_outputs_), - dtype=class_type) + predictions = np.zeros((n_samples, self.n_outputs_), dtype=class_type) for k in range(self.n_outputs_): predictions[:, k] = self.classes_[k].take( - np.argmax(proba[:, k], axis=1), - axis=0) + np.argmax(proba[:, k], axis=1), axis=0 + ) return predictions @@ -558,10 +588,12 @@ def _prune_tree(self): n_classes = np.atleast_1d(self.n_classes_) pruned_tree = Tree(self.n_features_in_, n_classes, self.n_outputs_) else: - pruned_tree = Tree(self.n_features_in_, - # TODO: the tree shouldn't need this param - np.array([1] * self.n_outputs_, dtype=np.intp), - self.n_outputs_) + pruned_tree = Tree( + self.n_features_in_, + # TODO: the tree shouldn't need this param + np.array([1] * self.n_outputs_, dtype=np.intp), + self.n_outputs_, + ) _build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha) self.tree_ = pruned_tree @@ -632,6 +664,7 @@ def feature_importances_(self): # Public estimators # ============================================================================= + class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): """A decision tree classifier. @@ -859,20 +892,24 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): array([ 1. , 0.93..., 0.86..., 0.93..., 0.93..., 0.93..., 0.93..., 1. , 0.93..., 1. ]) """ - def __init__(self, *, - criterion="gini", - splitter="best", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0., - max_features=None, - random_state=None, - max_leaf_nodes=None, - min_impurity_decrease=0., - class_weight=None, - ccp_alpha=0.0, - monotonic_cst=None): + + def __init__( + self, + *, + criterion="gini", + splitter="best", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features=None, + random_state=None, + max_leaf_nodes=None, + min_impurity_decrease=0.0, + class_weight=None, + ccp_alpha=0.0, + monotonic_cst=None, + ): super().__init__( criterion=criterion, @@ -887,10 +924,12 @@ def __init__(self, *, random_state=random_state, min_impurity_decrease=min_impurity_decrease, monotonic_cst=monotonic_cst, - ccp_alpha=ccp_alpha) + ccp_alpha=ccp_alpha, + ) - def fit(self, X, y, sample_weight=None, check_input=True, - X_idx_sorted="deprecated"): + def fit( + self, X, y, sample_weight=None, check_input=True, X_idx_sorted="deprecated" + ): """Build a decision tree classifier from the training set (X, y). Parameters @@ -927,10 +966,12 @@ def fit(self, X, y, sample_weight=None, check_input=True, """ super().fit( - X, y, + X, + y, sample_weight=sample_weight, check_input=check_input, - X_idx_sorted=X_idx_sorted) + X_idx_sorted=X_idx_sorted, + ) return self def predict_proba(self, X, check_input=True): @@ -962,7 +1003,7 @@ class in a leaf. proba = self.tree_.predict(X) if self.n_outputs_ == 1: - proba = proba[:, :self.n_classes_] + proba = proba[:, : self.n_classes_] normalizer = proba.sum(axis=1)[:, np.newaxis] normalizer[normalizer == 0.0] = 1.0 proba /= normalizer @@ -973,7 +1014,7 @@ class in a leaf. all_proba = [] for k in range(self.n_outputs_): - proba_k = proba[:, k, :self.n_classes_[k]] + proba_k = proba[:, k, : self.n_classes_[k]] normalizer = proba_k.sum(axis=1)[:, np.newaxis] normalizer[normalizer == 0.0] = 1.0 proba_k /= normalizer @@ -1011,7 +1052,8 @@ def predict_log_proba(self, X): @deprecated( # type: ignore "The attribute 'n_features_' is deprecated in 1.0 and will be removed " - "in 1.2. Use 'n_features_in_' instead.") + "in 1.2. Use 'n_features_in_' instead." + ) @property def n_features_(self): return self.n_features_in_ @@ -1230,19 +1272,23 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): array([-0.39..., -0.46..., 0.02..., 0.06..., -0.50..., 0.16..., 0.11..., -0.73..., -0.30..., -0.00...]) """ - def __init__(self, *, - criterion="squared_error", - splitter="best", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0., - max_features=None, - random_state=None, - max_leaf_nodes=None, - min_impurity_decrease=0., - ccp_alpha=0.0, - monotonic_cst=None): + + def __init__( + self, + *, + criterion="squared_error", + splitter="best", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features=None, + random_state=None, + max_leaf_nodes=None, + min_impurity_decrease=0.0, + ccp_alpha=0.0, + monotonic_cst=None, + ): super().__init__( criterion=criterion, splitter=splitter, @@ -1255,10 +1301,12 @@ def __init__(self, *, random_state=random_state, min_impurity_decrease=min_impurity_decrease, ccp_alpha=ccp_alpha, - monotonic_cst=monotonic_cst) + monotonic_cst=monotonic_cst, + ) - def fit(self, X, y, sample_weight=None, check_input=True, - X_idx_sorted="deprecated"): + def fit( + self, X, y, sample_weight=None, check_input=True, X_idx_sorted="deprecated" + ): """Build a decision tree regressor from the training set (X, y). Parameters @@ -1294,10 +1342,12 @@ def fit(self, X, y, sample_weight=None, check_input=True, """ super().fit( - X, y, + X, + y, sample_weight=sample_weight, check_input=check_input, - X_idx_sorted=X_idx_sorted) + X_idx_sorted=X_idx_sorted, + ) return self def _compute_partial_dependence_recursion(self, grid, target_features): @@ -1317,17 +1367,20 @@ def _compute_partial_dependence_recursion(self, grid, target_features): averaged_predictions : ndarray of shape (n_samples,) The value of the partial dependence function on each grid point. """ - grid = np.asarray(grid, dtype=DTYPE, order='C') - averaged_predictions = np.zeros(shape=grid.shape[0], - dtype=np.float64, order='C') + grid = np.asarray(grid, dtype=DTYPE, order="C") + averaged_predictions = np.zeros( + shape=grid.shape[0], dtype=np.float64, order="C" + ) self.tree_.compute_partial_dependence( - grid, target_features, averaged_predictions) + grid, target_features, averaged_predictions + ) return averaged_predictions @deprecated( # type: ignore "The attribute 'n_features_' is deprecated in 1.0 and will be removed " - "in 1.2. Use 'n_features_in_' instead.") + "in 1.2. Use 'n_features_in_' instead." + ) @property def n_features_(self): return self.n_features_in_ @@ -1552,20 +1605,24 @@ class ExtraTreeClassifier(DecisionTreeClassifier): >>> cls.score(X_test, y_test) 0.8947... """ - def __init__(self, *, - criterion="gini", - splitter="random", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0., - max_features="auto", - random_state=None, - max_leaf_nodes=None, - min_impurity_decrease=0., - class_weight=None, - ccp_alpha=0.0, - monotonic_cst=None): + + def __init__( + self, + *, + criterion="gini", + splitter="random", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features="auto", + random_state=None, + max_leaf_nodes=None, + min_impurity_decrease=0.0, + class_weight=None, + ccp_alpha=0.0, + monotonic_cst=None, + ): super().__init__( criterion=criterion, @@ -1580,7 +1637,8 @@ def __init__(self, *, min_impurity_decrease=min_impurity_decrease, random_state=random_state, ccp_alpha=ccp_alpha, - monotonic_cst=monotonic_cst) + monotonic_cst=monotonic_cst, + ) class ExtraTreeRegressor(DecisionTreeRegressor): @@ -1786,19 +1844,23 @@ class ExtraTreeRegressor(DecisionTreeRegressor): >>> reg.score(X_test, y_test) 0.33... """ - def __init__(self, *, - criterion="squared_error", - splitter="random", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0., - max_features="auto", - random_state=None, - min_impurity_decrease=0., - max_leaf_nodes=None, - ccp_alpha=0.0, - monotonic_cst=None): + + def __init__( + self, + *, + criterion="squared_error", + splitter="random", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features="auto", + random_state=None, + min_impurity_decrease=0.0, + max_leaf_nodes=None, + ccp_alpha=0.0, + monotonic_cst=None, + ): super().__init__( criterion=criterion, splitter=splitter, @@ -1811,4 +1873,5 @@ def __init__(self, *, min_impurity_decrease=min_impurity_decrease, random_state=random_state, ccp_alpha=ccp_alpha, - monotonic_cst=monotonic_cst) + monotonic_cst=monotonic_cst, + ) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index f23b84baddb6e..4cd73da08a709 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -5,8 +5,8 @@ from sklearn.tree.tests.test_tree import REG_TREES, CLF_TREES -@pytest.mark.parametrize('seed', range(4)) -@pytest.mark.parametrize('depth_first', (True, False)) +@pytest.mark.parametrize("seed", range(4)) +@pytest.mark.parametrize("depth_first", (True, False)) def test_montonic_constraints(seed, depth_first): n_samples = 1000 n_samples_train = 900 @@ -29,8 +29,11 @@ def test_montonic_constraints(seed, depth_first): if depth_first: est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) else: - est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst, - max_leaf_nodes=n_samples_train) + est = TreeRegressor( + max_depth=None, + monotonic_cst=monotonic_cst, + max_leaf_nodes=n_samples_train, + ) if hasattr(est, "random_state"): est.set_params(**{"random_state": seed}) est.fit(X_train, y_train) @@ -49,8 +52,11 @@ def test_montonic_constraints(seed, depth_first): if depth_first: est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) else: - est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst, - max_leaf_nodes=n_samples_train) + est = TreeClassifier( + max_depth=None, + monotonic_cst=monotonic_cst, + max_leaf_nodes=n_samples_train, + ) if hasattr(est, "random_state"): est.set_params(**{"random_state": seed}) est.fit(X_train, y_train) @@ -74,13 +80,15 @@ def test_multiclass_raises(): monotonic_cst[0] = -1 monotonic_cst[1] = 1 for name, TreeClassifier in CLF_TREES.items(): - est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst, - random_state=0) + est = TreeClassifier( + max_depth=None, monotonic_cst=monotonic_cst, random_state=0 + ) if hasattr(est, "random_state"): est.set_params(**{"random_state": 0}) - msg = "Monotonic constraints are not supported with multiclass " \ - "classification" + msg = ( + "Monotonic constraints are not supported with multiclass " "classification" + ) with pytest.raises(ValueError, match=msg): est.fit(X, y) @@ -110,10 +118,10 @@ def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps): assert (monotonic_sign * np.diff(y_pred_grid) >= 0.0).all() -@pytest.mark.parametrize('monotonic_sign', (-1, 1)) -@pytest.mark.parametrize('splitter', ("best", "random")) -@pytest.mark.parametrize('depth_first', (True, False)) -@pytest.mark.parametrize('seed', range(4)) +@pytest.mark.parametrize("monotonic_sign", (-1, 1)) +@pytest.mark.parametrize("splitter", ("best", "random")) +@pytest.mark.parametrize("depth_first", (True, False)) +@pytest.mark.parametrize("seed", range(4)) def test_1d_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): # Adaptation from test_nodes_values in test_montonic_constraints.py # Build a single tree with only one feature, and make sure the nodes @@ -138,15 +146,17 @@ def test_1d_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): if depth_first: # No max_leaf_nodes, default depth first tree builder - clf = DecisionTreeRegressor(splitter=splitter, - monotonic_cst=[monotonic_sign], - random_state=seed) + clf = DecisionTreeRegressor( + splitter=splitter, monotonic_cst=[monotonic_sign], random_state=seed + ) else: # max_leaf_nodes triggers best first tree builder - clf = DecisionTreeRegressor(splitter=splitter, - monotonic_cst=[monotonic_sign], - max_leaf_nodes=n_samples, - random_state=seed) + clf = DecisionTreeRegressor( + splitter=splitter, + monotonic_cst=[monotonic_sign], + max_leaf_nodes=n_samples, + random_state=seed, + ) clf.fit(X, y) assert_1d_reg_tree_children_monotonic_bounded(clf.tree_, monotonic_sign) @@ -175,8 +185,10 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): else: # constrained feature # check montonicity - assert monotonic_cst[feature] * tree_.value[i_left] \ - <= monotonic_cst[feature] * tree_.value[i_right] + assert ( + monotonic_cst[feature] * tree_.value[i_left] + <= monotonic_cst[feature] * tree_.value[i_right] + ) # update and propagate bounds down the tree if monotonic_cst[feature] == 1: upper_bound[i_left] = tree_.value[i] @@ -190,10 +202,10 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): lower_bound[i_right] = lower_bound[i] -@pytest.mark.parametrize('monotonic_sign', (-1, 1)) -@pytest.mark.parametrize('splitter', ("best", "random")) -@pytest.mark.parametrize('depth_first', (True, False)) -@pytest.mark.parametrize('seed', range(4)) +@pytest.mark.parametrize("monotonic_sign", (-1, 1)) +@pytest.mark.parametrize("splitter", ("best", "random")) +@pytest.mark.parametrize("depth_first", (True, False)) +@pytest.mark.parametrize("seed", range(4)) def test_nd_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): # Build tree with several features, and make sure the nodes # values respect the monotonic constraints. @@ -224,14 +236,16 @@ def test_nd_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): if depth_first: # No max_leaf_nodes, default depth first tree builder - clf = DecisionTreeRegressor(splitter=splitter, - monotonic_cst=monotonic_cst, - random_state=seed) + clf = DecisionTreeRegressor( + splitter=splitter, monotonic_cst=monotonic_cst, random_state=seed + ) else: # max_leaf_nodes triggers best first tree builder - clf = DecisionTreeRegressor(splitter=splitter, - monotonic_cst=monotonic_cst, - max_leaf_nodes=n_samples, - random_state=seed) + clf = DecisionTreeRegressor( + splitter=splitter, + monotonic_cst=monotonic_cst, + max_leaf_nodes=n_samples, + random_state=seed, + ) clf.fit(X, y) assert_nd_reg_tree_children_monotonic_bounded(clf.tree_, monotonic_cst) From 66032e5131994499c99368328c337374d934ecca Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 22 Jun 2021 00:18:50 +0200 Subject: [PATCH 069/174] Make the monotonicity constraints hold on the positive class for binary classifications --- sklearn/ensemble/_forest.py | 12 +++++++--- sklearn/tree/_classes.py | 28 +++++++++++++++-------- sklearn/tree/tests/test_monotonic_tree.py | 18 ++++++++++----- 3 files changed, 39 insertions(+), 19 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 457760f5aca0c..6a3223e1e162a 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1233,7 +1233,9 @@ class RandomForestClassifier(ForestClassifier): monotonic_cst : array-like of int of shape (n_features), default=None Indicates the monotonic constraint to enforce on each feature. -1, 1 and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. + constraint and no constraint. The constraints are only valid for + binary classifications and hold over the probability of the positive + class. Attributes ---------- @@ -1890,7 +1892,9 @@ class ExtraTreesClassifier(ForestClassifier): monotonic_cst : array-like of int of shape (n_features), default=None Indicates the monotonic constraint to enforce on each feature. -1, 1 and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. + constraint and no constraint. The constraints are only valid for + binary classifications and hold over the probability of the positive + class. Attributes ---------- @@ -2455,7 +2459,9 @@ class RandomTreesEmbedding(BaseForest): monotonic_cst : array-like of int of shape (n_features), default=None Indicates the monotonic constraint to enforce on each feature. -1, 1 and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. + constraint and no constraint. The constraints are only valid for + binary classifications and hold over the probability of the positive + class. Attributes ---------- diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 19720f949e00c..9f8ef0255d7ad 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -173,7 +173,7 @@ def fit( if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: raise ValueError( - "No support for np.int64 index based " "sparse matrices" + "No support for np.int64 index based sparse matrices" ) if self.criterion == "poisson": @@ -385,14 +385,18 @@ def fit( else: if self.n_outputs_ > 1: raise ValueError( - "Monotonic constraints are not supported with" " multiple output" + "Monotonic constraints are not supported with multiple output" ) - if is_classifier(self) and self.n_classes_[0] > 2: - raise ValueError( - "Monotonic constraints are not supported with" - " multiclass classification" - ) - monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) + if is_classifier(self): + if self.n_classes_[0] > 2: + raise ValueError( + "Monotonic constraints are not supported with multiclass classification" + ) + # Imposing the constraint on the probability of the positive class + monotonic_cst = -np.asarray(self.monotonic_cst, dtype=np.int32) + else: + monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) + if monotonic_cst.shape[0] != X.shape[1]: raise ValueError( @@ -800,7 +804,9 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): monotonic_cst : array-like of int of shape (n_features), default=None Indicates the monotonic constraint to enforce on each feature. -1, 1 and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. + constraint and no constraint. The constraints are only valid for + binary classifications and hold over the probability of the positive + class. Attributes ---------- @@ -1522,7 +1528,9 @@ class ExtraTreeClassifier(DecisionTreeClassifier): monotonic_cst : array-like of int of shape (n_features), default=None Indicates the monotonic constraint to enforce on each feature. -1, 1 and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. + constraint and no constraint. The constraints are only valid for + binary classifications and hold over the probability of the positive + class. Attributes ---------- diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 4cd73da08a709..6d02cc09a7ca5 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -3,7 +3,7 @@ from sklearn import datasets from sklearn.tree import DecisionTreeRegressor from sklearn.tree.tests.test_tree import REG_TREES, CLF_TREES - +from sklearn.ensemble import ExtraTreesClassifier, RandomForestClassifier @pytest.mark.parametrize("seed", range(4)) @pytest.mark.parametrize("depth_first", (True, False)) @@ -25,7 +25,7 @@ def test_montonic_constraints(seed, depth_first): monotonic_cst[0] = 1 monotonic_cst[1] = -1 - for name, TreeRegressor in REG_TREES.items(): + for _, TreeRegressor in REG_TREES.items(): if depth_first: est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) else: @@ -48,7 +48,11 @@ def test_montonic_constraints(seed, depth_first): # y_decr should always be lower than y assert np.all(y_decr <= y) - for name, TreeClassifier in CLF_TREES.items(): + classifiers = CLF_TREES.copy() + # TODO: GradientBoostingClassifier + classifiers.update({"RandomForestClassifier": RandomForestClassifier, + "ExtraTreesClassifier": ExtraTreesClassifier}) + for _, TreeClassifier in classifiers.items(): if depth_first: est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) else: @@ -59,16 +63,18 @@ def test_montonic_constraints(seed, depth_first): ) if hasattr(est, "random_state"): est.set_params(**{"random_state": seed}) + if hasattr(est, "n_estimators"): + est.set_params(**{"n_estimators": 10}) est.fit(X_train, y_train) - y = est.predict_proba(X_test)[:, 0] + y = est.predict_proba(X_test)[:, 1] # increasing constraint - y_incr = est.predict_proba(X_test_incr)[:, 0] + y_incr = est.predict_proba(X_test_incr)[:, 1] # y_incr should always be greater than y assert np.all(y_incr >= y) # decreasing constraint - y_decr = est.predict_proba(X_test_decr)[:, 0] + y_decr = est.predict_proba(X_test_decr)[:, 1] # y_decr should always be lower than y assert np.all(y_decr <= y) From e6264d31a302ef928a6640ba5dc025c84be67771 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 24 Jun 2021 14:26:54 +0200 Subject: [PATCH 070/174] Adapt gradient boosting to disable modifying leaves (line search) and breaking the monotonicity constraints --- sklearn/ensemble/_gb.py | 7 +++++-- sklearn/ensemble/_gb_losses.py | 31 ++++++++++++++++++++----------- sklearn/tree/_classes.py | 1 - 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index ecf51091aeeb8..ec8915db11e6e 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -249,6 +249,7 @@ def _fit_stage( tree.fit(X, residual, sample_weight=sample_weight, check_input=False) # update tree leaves + perform_line_search = self.monotonic_cst is None loss.update_terminal_regions( tree.tree_, X, @@ -259,8 +260,8 @@ def _fit_stage( sample_mask, learning_rate=self.learning_rate, k=k, + perform_line_search=perform_line_search ) - # add tree to ensemble self.estimators_[i, k] = tree @@ -1093,7 +1094,9 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): monotonic_cst : array-like of int of shape (n_features), default=None Indicates the monotonic constraint to enforce on each feature. -1, 1 and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. + constraint and no constraint. The constraints are only valid for + binary classifications and hold over the probability of the positive + class. Attributes ---------- diff --git a/sklearn/ensemble/_gb_losses.py b/sklearn/ensemble/_gb_losses.py index 95090f235132b..8bbe726dab70f 100644 --- a/sklearn/ensemble/_gb_losses.py +++ b/sklearn/ensemble/_gb_losses.py @@ -80,6 +80,7 @@ def update_terminal_regions( sample_mask, learning_rate=0.1, k=0, + perform_line_search=True, ): """Update the terminal regions (=leaves) of the given tree and updates the current predictions of the model. Traverses tree @@ -107,6 +108,9 @@ def update_terminal_regions( ``learning_rate``. k : int, default=0 The index of the estimator being updated. + perform_line_search : bool, default=True + Whether line search must be performed. Monotonic constraints + requires not performing line search. """ # compute leaf for each sample in ``X``. @@ -117,17 +121,18 @@ def update_terminal_regions( masked_terminal_regions[~sample_mask] = -1 # update each leaf (= perform line search) - for leaf in np.where(tree.children_left == TREE_LEAF)[0]: - self._update_terminal_region( - tree, - masked_terminal_regions, - leaf, - X, - y, - residual, - raw_predictions[:, k], - sample_weight, - ) + if perform_line_search: + for leaf in np.where(tree.children_left == TREE_LEAF)[0]: + self._update_terminal_region( + tree, + masked_terminal_regions, + leaf, + X, + y, + residual, + raw_predictions[:, k], + sample_weight, + ) # update predictions (both in-bag and out-of-bag) raw_predictions[:, k] += learning_rate * tree.value[:, 0, 0].take( @@ -256,6 +261,7 @@ def update_terminal_regions( sample_mask, learning_rate=0.1, k=0, + perform_line_search=True, ): """Least squares does not need to update terminal regions. @@ -283,6 +289,9 @@ def update_terminal_regions( ``learning_rate``. k : int, default=0 The index of the estimator being updated. + perform_line_search : bool, default=True + Whether line search must be performed. Monotonic constraints + requires not performing line search. """ # update predictions raw_predictions[:, k] += learning_rate * tree.predict(X).ravel() diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 9f8ef0255d7ad..c7670d162a03b 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -397,7 +397,6 @@ def fit( else: monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) - if monotonic_cst.shape[0] != X.shape[1]: raise ValueError( "monotonic_cst has shape {} but the input data " From d196acee2173ba6a206ef79207e1b2874657bf6b Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 24 Jun 2021 15:08:50 +0200 Subject: [PATCH 071/174] Add tests for gradient boosting and removing tests for forests --- sklearn/tree/tests/test_monotonic_tree.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 6d02cc09a7ca5..6650b14d4aa8f 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -3,7 +3,8 @@ from sklearn import datasets from sklearn.tree import DecisionTreeRegressor from sklearn.tree.tests.test_tree import REG_TREES, CLF_TREES -from sklearn.ensemble import ExtraTreesClassifier, RandomForestClassifier +from sklearn.ensemble import GradientBoostingClassifier +from sklearn.ensemble import GradientBoostingRegressor @pytest.mark.parametrize("seed", range(4)) @pytest.mark.parametrize("depth_first", (True, False)) @@ -25,7 +26,12 @@ def test_montonic_constraints(seed, depth_first): monotonic_cst[0] = 1 monotonic_cst[1] = -1 - for _, TreeRegressor in REG_TREES.items(): + regressors = REG_TREES.copy() + regressors.update({ + "GradientBoostingRegressor": GradientBoostingRegressor + }) + + for name, TreeRegressor in regressors.items(): if depth_first: est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) else: @@ -36,6 +42,8 @@ def test_montonic_constraints(seed, depth_first): ) if hasattr(est, "random_state"): est.set_params(**{"random_state": seed}) + if hasattr(est, "n_estimators"): + est.set_params(**{"n_estimators": 5}) est.fit(X_train, y_train) y = est.predict(X_test) # increasing constraint @@ -49,10 +57,10 @@ def test_montonic_constraints(seed, depth_first): assert np.all(y_decr <= y) classifiers = CLF_TREES.copy() - # TODO: GradientBoostingClassifier - classifiers.update({"RandomForestClassifier": RandomForestClassifier, - "ExtraTreesClassifier": ExtraTreesClassifier}) - for _, TreeClassifier in classifiers.items(): + classifiers.update({ + "GradientBoostingClassifier": GradientBoostingClassifier + }) + for name, TreeClassifier in classifiers.items(): if depth_first: est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) else: @@ -64,7 +72,7 @@ def test_montonic_constraints(seed, depth_first): if hasattr(est, "random_state"): est.set_params(**{"random_state": seed}) if hasattr(est, "n_estimators"): - est.set_params(**{"n_estimators": 10}) + est.set_params(**{"n_estimators": 5}) est.fit(X_train, y_train) y = est.predict_proba(X_test)[:, 1] From 250996a6baa1793de0912b63618e44451568a6c1 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 24 Jun 2021 15:56:13 +0200 Subject: [PATCH 072/174] Format --- sklearn/ensemble/_gb.py | 2 +- sklearn/tree/_classes.py | 3 ++- sklearn/tree/tests/test_monotonic_tree.py | 9 +++------ 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index ec8915db11e6e..ac294cfee373b 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -260,7 +260,7 @@ def _fit_stage( sample_mask, learning_rate=self.learning_rate, k=k, - perform_line_search=perform_line_search + perform_line_search=perform_line_search, ) # add tree to ensemble self.estimators_[i, k] = tree diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index c7670d162a03b..934f36a59a40e 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -390,7 +390,8 @@ def fit( if is_classifier(self): if self.n_classes_[0] > 2: raise ValueError( - "Monotonic constraints are not supported with multiclass classification" + "Monotonic constraints are not supported with multiclass " + "classification" ) # Imposing the constraint on the probability of the positive class monotonic_cst = -np.asarray(self.monotonic_cst, dtype=np.int32) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 6650b14d4aa8f..10a387a6aee9c 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -6,6 +6,7 @@ from sklearn.ensemble import GradientBoostingClassifier from sklearn.ensemble import GradientBoostingRegressor + @pytest.mark.parametrize("seed", range(4)) @pytest.mark.parametrize("depth_first", (True, False)) def test_montonic_constraints(seed, depth_first): @@ -27,9 +28,7 @@ def test_montonic_constraints(seed, depth_first): monotonic_cst[1] = -1 regressors = REG_TREES.copy() - regressors.update({ - "GradientBoostingRegressor": GradientBoostingRegressor - }) + regressors.update({"GradientBoostingRegressor": GradientBoostingRegressor}) for name, TreeRegressor in regressors.items(): if depth_first: @@ -57,9 +56,7 @@ def test_montonic_constraints(seed, depth_first): assert np.all(y_decr <= y) classifiers = CLF_TREES.copy() - classifiers.update({ - "GradientBoostingClassifier": GradientBoostingClassifier - }) + classifiers.update({"GradientBoostingClassifier": GradientBoostingClassifier}) for name, TreeClassifier in classifiers.items(): if depth_first: est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) From df316970e95b45a301d238918b971508b2e0f73f Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 24 Jun 2021 16:32:09 +0200 Subject: [PATCH 073/174] Add whats_new entry --- doc/whats_new/v1.0.rst | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index e4bff3c124dc5..517c35a6ffe2b 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -288,6 +288,16 @@ Changelog interpreted as using all `n_samples` for bootstrapping. :pr:`20159` by :user:`murata-yu`. +- |Feature| :class:`ensemble.RandomForestClassifier`, + :class:`ensemble.RandomForestRegressor`, + :class:`ensemble.ExtraTreesClassifier`, + :class:`ensemble.ExtraTreesRegressor`, + :class:`ensemble.GradientBoostingClassifier` and + :class:`ensemble.GradientBoostingRegressor` now support monotonic + constraints, useful when features are supposed to have a positive/negative + effect on the target. :pr:`13649` by :user:`Samuel Ronsin ` and + :user:`Patrick O'Reilly `. + :mod:`sklearn.feature_extraction` ................................. @@ -563,6 +573,12 @@ Changelog and will be removed in 1.2. :pr:`20272` by :user:`Jérémie du Boisberranger `. +- |Feature| :class:`tree.DecisionTreeClassifier` and + :class:`tree.DecisionTreeRegressor` now support monotonic + constraints, useful when features are supposed to have a positive/negative + effect on the target. :pr:`13649` by :user:`Samuel Ronsin ` and + :user:`Patrick O'Reilly `. + :mod:`sklearn.utils` .................... From e6937d9e6abd69b894e48dca5833860b27519f4e Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 30 Nov 2021 17:03:30 +0100 Subject: [PATCH 074/174] Fix 1/-1 mixup in the docstrings relative to positive/negative constraint --- sklearn/ensemble/_forest.py | 10 +++++----- sklearn/ensemble/_gb.py | 4 ++-- sklearn/ensemble/_hist_gradient_boosting/grower.py | 2 +- sklearn/tree/_classes.py | 8 ++++---- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 6a3223e1e162a..85228ca684763 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1231,7 +1231,7 @@ class RandomForestClassifier(ForestClassifier): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. -1, 1 + Indicates the monotonic constraint to enforce on each feature. 1, -1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. The constraints are only valid for binary classifications and hold over the probability of the positive @@ -1562,7 +1562,7 @@ class RandomForestRegressor(ForestRegressor): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. -1, 1 + Indicates the monotonic constraint to enforce on each feature. 1, -1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. @@ -1890,7 +1890,7 @@ class ExtraTreesClassifier(ForestClassifier): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. -1, 1 + Indicates the monotonic constraint to enforce on each feature. 1, -1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. The constraints are only valid for binary classifications and hold over the probability of the positive @@ -2210,7 +2210,7 @@ class ExtraTreesRegressor(ForestRegressor): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. -1, 1 + Indicates the monotonic constraint to enforce on each feature. 1, -1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. @@ -2457,7 +2457,7 @@ class RandomTreesEmbedding(BaseForest): new forest. See :term:`the Glossary `. monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. -1, 1 + Indicates the monotonic constraint to enforce on each feature. 1, -1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. The constraints are only valid for binary classifications and hold over the probability of the positive diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index ac294cfee373b..45cf4e7a573fe 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -1092,7 +1092,7 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. -1, 1 + Indicates the monotonic constraint to enforce on each feature. 1, -1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. The constraints are only valid for binary classifications and hold over the probability of the positive @@ -1676,7 +1676,7 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. -1, 1 + Indicates the monotonic constraint to enforce on each feature. 1, -1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. diff --git a/sklearn/ensemble/_hist_gradient_boosting/grower.py b/sklearn/ensemble/_hist_gradient_boosting/grower.py index 81e971de700e4..1f1d934b629f8 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/grower.py +++ b/sklearn/ensemble/_hist_gradient_boosting/grower.py @@ -162,7 +162,7 @@ class TreeGrower: is_categorical : ndarray of bool of shape (n_features,), default=None Indicates categorical features. monotonic_cst : array-like of shape (n_features,), dtype=int, default=None - Indicates the monotonic constraint to enforce on each feature. -1, 1 + Indicates the monotonic constraint to enforce on each feature. 1, -1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. Read more in the :ref:`User Guide `. diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 934f36a59a40e..74dc77a58170c 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -802,7 +802,7 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. -1, 1 + Indicates the monotonic constraint to enforce on each feature. 1, -1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. The constraints are only valid for binary classifications and hold over the probability of the positive @@ -1198,7 +1198,7 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. -1, 1 + Indicates the monotonic constraint to enforce on each feature. 1, -1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. @@ -1526,7 +1526,7 @@ class ExtraTreeClassifier(DecisionTreeClassifier): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. -1, 1 + Indicates the monotonic constraint to enforce on each feature. 1, -1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. The constraints are only valid for binary classifications and hold over the probability of the positive @@ -1779,7 +1779,7 @@ class ExtraTreeRegressor(DecisionTreeRegressor): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. -1, 1 + Indicates the monotonic constraint to enforce on each feature. 1, -1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. From 56b361d895b7284ce5d40d8f2e0e4222fee3189c Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 30 Nov 2021 17:11:18 +0100 Subject: [PATCH 075/174] Apply suggestions from code review Co-authored-by: Julien Jerphanion --- sklearn/tree/_classes.py | 1 - sklearn/tree/_splitter.pxd | 13 ++++++++----- sklearn/tree/_tree.pxd | 1 - sklearn/tree/_tree.pyx | 15 ++++++++++++++- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 74dc77a58170c..6370efb170745 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -916,7 +916,6 @@ def __init__( ccp_alpha=0.0, monotonic_cst=None, ): - super().__init__( criterion=criterion, splitter=splitter, diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 8dc47f53aef9a..10882cc770ca2 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -61,10 +61,12 @@ cdef class Splitter: cdef const DOUBLE_t[:, ::1] y cdef DOUBLE_t* sample_weight - cdef INT32_t* monotonic_cst # Monotonicity constraints - # -1: monotonically decreasing - # 0: no constraint - # +1: monotonically increasing + # Monotonicity constraints for each feature. + # The encoding is as follows: + # -1: monotonically decreasing + # 0: no constraint + # +1: monotonically increasing + cdef INT32_t* monotonic_cst # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, @@ -94,7 +96,8 @@ cdef class Splitter: SplitRecord* split, SIZE_t* n_constant_features, double lower_bound, - double upper_bound) nogil except -1 + double upper_bound, + ) nogil except -1 cdef void node_value(self, double* dest) nogil diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 5c66aaf5ac100..0874187ee98ae 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -59,7 +59,6 @@ cdef class Tree: SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, double weighted_n_node_samples) nogil except -1 - cdef int _resize(self, SIZE_t capacity) nogil except -1 cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1 diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index f5ed77419a9a9..c35801c175fac 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -260,16 +260,19 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if not is_leaf: if splitter.monotonic_cst[split.feature] == 0: + # No constraint left_child_min = lower_bound left_child_max = upper_bound right_child_min = lower_bound right_child_max = upper_bound elif splitter.monotonic_cst[split.feature] == 1: + # Monotonically increasing constraint left_child_min = lower_bound left_child_max = middle_value right_child_min = middle_value right_child_max = upper_bound elif splitter.monotonic_cst[split.feature] == -1: + # Monotonically decreasing constraint left_child_min = middle_value left_child_max = upper_bound right_child_min = lower_bound @@ -1437,7 +1440,17 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT weighted_n_node_samples[i] * impurity[i] / total_sum_weights) # Push root node, using StackRecord.start as node id - rc = stack.push(0, 0, 0, -1, 0, 0, 0, -INFINITY, INFINITY) + rc = stack.push( + start=0, + end=0, + depth=0, + parent=-1, + is_left=0, + impurity=0, + n_constant_features=0, + lower_bound=-INFINITY, + upper_bound=INFINITY, + ) if rc == -1: with gil: raise MemoryError("pruning tree") From ba95efa45b2548c69176ab47ba130c04b64ae58c Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Tue, 30 Nov 2021 17:42:47 +0100 Subject: [PATCH 076/174] Move blurbs to latest whats_new --- doc/whats_new/v1.0.rst | 15 --------------- doc/whats_new/v1.1.rst | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 3f73738f49126..a6c955b5afcdc 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -592,15 +592,6 @@ Changelog interpreted as using all `n_samples` for bootstrapping. :pr:`20159` by :user:`murata-yu`. -- |Feature| :class:`ensemble.RandomForestClassifier`, - :class:`ensemble.RandomForestRegressor`, - :class:`ensemble.ExtraTreesClassifier`, - :class:`ensemble.ExtraTreesRegressor`, - :class:`ensemble.GradientBoostingClassifier` and - :class:`ensemble.GradientBoostingRegressor` now support monotonic - constraints, useful when features are supposed to have a positive/negative - effect on the target. :pr:`13649` by :user:`Samuel Ronsin ` and - :user:`Patrick O'Reilly `. - |Fix| Fixed a bug in :class:`ensemble.AdaBoostClassifier` and :class:`ensemble.AdaBoostRegressor` where the `sample_weight` parameter got overwritten during `fit`. @@ -1079,12 +1070,6 @@ Changelog and will be removed in 1.2. :pr:`20272` by :user:`Jérémie du Boisberranger `. -- |Feature| :class:`tree.DecisionTreeClassifier` and - :class:`tree.DecisionTreeRegressor` now support monotonic - constraints, useful when features are supposed to have a positive/negative - effect on the target. :pr:`13649` by :user:`Samuel Ronsin ` and - :user:`Patrick O'Reilly `. - :mod:`sklearn.utils` .................... diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index f07b33ae9e4da..08617d76d46d6 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -163,6 +163,16 @@ Changelog :class:`ensemble.ExtraTreesClassifier`. :pr:`20803` by :user:`Brian Sun `. +- |Feature| :class:`ensemble.RandomForestClassifier`, + :class:`ensemble.RandomForestRegressor`, + :class:`ensemble.ExtraTreesClassifier`, + :class:`ensemble.ExtraTreesRegressor`, + :class:`ensemble.GradientBoostingClassifier` and + :class:`ensemble.GradientBoostingRegressor` now support monotonic + constraints, useful when features are supposed to have a positive/negative + effect on the target. :pr:`13649` by :user:`Samuel Ronsin ` and + :user:`Patrick O'Reilly `. + :mod:`sklearn.impute` ..................... @@ -304,6 +314,12 @@ Changelog to train and pickle the model on 64 bit machine and load the model on a 32 bit machine for prediction. :pr:`21552` by :user:`Loïc Estève `. +- |Feature| :class:`tree.DecisionTreeClassifier` and + :class:`tree.DecisionTreeRegressor` now support monotonic + constraints, useful when features are supposed to have a positive/negative + effect on the target. :pr:`13649` by :user:`Samuel Ronsin ` and + :user:`Patrick O'Reilly `. + Code and Documentation Contributors ----------------------------------- From b764a892346b2ca8670d31aa80c3a4ddb221ac5a Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 1 Dec 2021 11:02:09 +0100 Subject: [PATCH 077/174] Remove named argument for cdef function and use deref instead of [0] in criterion sums --- sklearn/tree/_splitter.pyx | 15 +++++++++------ sklearn/tree/_tree.pyx | 12 +----------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 14e4ef4f3eca2..cc520a11ddf37 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -18,6 +18,8 @@ from libc.stdlib cimport qsort from libc.string cimport memcpy from libc.string cimport memset +from cython.operator cimport dereference as deref + import numpy as np cimport numpy as np np.import_array() @@ -237,15 +239,15 @@ cdef class Splitter: cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: """Check monotonic constraint is satisfied at the current split""" cdef bint check_lower, check_upper, check_monotonic - check_lower = (self.criterion.sum_left[0] >= lower_bound * self.criterion.weighted_n_left) \ - & (self.criterion.sum_right[0] >= lower_bound * self.criterion.weighted_n_right) - check_upper = (self.criterion.sum_left[0] <= upper_bound * self.criterion.weighted_n_left) \ - & (self.criterion.sum_right[0] <= upper_bound * self.criterion.weighted_n_right) + check_lower = (deref(self.criterion.sum_left) >= lower_bound * self.criterion.weighted_n_left) \ + & (deref(self.criterion.sum_right) >= lower_bound * self.criterion.weighted_n_right) + check_upper = (deref(self.criterion.sum_left) <= upper_bound * self.criterion.weighted_n_left) \ + & (deref(self.criterion.sum_right) <= upper_bound * self.criterion.weighted_n_right) if monotonic_cst == 0: # No constraint return check_lower & check_upper else: - check_monotonic = (self.criterion.sum_left[0] * self.criterion.weighted_n_right - - self.criterion.sum_right[0] * self.criterion.weighted_n_left) \ + check_monotonic = (deref(self.criterion.sum_left) * self.criterion.weighted_n_right + - deref(self.criterion.sum_right) * self.criterion.weighted_n_left) \ * monotonic_cst <= 0 return check_lower & check_upper & check_monotonic @@ -476,6 +478,7 @@ cdef class BestSplitter(BaseDenseSplitter): n_constant_features[0] = n_total_constants return 0 + # Sort n-element arrays pointed to by Xf and samples, simultaneously, # by the values in Xf. Algorithm: Introsort (Musser, SP&E, 1997). cdef inline void sort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil: diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index a65f5ae8040c8..07a46b338db0c 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1563,17 +1563,7 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT weighted_n_node_samples[i] * impurity[i] / total_sum_weights) # Push root node, using StackRecord.start as node id - rc = stack.push( - start=0, - end=0, - depth=0, - parent=-1, - is_left=0, - impurity=0, - n_constant_features=0, - lower_bound=-INFINITY, - upper_bound=INFINITY, - ) + rc = stack.push(0, 0, 0, -1, 0, 0, 0, -INFINITY, INFINITY) if rc == -1: with gil: raise MemoryError("pruning tree") From c721fa690ac11b6e1eaf55c9a2c319880d42482f Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 1 Dec 2021 11:09:51 +0100 Subject: [PATCH 078/174] Update wording in docstring for perform_line_search argument --- sklearn/ensemble/_gb_losses.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/_gb_losses.py b/sklearn/ensemble/_gb_losses.py index 8bbe726dab70f..a067c5c67c01f 100644 --- a/sklearn/ensemble/_gb_losses.py +++ b/sklearn/ensemble/_gb_losses.py @@ -109,9 +109,8 @@ def update_terminal_regions( k : int, default=0 The index of the estimator being updated. perform_line_search : bool, default=True - Whether line search must be performed. Monotonic constraints - requires not performing line search. - + Whether line search must be performed. Line search must not be + performed under monotonic constraints. """ # compute leaf for each sample in ``X``. terminal_regions = tree.apply(X) @@ -290,8 +289,8 @@ def update_terminal_regions( k : int, default=0 The index of the estimator being updated. perform_line_search : bool, default=True - Whether line search must be performed. Monotonic constraints - requires not performing line search. + Whether line search must be performed. Line search must not be + performed under monotonic constraints. """ # update predictions raw_predictions[:, k] += learning_rate * tree.predict(X).ravel() From f4ebabe5cbbb2183834767ecf3dbfc345b7b75ad Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 1 Dec 2021 12:13:29 +0100 Subject: [PATCH 079/174] Split test_montonic_constraints into two tests for classifications/regressions --- sklearn/tree/tests/test_monotonic_tree.py | 84 +++++++++++++++-------- 1 file changed, 57 insertions(+), 27 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 10a387a6aee9c..a93ae621b4d82 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -1,18 +1,26 @@ import numpy as np import pytest -from sklearn import datasets -from sklearn.tree import DecisionTreeRegressor -from sklearn.tree.tests.test_tree import REG_TREES, CLF_TREES + +from sklearn.datasets import make_classification, make_regression from sklearn.ensemble import GradientBoostingClassifier from sklearn.ensemble import GradientBoostingRegressor +from sklearn.tree import DecisionTreeRegressor +from sklearn.tree.tests.test_tree import REG_TREES, CLF_TREES @pytest.mark.parametrize("seed", range(4)) @pytest.mark.parametrize("depth_first", (True, False)) -def test_montonic_constraints(seed, depth_first): +def test_montonic_constraints_classifications(seed, depth_first): n_samples = 1000 n_samples_train = 900 - X, y = datasets.make_hastie_10_2(n_samples=n_samples, random_state=0) + X, y = make_classification( + n_samples=n_samples, + n_classes=2, + n_features=5, + n_informative=5, + n_redundant=0, + random_state=0, + ) train = np.arange(n_samples_train) test = np.arange(n_samples_train, n_samples) X_train = X[train] @@ -27,14 +35,13 @@ def test_montonic_constraints(seed, depth_first): monotonic_cst[0] = 1 monotonic_cst[1] = -1 - regressors = REG_TREES.copy() - regressors.update({"GradientBoostingRegressor": GradientBoostingRegressor}) - - for name, TreeRegressor in regressors.items(): + classifiers = CLF_TREES.copy() + classifiers.update({"GradientBoostingClassifier": GradientBoostingClassifier}) + for name, TreeClassifier in classifiers.items(): if depth_first: - est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) + est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) else: - est = TreeRegressor( + est = TreeClassifier( max_depth=None, monotonic_cst=monotonic_cst, max_leaf_nodes=n_samples_train, @@ -44,24 +51,48 @@ def test_montonic_constraints(seed, depth_first): if hasattr(est, "n_estimators"): est.set_params(**{"n_estimators": 5}) est.fit(X_train, y_train) - y = est.predict(X_test) + y = est.predict_proba(X_test)[:, 1] + # increasing constraint - y_incr = est.predict(X_test_incr) + y_incr = est.predict_proba(X_test_incr)[:, 1] # y_incr should always be greater than y assert np.all(y_incr >= y) # decreasing constraint - y_decr = est.predict(X_test_decr) + y_decr = est.predict_proba(X_test_decr)[:, 1] # y_decr should always be lower than y assert np.all(y_decr <= y) - classifiers = CLF_TREES.copy() - classifiers.update({"GradientBoostingClassifier": GradientBoostingClassifier}) - for name, TreeClassifier in classifiers.items(): + +@pytest.mark.parametrize("seed", range(4)) +@pytest.mark.parametrize("depth_first", (True, False)) +def test_montonic_constraints_regressions(seed, depth_first): + n_samples = 1000 + n_samples_train = 900 + # Build a classification task using 3 informative features + X, y = make_regression( + n_samples=n_samples, n_features=5, n_informative=5, random_state=0 + ) + train = np.arange(n_samples_train) + test = np.arange(n_samples_train, n_samples) + X_train = X[train] + y_train = y[train] + X_test = np.copy(X[test]) + X_test_incr = np.copy(X_test) + X_test_decr = np.copy(X_test) + X_test_incr[:, 0] += 10 + X_test_decr[:, 1] += 10 + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = 1 + monotonic_cst[1] = -1 + regressors = REG_TREES.copy() + regressors.update({"GradientBoostingRegressor": GradientBoostingRegressor}) + + for name, TreeRegressor in regressors.items(): if depth_first: - est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) + est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) else: - est = TreeClassifier( + est = TreeRegressor( max_depth=None, monotonic_cst=monotonic_cst, max_leaf_nodes=n_samples_train, @@ -71,21 +102,22 @@ def test_montonic_constraints(seed, depth_first): if hasattr(est, "n_estimators"): est.set_params(**{"n_estimators": 5}) est.fit(X_train, y_train) - y = est.predict_proba(X_test)[:, 1] - + y = est.predict(X_test) # increasing constraint - y_incr = est.predict_proba(X_test_incr)[:, 1] + y_incr = est.predict(X_test_incr) # y_incr should always be greater than y assert np.all(y_incr >= y) # decreasing constraint - y_decr = est.predict_proba(X_test_decr)[:, 1] + y_decr = est.predict(X_test_decr) # y_decr should always be lower than y assert np.all(y_decr <= y) def test_multiclass_raises(): - X, y = datasets.make_hastie_10_2(n_samples=100, random_state=0) + X, y = make_classification( + n_samples=100, n_features=5, n_classes=3, n_informative=3, random_state=0 + ) y[0] = 0 monotonic_cst = np.zeros(X.shape[1]) monotonic_cst[0] = -1 @@ -97,9 +129,7 @@ def test_multiclass_raises(): if hasattr(est, "random_state"): est.set_params(**{"random_state": 0}) - msg = ( - "Monotonic constraints are not supported with multiclass " "classification" - ) + msg = "Monotonic constraints are not supported with multiclass classification" with pytest.raises(ValueError, match=msg): est.fit(X, y) From c0fa6f83a343552fa150ab20640f1e5a8c4fba03 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 1 Dec 2021 17:18:04 +0100 Subject: [PATCH 080/174] Add name of arguments to stack.push calls as comments --- sklearn/tree/_tree.pyx | 110 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 97 insertions(+), 13 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 07a46b338db0c..c73ea4e23ac34 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -190,7 +190,17 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): with nogil: # push root node onto stack - rc = stack.push(0, n_node_samples, 0, _TREE_UNDEFINED, 0, INFINITY, 0, -INFINITY, INFINITY) + rc = stack.push( + 0, # start + n_node_samples, # end + 0, # depth + _TREE_UNDEFINED, # parent + 0, # is_left + INFINITY, # impurity + 0, # n_constant_features + -INFINITY, # lower_bound + INFINITY # upper_bound + ) if rc == -1: # got return code -1 - out-of-memory with gil: @@ -277,16 +287,32 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): right_child_max = middle_value # Push right child on stack - rc = stack.push(split.pos, end, depth + 1, node_id, 0, - split.impurity_right, n_constant_features, - right_child_min, right_child_max) + rc = stack.push( + split.pos, # start + end, # end + depth + 1, # depth + node_id, # parent + 0, # is_left + split.impurity_right, # impurity + n_constant_features, # n_constant_features + right_child_min, # lower_bound + right_child_max # upper_bound + ) if rc == -1: break # Push left child on stack - rc = stack.push(start, split.pos, depth + 1, node_id, 1, - split.impurity_left, n_constant_features, left_child_min, - left_child_max) + rc = stack.push( + start, # start + split.pos, # end + depth + 1, # depth + node_id, # parent + 1, # is_left + split.impurity_left, # impurity + n_constant_features, # n_constant_features + left_child_min, # lower_bound + left_child_max # upper_bound + ) if rc == -1: break @@ -1563,7 +1589,17 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT weighted_n_node_samples[i] * impurity[i] / total_sum_weights) # Push root node, using StackRecord.start as node id - rc = stack.push(0, 0, 0, -1, 0, 0, 0, -INFINITY, INFINITY) + rc = stack.push( + 0, # start + 0, # end + 0, # depth + -1, # parent + 0, # is_left + 0, # impurity + 0, # n_constant_features + -INFINITY, # lower_bound + INFINITY # upper_bound + ) if rc == -1: with gil: raise MemoryError("pruning tree") @@ -1576,12 +1612,32 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT # ... and child_r[node_idx] == _TREE_LEAF: leaves_in_subtree[node_idx] = 1 else: - rc = stack.push(child_l[node_idx], 0, 0, node_idx, 0, 0, 0, -INFINITY, INFINITY) + rc = stack.push( + child_l[node_idx], # start + 0, # end + 0, # depth + node_idx, # parent + 0, # is_left + 0, # impurity + 0, # n_constant_features + -INFINITY, # lower_bound + INFINITY # upper_bound + ) if rc == -1: with gil: raise MemoryError("pruning tree") - rc = stack.push(child_r[node_idx], 0, 0, node_idx, 0, 0, 0, -INFINITY, INFINITY) + rc = stack.push( + child_r[node_idx], # start + 0, # end + 0, # depth + node_idx, # parent + 0, # is_left + 0, # impurity + 0, # n_constant_features + -INFINITY, # lower_bound + INFINITY # upper_bound + ) if rc == -1: with gil: raise MemoryError("pruning tree") @@ -1797,7 +1853,17 @@ cdef _build_pruned_tree( with nogil: # push root node onto stack - rc = stack.push(0, 0, 0, _TREE_UNDEFINED, 0, 0.0, 0, -INFINITY, INFINITY) + rc = stack.push( + 0, # start + 0, # end + 0, # depth + _TREE_UNDEFINED, # parent + 0, # is_left + 0.0, # impurity + 0, # n_constant_features + -INFINITY, # lower_bound + INFINITY # upper_bound + ) if rc == -1: with gil: raise MemoryError("pruning tree") @@ -1830,13 +1896,31 @@ cdef _build_pruned_tree( if not is_leaf: # Push right child on stack rc = stack.push( - node.right_child, 0, depth + 1, new_node_id, 0, 0.0, 0, -INFINITY, INFINITY) + node.right_child, # start + 0, # end + depth + 1, # depth + new_node_id, # parent + 0, # is_left + 0.0, # impurity + 0, # n_constant_features + -INFINITY, # lower_bound + INFINITY # upper_bound + ) if rc == -1: break # push left child on stack rc = stack.push( - node.left_child, 0, depth + 1, new_node_id, 1, 0.0, 0, -INFINITY, INFINITY) + node.left_child, # start + 0, # end + depth + 1, # depth + new_node_id, # parent + 1, # is_left + 0.0, # impurity + 0, # n_constant_features + -INFINITY, # lower_bound + INFINITY # upper_bound + ) if rc == -1: break From 70de12992e360ce95047b2f166ef42a9c5e1eaa5 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 1 Dec 2021 19:03:59 +0100 Subject: [PATCH 081/174] Add test for incorrect input --- sklearn/tree/tests/test_monotonic_tree.py | 33 +++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index a93ae621b4d82..160e3062153cb 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -134,6 +134,39 @@ def test_multiclass_raises(): est.fit(X, y) +def test_multiple_output_raises(): + X = [[1, 2, 3, 4, 5], + [6, 7, 8, 9, 10]] + y = [[1, 0, 1, 0, 1], + [1, 0, 1, 0, 1]] + + for name, TreeClassifier in CLF_TREES.items(): + est = TreeClassifier( + max_depth=None, monotonic_cst=np.array([-1, 1]), random_state=0 + ) + msg = "Monotonic constraints are not supported with multiple output" + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + + +def test_bad_monotonic_cst_raises(): + X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] + y = [1, 0, 1, 0, 1] + + for name, TreeClassifier in CLF_TREES.items(): + est = TreeClassifier( + max_depth=None, monotonic_cst=np.array([-1, 1, 0]), random_state=0 + ) + msg = "monotonic_cst has shape 3 but the input data X has 2 features." + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + est = TreeClassifier( + max_depth=None, monotonic_cst=np.array([-2, 2]), random_state=0 + ) + msg = "monotonic_cst must be None or an array-like of -1, 0 or 1." + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + def assert_1d_reg_tree_children_monotonic_bounded(tree_, monotonic_sign): # Flip values to always check for increasing constraint values = monotonic_sign * tree_.value From 20590d5fe07009a640f44fd5ad205b507b2a1cd0 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 1 Dec 2021 19:05:33 +0100 Subject: [PATCH 082/174] Add blank line --- sklearn/tree/tests/test_monotonic_tree.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 160e3062153cb..68da7b1820846 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -167,6 +167,7 @@ def test_bad_monotonic_cst_raises(): with pytest.raises(ValueError, match=msg): est.fit(X, y) + def assert_1d_reg_tree_children_monotonic_bounded(tree_, monotonic_sign): # Flip values to always check for increasing constraint values = monotonic_sign * tree_.value From 573aa796d7f261a7d6d4f382cca5817c2584b0bf Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 2 Dec 2021 11:12:40 +0100 Subject: [PATCH 083/174] Format --- sklearn/tree/tests/test_monotonic_tree.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 68da7b1820846..43028545ee2ea 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -135,10 +135,8 @@ def test_multiclass_raises(): def test_multiple_output_raises(): - X = [[1, 2, 3, 4, 5], - [6, 7, 8, 9, 10]] - y = [[1, 0, 1, 0, 1], - [1, 0, 1, 0, 1]] + X = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] + y = [[1, 0, 1, 0, 1], [1, 0, 1, 0, 1]] for name, TreeClassifier in CLF_TREES.items(): est = TreeClassifier( From 03ee5f40373d25560e9152fe87c303a5ebc1a4af Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 3 Dec 2021 14:55:20 +0100 Subject: [PATCH 084/174] Factor the logic computing the middle value Co-authored-by: Samuel O. Ronsin --- sklearn/tree/_tree.pxd | 1 + sklearn/tree/_tree.pyx | 44 +++++++++++++++++++++++------------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 0874187ee98ae..7d538cab8d0cc 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -64,6 +64,7 @@ cdef class Tree: cdef np.ndarray _get_value_ndarray(self) cdef np.ndarray _get_node_ndarray(self) + cdef inline double _get_middle_value(self, int node_id) nogil cpdef np.ndarray predict(self, object X) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index c73ea4e23ac34..e7eb436df3b4b 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -255,16 +255,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # inspection and interpretation splitter.node_value(tree.value + node_id * tree.value_stride) - if tree.max_n_classes == 1: - # Regression - middle_value = tree.value[node_id] - else: - # Classification - with gil: - middle_value = ( - tree._get_value_ndarray()[node_id][0][0] - / np.sum(tree._get_value_ndarray()[node_id][0]) - ) + middle_value = tree._get_middle_value(node_id) if not is_leaf: if splitter.monotonic_cst[split.feature] == 0: @@ -433,27 +424,19 @@ cdef class BestFirstTreeBuilder(TreeBuilder): else: # Node is expandable - if tree.max_n_classes == 1: - # Regression - middle_value = tree.value[record.node_id] - else: - # Classification - with gil: - middle_value = ( - tree._get_value_ndarray()[record.node_id][0][0] - / np.sum(tree._get_value_ndarray()[record.node_id][0]) - ) if splitter.monotonic_cst[node.feature] == 0: left_child_min = record.lower_bound left_child_max = record.upper_bound right_child_min = record.lower_bound right_child_max = record.upper_bound elif splitter.monotonic_cst[node.feature] == 1: + middle_value = tree._get_middle_value(record.node_id) left_child_min = record.lower_bound left_child_max = middle_value right_child_min = middle_value right_child_max = record.upper_bound elif splitter.monotonic_cst[node.feature] == -1: + middle_value = tree._get_middle_value(record.node_id) left_child_min = middle_value left_child_max = record.upper_bound right_child_min = record.lower_bound @@ -1222,6 +1205,27 @@ cdef class Tree: raise ValueError("Can't initialize array.") return arr + cdef inline double _get_middle_value(self, int node_id) nogil: + cdef: + int i + double middle_value = 0 + + if self.max_n_classes == 1: + # Regression + middle_value = self.value[node_id] + else: + # Classification + + # This performs some raw pointers arithmetic + # Ideally a memory view could wrap self.value so that's + # indexing on several axis can be used. + for i in range(self.max_n_classes): + middle_value += self.value[node_id * self.value_stride + i] + + middle_value = self.value[node_id * self.value_stride] / middle_value + + return middle_value + def compute_partial_dependence(self, DTYPE_t[:, ::1] X, int[::1] target_features, double[::1] out): From 779efcbc35008f6068b4867bd016c43e021cc765 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 12 May 2022 18:56:58 +0200 Subject: [PATCH 085/174] Apply Julien's suggestions in tests Co-authored-by: Julien Jerphanion --- sklearn/tree/tests/test_monotonic_tree.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 43028545ee2ea..b40706bd43375 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -98,7 +98,7 @@ def test_montonic_constraints_regressions(seed, depth_first): max_leaf_nodes=n_samples_train, ) if hasattr(est, "random_state"): - est.set_params(**{"random_state": seed}) + est.set_params(random_state=seed) if hasattr(est, "n_estimators"): est.set_params(**{"n_estimators": 5}) est.fit(X_train, y_train) @@ -158,10 +158,17 @@ def test_bad_monotonic_cst_raises(): msg = "monotonic_cst has shape 3 but the input data X has 2 features." with pytest.raises(ValueError, match=msg): est.fit(X, y) + msg = "monotonic_cst must be None or an array-like of -1, 0 or 1." + est = TreeClassifier( max_depth=None, monotonic_cst=np.array([-2, 2]), random_state=0 ) - msg = "monotonic_cst must be None or an array-like of -1, 0 or 1." + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + + est = TreeClassifier( + max_depth=None, monotonic_cst=np.array([-1, 0.8]), random_state=0 + ) with pytest.raises(ValueError, match=msg): est.fit(X, y) From 1426e1df8a3d223da440a3eac383756110544023 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 12 May 2022 19:14:25 +0200 Subject: [PATCH 086/174] Clean up check of monotonic_cst values Co-authored-by: Julien Jerphanion --- sklearn/tree/_classes.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index e5f0a9073deb4..55373925e3fab 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -391,7 +391,14 @@ def fit(self, X, y, sample_weight=None, check_input=True): "monotonic_cst has shape {} but the input data " "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) ) - if np.any(monotonic_cst < -1) or np.any(monotonic_cst > 1): + # Applying element-wise logical conjunction + # for monotonic constraints' support. + unsatisfied_constraints_conditions = ( + (monotonic_cst != -1) * + (monotonic_cst != 0) * + (monotonic_cst != 1) + ) + if np.any(unsatisfied_constraints_conditions): raise ValueError( "monotonic_cst must be None or an array-like of -1, 0 or 1." ) From 1db26e0f00cb7dd399974c6cf3696009bd086e07 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 12 May 2022 19:15:58 +0200 Subject: [PATCH 087/174] Use local variables for attributes and initialize on definition Co-authored-by: Julien Jerphanion --- sklearn/tree/_splitter.pyx | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index cc520a11ddf37..00689763367e4 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -238,17 +238,28 @@ cdef class Splitter: cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: """Check monotonic constraint is satisfied at the current split""" - cdef bint check_lower, check_upper, check_monotonic - check_lower = (deref(self.criterion.sum_left) >= lower_bound * self.criterion.weighted_n_left) \ - & (deref(self.criterion.sum_right) >= lower_bound * self.criterion.weighted_n_right) - check_upper = (deref(self.criterion.sum_left) <= upper_bound * self.criterion.weighted_n_left) \ - & (deref(self.criterion.sum_right) <= upper_bound * self.criterion.weighted_n_right) + cdef: + double sum_left = deref(self.criterion.sum_left) + double sum_right = deref(self.criterion.sum_right) + double weighted_n_left = self.criterion.weighted_n_left + double weighted_n_right = self.criterion.weighted_n_right + + bint check_lower = ( + sum_left >= lower_bound * weighted_n_left & + sum_right >= lower_bound * weighted_n_right + ) + bint check_upper = ( + sum_left <= upper_bound * weighted_n_left & + sum_right <= upper_bound * weighted_n_right + ) + bint check_monotonic if monotonic_cst == 0: # No constraint return check_lower & check_upper else: - check_monotonic = (deref(self.criterion.sum_left) * self.criterion.weighted_n_right - - deref(self.criterion.sum_right) * self.criterion.weighted_n_left) \ - * monotonic_cst <= 0 + check_monotonic = ( + (sum_left * weighted_n_right - + sum_right * weighted_n_left) * monotonic_cst <= 0 + ) return check_lower & check_upper & check_monotonic cdef class BaseDenseSplitter(Splitter): From e87e40ae1db17a96ae6604d46d7f3efae560e294 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 12 May 2022 19:23:16 +0200 Subject: [PATCH 088/174] Add comments in BestFirstTreeBuilder Co-authored-by: Julien Jerphanion --- sklearn/tree/_tree.pyx | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index e7eb436df3b4b..d5ab4461ec690 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -425,17 +425,20 @@ cdef class BestFirstTreeBuilder(TreeBuilder): else: # Node is expandable if splitter.monotonic_cst[node.feature] == 0: + # No constraint left_child_min = record.lower_bound left_child_max = record.upper_bound right_child_min = record.lower_bound right_child_max = record.upper_bound elif splitter.monotonic_cst[node.feature] == 1: + # Monotonically increasing constraint middle_value = tree._get_middle_value(record.node_id) left_child_min = record.lower_bound left_child_max = middle_value right_child_min = middle_value right_child_max = record.upper_bound elif splitter.monotonic_cst[node.feature] == -1: + # Monotonically decreasing constraint middle_value = tree._get_middle_value(record.node_id) left_child_min = middle_value left_child_max = record.upper_bound From f03f08934ccc8ee6702adda87250f133ec65a8ee Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 12 May 2022 19:45:50 +0200 Subject: [PATCH 089/174] Update methods signature format for readability --- sklearn/tree/_tree.pyx | 117 +++++++++++++++++++++++++++++++---------- 1 file changed, 89 insertions(+), 28 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index d5ab4461ec690..d21af9cc46970 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -398,10 +398,20 @@ cdef class BestFirstTreeBuilder(TreeBuilder): with nogil: # add root to frontier - rc = self._add_split_node(splitter, tree, 0, n_node_samples, - INFINITY, IS_FIRST, IS_LEFT, NULL, 0, - -INFINITY, INFINITY, - &split_node_left) + rc = self._add_split_node( + splitter, # splitter + tree, # tree + 0, # start + n_node_samples, # end + INFINITY, # impurity + IS_FIRST, # is_first + IS_LEFT, # is_left + NULL, # parent + 0, # depth + -INFINITY, # lower_bound + INFINITY, # upper_bound + &split_node_left # res + ) if rc >= 0: rc = _add_to_frontier(&split_node_left, frontier) @@ -449,14 +459,20 @@ cdef class BestFirstTreeBuilder(TreeBuilder): max_split_nodes -= 1 # Compute left split node - rc = self._add_split_node(splitter, tree, - record.start, record.pos, - record.impurity_left, - IS_NOT_FIRST, IS_LEFT, node, - record.depth + 1, - left_child_min, - left_child_max, - &split_node_left) + rc = self._add_split_node( + splitter, # splitter + tree, # tree + record.start, # start + record.pos, # end + record.impurity_left, # impurity + IS_NOT_FIRST, # is_first + IS_LEFT, # is_left + node, # parent + record.depth + 1, # depth + left_child_min, # lower_bound + left_child_max, # upper_bound + &split_node_left # res + ) if rc == -1: break @@ -464,14 +480,20 @@ cdef class BestFirstTreeBuilder(TreeBuilder): node = &tree.nodes[record.node_id] # Compute right split node - rc = self._add_split_node(splitter, tree, record.pos, - record.end, - record.impurity_right, - IS_NOT_FIRST, IS_NOT_LEFT, node, - record.depth + 1, - right_child_min, - right_child_max, - &split_node_right) + rc = self._add_split_node( + splitter, # splitter + tree, # tree + record.pos, # start + record.end, # end + record.impurity_right,# impurity + IS_NOT_FIRST, # is_first + IS_NOT_LEFT, # is_left + node, # parent + record.depth + 1, # depth + right_child_min, # lower_bound + right_child_max, # upper_bound + &split_node_right # res + ) if rc == -1: break @@ -496,12 +518,21 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if rc == -1: raise MemoryError() - cdef inline int _add_split_node(self, Splitter splitter, Tree tree, - SIZE_t start, SIZE_t end, double impurity, - bint is_first, bint is_left, Node* parent, - SIZE_t depth, double lower_bound, + cdef inline int _add_split_node( + self, + Splitter splitter, + Tree tree, + SIZE_t start, + SIZE_t end, + double impurity, + bint is_first, + bint is_left, + Node* parent, + SIZE_t depth, + double lower_bound, double upper_bound, - PriorityHeapRecord* res) nogil except -1: + PriorityHeapRecord* res + ) nogil except -1: """Adds node w/ partition ``[start, end)`` to the frontier. """ cdef SplitRecord split cdef SIZE_t node_id @@ -1687,7 +1718,17 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT break # stack uses only the start variable - rc = stack.push(pruned_branch_node_idx, 0, 0, 0, 0, 0, 0, -INFINITY, INFINITY) + rc = stack.push( + pruned_branch_node_idx, # start + 0, # end + 0, # depth + 0, # parent + 0, # is_left + 0, # impurity + 0, # n_constant_features + -INFINITY, # lower_bound + INFINITY # upper_bound + ) if rc == -1: with gil: raise MemoryError("pruning tree") @@ -1705,11 +1746,31 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT if child_l[node_idx] != _TREE_LEAF: # ... and child_r[node_idx] != _TREE_LEAF: - rc = stack.push(child_l[node_idx], 0, 0, 0, 0, 0, 0, -INFINITY, INFINITY) + rc = stack.push( + child_l[node_idx], # start + 0, # end + 0, # depth + 0, # parent + 0, # is_left + 0, # impurity + 0, # n_constant_features + -INFINITY, # lower_bound + INFINITY # upper_bound + ) if rc == -1: with gil: raise MemoryError("pruning tree") - rc = stack.push(child_r[node_idx], 0, 0, 0, 0, 0, 0, -INFINITY, INFINITY) + rc = stack.push( + child_r[node_idx], # start + 0, # end + 0, # depth + 0, # parent + 0, # is_left + 0, # impurity + 0, # n_constant_features + -INFINITY, # lower_bound + INFINITY # upper_bound + ) if rc == -1: with gil: raise MemoryError("pruning tree") From cc446f64584bc81ac199a930e5b4fd7a75448f27 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 12 May 2022 19:56:53 +0200 Subject: [PATCH 090/174] Update methods signature format for readability --- sklearn/tree/_splitter.pyx | 55 +++++++++++++++++++++++++++----------- sklearn/tree/_tree.pyx | 17 +++++++++--- 2 files changed, 54 insertions(+), 18 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 00689763367e4..2372724e1e5f2 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -213,9 +213,14 @@ cdef class Splitter: weighted_n_node_samples[0] = self.criterion.weighted_n_node_samples return 0 - cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features, double lower_bound, - double upper_bound) nogil except -1: + cdef int node_split( + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound, + ) nogil except -1: """Find the best split on node samples[start:end]. This is a placeholder method. The majority of computation will be done @@ -293,9 +298,14 @@ cdef class BestSplitter(BaseDenseSplitter): self.min_weight_leaf, self.random_state), self.__getstate__()) - cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features, double lower_bound, - double upper_bound) nogil except -1: + cdef int node_split( + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) nogil except -1: """Find the best split on node samples[start:end] Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -613,9 +623,14 @@ cdef class RandomSplitter(BaseDenseSplitter): self.min_weight_leaf, self.random_state), self.__getstate__()) - cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features, double lower_bound, - double upper_bound) nogil except -1: + cdef int node_split( + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) nogil except -1: """Find the best random split on node samples[start:end] Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -1139,9 +1154,14 @@ cdef class BestSparseSplitter(BaseSparseSplitter): self.min_weight_leaf, self.random_state), self.__getstate__()) - cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features, double lower_bound, - double upper_bound) nogil except -1: + cdef int node_split( + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) nogil except -1: """Find the best split on node samples[start:end], using sparse features Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -1376,9 +1396,14 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): self.min_weight_leaf, self.random_state), self.__getstate__()) - cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features, double lower_bound, - double upper_bound) nogil except -1: + cdef int node_split( + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) nogil except -1: """Find a random split on node samples[start:end], using sparse features Returns -1 in case of failure to allocate memory (and raise MemoryError) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index d21af9cc46970..6a75a7b7ee2cc 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -235,7 +235,13 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): is_leaf = is_leaf or impurity <= EPSILON if not is_leaf: - splitter.node_split(impurity, &split, &n_constant_features, lower_bound, upper_bound) + splitter.node_split( + impurity, + &split, + &n_constant_features, + lower_bound, + upper_bound + ) # If EPSILON=0 in the below comparison, float precision # issues stop splitting, producing trees that are # dissimilar to v0.18 @@ -560,8 +566,13 @@ cdef class BestFirstTreeBuilder(TreeBuilder): ) if not is_leaf: - splitter.node_split(impurity, &split, &n_constant_features, lower_bound, - upper_bound) + splitter.node_split( + impurity, + &split, + &n_constant_features, + lower_bound, + upper_bound + ) # If EPSILON=0 in the below comparison, float precision issues stop # splitting early, producing trees that are dissimilar to v0.18 is_leaf = (is_leaf or split.pos >= end or From c8ec3c4dc5e2c8c4f68c807431d791407ac7fc4d Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 12 May 2022 22:43:34 +0200 Subject: [PATCH 091/174] Format --- sklearn/tree/_classes.py | 4 +--- sklearn/tree/tests/test_monotonic_tree.py | 6 +++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 873b6500632ad..d75fdf11454cd 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -444,9 +444,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): # Applying element-wise logical conjunction # for monotonic constraints' support. unsatisfied_constraints_conditions = ( - (monotonic_cst != -1) * - (monotonic_cst != 0) * - (monotonic_cst != 1) + (monotonic_cst != -1) * (monotonic_cst != 0) * (monotonic_cst != 1) ) if np.any(unsatisfied_constraints_conditions): raise ValueError( diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index b40706bd43375..c0f6298a425b7 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -152,20 +152,20 @@ def test_bad_monotonic_cst_raises(): y = [1, 0, 1, 0, 1] for name, TreeClassifier in CLF_TREES.items(): + msg = "monotonic_cst has shape 3 but the input data X has 2 features." est = TreeClassifier( max_depth=None, monotonic_cst=np.array([-1, 1, 0]), random_state=0 ) - msg = "monotonic_cst has shape 3 but the input data X has 2 features." with pytest.raises(ValueError, match=msg): est.fit(X, y) + msg = "monotonic_cst must be None or an array-like of -1, 0 or 1." - est = TreeClassifier( max_depth=None, monotonic_cst=np.array([-2, 2]), random_state=0 ) with pytest.raises(ValueError, match=msg): est.fit(X, y) - + est = TreeClassifier( max_depth=None, monotonic_cst=np.array([-1, 0.8]), random_state=0 ) From 0a02f62b20cea8c542e682c6af139850e118c097 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 12 May 2022 23:05:33 +0200 Subject: [PATCH 092/174] Fix gil stuff and stuff --- sklearn/tree/_splitter.pxd | 2 +- sklearn/tree/_splitter.pyx | 34 +++++++++++++++++++--------------- sklearn/tree/_tree.pyx | 3 ++- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 10882cc770ca2..56bcd10306f8b 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -104,4 +104,4 @@ cdef class Splitter: cdef double node_impurity(self) nogil cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, - double lower_bound, double upper_bound) nogil + double lower_bound, double upper_bound) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 1808fc0367fd5..d1d99bba1d2fd 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -242,21 +242,21 @@ cdef class Splitter: return self.criterion.node_impurity() - cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: + cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound): """Check monotonic constraint is satisfied at the current split""" cdef: - double sum_left = deref(self.criterion.sum_left) - double sum_right = deref(self.criterion.sum_right) + double sum_left = self.criterion.sum_left + double sum_right = self.criterion.sum_right double weighted_n_left = self.criterion.weighted_n_left double weighted_n_right = self.criterion.weighted_n_right bint check_lower = ( - sum_left >= lower_bound * weighted_n_left & - sum_right >= lower_bound * weighted_n_right + (sum_left >= lower_bound * weighted_n_left) & + (sum_right >= lower_bound * weighted_n_right) ) bint check_upper = ( - sum_left <= upper_bound * weighted_n_left & - sum_right <= upper_bound * weighted_n_right + (sum_left <= upper_bound * weighted_n_left) & + (sum_right <= upper_bound * weighted_n_right) ) bint check_monotonic if monotonic_cst == 0: # No constraint @@ -443,8 +443,9 @@ cdef class BestSplitter(BaseDenseSplitter): self.criterion.update(current.pos) # Reject if monotonicity constraints are not satisfied - if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): - continue + with gil: + if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + continue # Reject if min_weight_leaf is not satisfied if ((self.criterion.weighted_n_left < min_weight_leaf) or @@ -669,8 +670,9 @@ cdef class RandomSplitter(BaseDenseSplitter): continue # Reject if monotonicity constraints are not satisfied - if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): - continue + with gil: + if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1227,8 +1229,9 @@ cdef class BestSparseSplitter(BaseSparseSplitter): continue # Reject if monotonicity constraints are not satisfied - if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): - continue + with gil: + if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1464,8 +1467,9 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): continue # Reject if monotonicity constraints are not satisfied - if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): - continue + with gil: + if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + continue # Evaluate split self.criterion.reset() diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 8b9a4696f65e0..eba72400a5e44 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -617,7 +617,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): res.end = end res.depth = depth res.impurity = impurity - node_value = splitter.criterion.sum_total[node_id] / weighted_n_node_samples + with gil: + node_value = splitter.criterion.sum_total[node_id] / weighted_n_node_samples if not is_leaf: # is split node From dea6bda441234d97fb153f2317657c03072851f8 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 18 May 2022 11:08:38 +0200 Subject: [PATCH 093/174] Fix -1 0 1 monotonic_cst check --- sklearn/tree/_classes.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index d75fdf11454cd..6a465df546c8f 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -425,6 +425,16 @@ def fit(self, X, y, sample_weight=None, check_input=True): raise ValueError( "Monotonic constraints are not supported with multiple output" ) + # Applying element-wise logical conjunction + # for monotonic constraints' support. + monotonic_cst = np.asarray(self.monotonic_cst) + unsatisfied_constraints_conditions = ( + (monotonic_cst != -1) * (monotonic_cst != 0) * (monotonic_cst != 1) + ) + if np.any(unsatisfied_constraints_conditions): + raise ValueError( + "monotonic_cst must be None or an array-like of -1, 0 or 1." + ) if is_classifier(self): if self.n_classes_[0] > 2: raise ValueError( @@ -441,15 +451,6 @@ def fit(self, X, y, sample_weight=None, check_input=True): "monotonic_cst has shape {} but the input data " "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) ) - # Applying element-wise logical conjunction - # for monotonic constraints' support. - unsatisfied_constraints_conditions = ( - (monotonic_cst != -1) * (monotonic_cst != 0) * (monotonic_cst != 1) - ) - if np.any(unsatisfied_constraints_conditions): - raise ValueError( - "monotonic_cst must be None or an array-like of -1, 0 or 1." - ) if not isinstance(self.splitter, Splitter): splitter = SPLITTERS[self.splitter]( criterion, From 1c943a3c86226d9e0831978b731642c2138244fb Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 18 May 2022 12:08:46 +0200 Subject: [PATCH 094/174] Move check_monotonicity from Splitter to Criterion --- sklearn/tree/_criterion.pxd | 3 ++ sklearn/tree/_criterion.pyx | 66 +++++++++++++++++++++++++++++++++++++ sklearn/tree/_splitter.pxd | 3 -- sklearn/tree/_splitter.pyx | 45 +++++-------------------- sklearn/tree/_tree.pyx | 2 +- 5 files changed, 79 insertions(+), 40 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index bc78c09b6ff5e..7620aaad91f26 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -55,6 +55,9 @@ cdef class Criterion: double impurity_left, double impurity_right) nogil cdef double proxy_impurity_improvement(self) nogil + cdef double get_sum_total(self, INT32_t node_id) nogil + cdef bint check_monotonicity(self, INT32_t monotonic_cst, + double lower_bound, double upper_bound) nogil cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 680719464e11d..e996cb5661d40 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -192,6 +192,13 @@ cdef class Criterion: - (self.weighted_n_left / self.weighted_n_node_samples * impurity_left))) + cdef double get_sum_total(self, INT32_t node_id) nogil: + pass + + cdef bint check_monotonicity(self, INT32_t monotonic_cst, + double lower_bound, double upper_bound) nogil: + pass + cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" @@ -430,6 +437,36 @@ cdef class ClassificationCriterion(Criterion): memcpy(dest, &self.sum_total[k, 0], self.n_classes[k] * sizeof(double)) dest += self.max_n_classes + cdef inline double get_sum_total(self, INT32_t node_id) nogil: + """Get the weighted sum of the class 0 occurences in node""" + return self.sum_total[node_id][0] + + cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: + """Check monotonic constraint is satisfied at the current classification split""" + cdef: + double sum_left = self.sum_left[0][0] + double sum_right = self.sum_right[0][0] + double weighted_n_left = self.weighted_n_left + double weighted_n_right = self.weighted_n_right + + bint check_lower = ( + (sum_left >= lower_bound * weighted_n_left) & + (sum_right >= lower_bound * weighted_n_right) + ) + bint check_upper = ( + (sum_left <= upper_bound * weighted_n_left) & + (sum_right <= upper_bound * weighted_n_right) + ) + bint check_monotonic + if monotonic_cst == 0: # No constraint + return check_lower & check_upper + else: + check_monotonic = ( + (sum_left * weighted_n_right - + sum_right * weighted_n_left) * monotonic_cst <= 0 + ) + return check_lower & check_upper & check_monotonic + cdef class Entropy(ClassificationCriterion): r"""Cross Entropy impurity criterion. @@ -769,6 +806,35 @@ cdef class RegressionCriterion(Criterion): for k in range(self.n_outputs): dest[k] = self.sum_total[k] / self.weighted_n_node_samples + cdef inline double get_sum_total(self, INT32_t node_id) nogil: + """Get the weighted sum of the target values in node""" + return self.sum_total[node_id] + + cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: + """Check monotonic constraint is satisfied at the current regression split""" + cdef: + double sum_left = self.sum_left[0] + double sum_right = self.sum_right[0] + double weighted_n_left = self.weighted_n_left + double weighted_n_right = self.weighted_n_right + + bint check_lower = ( + (sum_left >= lower_bound * weighted_n_left) & + (sum_right >= lower_bound * weighted_n_right) + ) + bint check_upper = ( + (sum_left <= upper_bound * weighted_n_left) & + (sum_right <= upper_bound * weighted_n_right) + ) + bint check_monotonic + if monotonic_cst == 0: # No constraint + return check_lower & check_upper + else: + check_monotonic = ( + (sum_left * weighted_n_right - + sum_right * weighted_n_left) * monotonic_cst <= 0 + ) + return check_lower & check_upper & check_monotonic cdef class MSE(RegressionCriterion): """Mean squared error impurity criterion. diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index a93b4c04b17e1..62245e9f43c95 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -99,6 +99,3 @@ cdef class Splitter: cdef void node_value(self, double* dest) nogil cdef double node_impurity(self) nogil - - cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, - double lower_bound, double upper_bound) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 977b51ecfb4e8..d99a461594cc7 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -21,6 +21,7 @@ from libc.string cimport memset from cython.operator cimport dereference as deref import numpy as np +cimport numpy as cnp from scipy.sparse import csc_matrix @@ -56,7 +57,7 @@ cdef class Splitter: def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, np.ndarray[INT32_t] monotonic_cst): + object random_state, cnp.ndarray[INT32_t] monotonic_cst): """ Parameters ---------- @@ -225,32 +226,6 @@ cdef class Splitter: return self.criterion.node_impurity() - cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound): - """Check monotonic constraint is satisfied at the current split""" - cdef: - double sum_left = self.criterion.sum_left - double sum_right = self.criterion.sum_right - double weighted_n_left = self.criterion.weighted_n_left - double weighted_n_right = self.criterion.weighted_n_right - - bint check_lower = ( - (sum_left >= lower_bound * weighted_n_left) & - (sum_right >= lower_bound * weighted_n_right) - ) - bint check_upper = ( - (sum_left <= upper_bound * weighted_n_left) & - (sum_right <= upper_bound * weighted_n_right) - ) - bint check_monotonic - if monotonic_cst == 0: # No constraint - return check_lower & check_upper - else: - check_monotonic = ( - (sum_left * weighted_n_right - - sum_right * weighted_n_left) * monotonic_cst <= 0 - ) - return check_lower & check_upper & check_monotonic - cdef class BaseDenseSplitter(Splitter): cdef const DTYPE_t[:, :] X @@ -427,7 +402,7 @@ cdef class BestSplitter(BaseDenseSplitter): # Reject if monotonicity constraints are not satisfied with gil: - if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): continue # Reject if min_weight_leaf is not satisfied @@ -654,7 +629,7 @@ cdef class RandomSplitter(BaseDenseSplitter): # Reject if monotonicity constraints are not satisfied with gil: - if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -712,7 +687,7 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, np.ndarray[INT32_t] monotonic_cst): + object random_state, cnp.ndarray[INT32_t] monotonic_cst): # Parent __cinit__ is automatically called self.n_total_samples = 0 self.monotonic_cst = monotonic_cst.data @@ -1194,9 +1169,8 @@ cdef class BestSparseSplitter(BaseSparseSplitter): continue # Reject if monotonicity constraints are not satisfied - with gil: - if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): - continue + if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1422,9 +1396,8 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): continue # Reject if monotonicity constraints are not satisfied - with gil: - if not self.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): - continue + if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + continue # Evaluate split self.criterion.reset() diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 2f827d4ec6b79..a60e5c83e5491 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -618,7 +618,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): res.depth = depth res.impurity = impurity with gil: - node_value = splitter.criterion.sum_total[node_id] / weighted_n_node_samples + node_value = splitter.criterion.get_sum_total(node_id) / weighted_n_node_samples if not is_leaf: # is split node From b0a03175cd2a42119b35eb36bcaf97397c98da51 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 18 May 2022 16:23:09 +0200 Subject: [PATCH 095/174] Fix premature call to check_monotonicity in RandomSparseSplitter --- sklearn/tree/_splitter.pyx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index d99a461594cc7..cc6a7896afac5 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -1395,10 +1395,6 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): ((end - current.pos) < min_samples_leaf)): continue - # Reject if monotonicity constraints are not satisfied - if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): - continue - # Evaluate split self.criterion.reset() self.criterion.update(current.pos) @@ -1408,6 +1404,10 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): (self.criterion.weighted_n_right < min_weight_leaf)): continue + # Reject if monotonicity constraints are not satisfied + if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + continue + current_proxy_improvement = self.criterion.proxy_impurity_improvement() if current_proxy_improvement > best_proxy_improvement: From b784e045c2d2901fec896f452a0295e9b0b4117c Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 18 May 2022 16:28:45 +0200 Subject: [PATCH 096/174] Remove unnecessary GIL-taking --- sklearn/tree/_splitter.pyx | 10 ++++------ sklearn/tree/_tree.pyx | 3 +-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index cc6a7896afac5..e472cd362e902 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -401,9 +401,8 @@ cdef class BestSplitter(BaseDenseSplitter): self.criterion.update(current.pos) # Reject if monotonicity constraints are not satisfied - with gil: - if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): - continue + if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + continue # Reject if min_weight_leaf is not satisfied if ((self.criterion.weighted_n_left < min_weight_leaf) or @@ -628,9 +627,8 @@ cdef class RandomSplitter(BaseDenseSplitter): continue # Reject if monotonicity constraints are not satisfied - with gil: - if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): - continue + if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index a60e5c83e5491..a122212057206 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -617,8 +617,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): res.end = end res.depth = depth res.impurity = impurity - with gil: - node_value = splitter.criterion.get_sum_total(node_id) / weighted_n_node_samples + node_value = splitter.criterion.get_sum_total(node_id) / weighted_n_node_samples if not is_leaf: # is split node From f68a27b4508e1030077d751cf93231c52396f12b Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 18 May 2022 16:53:35 +0200 Subject: [PATCH 097/174] Fix formatting --- sklearn/tree/_classes.py | 2 +- sklearn/tree/_utils.pyx | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 6a465df546c8f..2fc38b624ae7b 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -429,7 +429,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): # for monotonic constraints' support. monotonic_cst = np.asarray(self.monotonic_cst) unsatisfied_constraints_conditions = ( - (monotonic_cst != -1) * (monotonic_cst != 0) * (monotonic_cst != 1) + (monotonic_cst != -1) * (monotonic_cst != 0) * (monotonic_cst != 1) ) if np.any(unsatisfied_constraints_conditions): raise ValueError( diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 9f522f5650f06..7346070b7a149 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -72,7 +72,6 @@ cdef inline double rand_uniform(double low, double high, cdef inline double log(double x) nogil: return ln(x) / ln(2.0) - # ============================================================================= # WeightedPQueue data structure # ============================================================================= From 8b795e710f62fd7e63567689d9d589f1326f7898 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 20 Jul 2022 16:09:56 +0200 Subject: [PATCH 098/174] Add monotonic_cst to parameter contraints for decision tree models --- sklearn/tree/_classes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 7eaf129a1fa04..f118f0277866b 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -120,6 +120,7 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): "max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None], "min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")], "ccp_alpha": [Interval(Real, 0.0, None, closed="left")], + "monotonic_cst": ["array-like", None] } @abstractmethod From 279ab6ab99a80eb292deae0bd0c469f3aed1878d Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 20 Jul 2022 16:42:19 +0200 Subject: [PATCH 099/174] Misc formatting --- sklearn/ensemble/_forest.py | 1 - sklearn/ensemble/_gb.py | 1 + sklearn/tree/_classes.py | 10 +++------- sklearn/tree/_splitter.pyx | 1 + sklearn/tree/_tree.pyx | 3 --- 5 files changed, 5 insertions(+), 11 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 17d76c83fdf0d..43d3c84553cbd 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1403,7 +1403,6 @@ def __init__( max_samples=None, monotonic_cst=None, ): - super().__init__( base_estimator=DecisionTreeClassifier(), n_estimators=n_estimators, diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 494d4a4592320..1b2e7346037e7 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -263,6 +263,7 @@ def _fit_stage( k=k, perform_line_search=perform_line_search, ) + # add tree to ensemble self.estimators_[i, k] = tree diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index f118f0277866b..4ad242e656e0f 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -120,7 +120,7 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): "max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None], "min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")], "ccp_alpha": [Interval(Real, 0.0, None, closed="left")], - "monotonic_cst": ["array-like", None] + "monotonic_cst": ["array-like", None], } @abstractmethod @@ -141,7 +141,6 @@ def __init__( ccp_alpha=0.0, monotonic_cst=None, ): - self.criterion = criterion self.splitter = splitter self.max_depth = max_depth @@ -368,7 +367,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): ) # Applying element-wise logical conjunction # for monotonic constraints' support. - monotonic_cst = np.asarray(self.monotonic_cst) + monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) unsatisfied_constraints_conditions = ( (monotonic_cst != -1) * (monotonic_cst != 0) * (monotonic_cst != 1) ) @@ -383,9 +382,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): "classification" ) # Imposing the constraint on the probability of the positive class - monotonic_cst = -np.asarray(self.monotonic_cst, dtype=np.int32) - else: - monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) + monotonic_cst *= -1 if monotonic_cst.shape[0] != X.shape[1]: raise ValueError( @@ -1658,7 +1655,6 @@ def __init__( ccp_alpha=0.0, monotonic_cst=None, ): - super().__init__( criterion=criterion, splitter=splitter, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index f16896310ae5c..6a68f4f7a7bf4 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -225,6 +225,7 @@ cdef class Splitter: return self.criterion.node_impurity() + cdef class BaseDenseSplitter(Splitter): cdef const DTYPE_t[:, :] X diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index a122212057206..e067acd1393d5 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -209,7 +209,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): with nogil: # push root node onto stack - builder_stack.push({ "start": 0, "end": n_node_samples, @@ -336,7 +335,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # Best first builder ---------------------------------------------------------- - cdef struct FrontierRecord: # Record of information of a Node, the frontier for a split. Those records are # maintained in a heap to access the Node with the best improvement in impurity, @@ -410,7 +408,6 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Recursive partition (without actual recursion) splitter.init(X, y, sample_weight_ptr) - cdef vector[FrontierRecord] frontier cdef FrontierRecord record cdef FrontierRecord split_node_left From 79d345bc4364afc9f4a07ebbc5d7e580e263548f Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 20 Jul 2022 17:43:26 +0200 Subject: [PATCH 100/174] Fix early casting to int beaking test when input is not int --- sklearn/tree/_classes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 4ad242e656e0f..83e04f5cb0bc5 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -367,7 +367,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): ) # Applying element-wise logical conjunction # for monotonic constraints' support. - monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) + monotonic_cst = np.asarray(self.monotonic_cst) unsatisfied_constraints_conditions = ( (monotonic_cst != -1) * (monotonic_cst != 0) * (monotonic_cst != 1) ) @@ -375,6 +375,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): raise ValueError( "monotonic_cst must be None or an array-like of -1, 0 or 1." ) + monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) if is_classifier(self): if self.n_classes_[0] > 2: raise ValueError( From 0367f26a79a6e1e9e076689600db798cc2e2dfc6 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 20 Jul 2022 17:52:02 +0200 Subject: [PATCH 101/174] Replace paramaterized seed by global_random_seed fixture --- sklearn/tree/tests/test_monotonic_tree.py | 28 ++++++++++------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index c0f6298a425b7..0891653c89500 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -8,9 +8,8 @@ from sklearn.tree.tests.test_tree import REG_TREES, CLF_TREES -@pytest.mark.parametrize("seed", range(4)) @pytest.mark.parametrize("depth_first", (True, False)) -def test_montonic_constraints_classifications(seed, depth_first): +def test_montonic_constraints_classifications(depth_first, global_random_seed): n_samples = 1000 n_samples_train = 900 X, y = make_classification( @@ -47,7 +46,7 @@ def test_montonic_constraints_classifications(seed, depth_first): max_leaf_nodes=n_samples_train, ) if hasattr(est, "random_state"): - est.set_params(**{"random_state": seed}) + est.set_params(**{"random_state": global_random_seed}) if hasattr(est, "n_estimators"): est.set_params(**{"n_estimators": 5}) est.fit(X_train, y_train) @@ -64,9 +63,8 @@ def test_montonic_constraints_classifications(seed, depth_first): assert np.all(y_decr <= y) -@pytest.mark.parametrize("seed", range(4)) @pytest.mark.parametrize("depth_first", (True, False)) -def test_montonic_constraints_regressions(seed, depth_first): +def test_montonic_constraints_regressions(depth_first, global_random_seed): n_samples = 1000 n_samples_train = 900 # Build a classification task using 3 informative features @@ -98,7 +96,7 @@ def test_montonic_constraints_regressions(seed, depth_first): max_leaf_nodes=n_samples_train, ) if hasattr(est, "random_state"): - est.set_params(random_state=seed) + est.set_params(random_state=global_random_seed) if hasattr(est, "n_estimators"): est.set_params(**{"n_estimators": 5}) est.fit(X_train, y_train) @@ -201,8 +199,7 @@ def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps): @pytest.mark.parametrize("monotonic_sign", (-1, 1)) @pytest.mark.parametrize("splitter", ("best", "random")) @pytest.mark.parametrize("depth_first", (True, False)) -@pytest.mark.parametrize("seed", range(4)) -def test_1d_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): +def test_1d_tree_nodes_values(monotonic_sign, splitter, depth_first, global_random_seed): # Adaptation from test_nodes_values in test_montonic_constraints.py # Build a single tree with only one feature, and make sure the nodes # values respect the monotonic constraints. @@ -218,7 +215,7 @@ def test_1d_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): # # c <= d <= root <= e <= f - rng = np.random.RandomState(seed) + rng = np.random.RandomState(global_random_seed) n_samples = 1000 n_features = 1 X = rng.rand(n_samples, n_features) @@ -227,7 +224,7 @@ def test_1d_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): if depth_first: # No max_leaf_nodes, default depth first tree builder clf = DecisionTreeRegressor( - splitter=splitter, monotonic_cst=[monotonic_sign], random_state=seed + splitter=splitter, monotonic_cst=[monotonic_sign], random_state=global_random_seed ) else: # max_leaf_nodes triggers best first tree builder @@ -235,7 +232,7 @@ def test_1d_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): splitter=splitter, monotonic_cst=[monotonic_sign], max_leaf_nodes=n_samples, - random_state=seed, + random_state=global_random_seed, ) clf.fit(X, y) @@ -285,8 +282,7 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): @pytest.mark.parametrize("monotonic_sign", (-1, 1)) @pytest.mark.parametrize("splitter", ("best", "random")) @pytest.mark.parametrize("depth_first", (True, False)) -@pytest.mark.parametrize("seed", range(4)) -def test_nd_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): +def test_nd_tree_nodes_values(monotonic_sign, splitter, depth_first, global_random_seed): # Build tree with several features, and make sure the nodes # values respect the monotonic constraints. @@ -307,7 +303,7 @@ def test_nd_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): # For iii) we check that each node value is within the proper lower and # upper bounds. - rng = np.random.RandomState(seed) + rng = np.random.RandomState(global_random_seed) n_samples = 10000 n_features = 2 monotonic_cst = [monotonic_sign, 0] @@ -317,7 +313,7 @@ def test_nd_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): if depth_first: # No max_leaf_nodes, default depth first tree builder clf = DecisionTreeRegressor( - splitter=splitter, monotonic_cst=monotonic_cst, random_state=seed + splitter=splitter, monotonic_cst=monotonic_cst, random_state=global_random_seed ) else: # max_leaf_nodes triggers best first tree builder @@ -325,7 +321,7 @@ def test_nd_tree_nodes_values(monotonic_sign, splitter, depth_first, seed): splitter=splitter, monotonic_cst=monotonic_cst, max_leaf_nodes=n_samples, - random_state=seed, + random_state=global_random_seed, ) clf.fit(X, y) assert_nd_reg_tree_children_monotonic_bounded(clf.tree_, monotonic_cst) From c6e955d0d2bbe79d01ac13cc8c08d2b39d3b01a7 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 20 Jul 2022 19:12:09 +0200 Subject: [PATCH 102/174] Comment inline the bounds propagation --- sklearn/tree/_tree.pyx | 48 ++++++++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index e067acd1393d5..301d6df623d75 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -280,22 +280,30 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if not is_leaf: if splitter.monotonic_cst[split.feature] == 0: - # No constraint + # Split on a feature with no monotonicity constraint + # Current bounds must always be propagated to both children + # Splitting criterion will always check them in check_monotonicity left_child_min = lower_bound left_child_max = upper_bound right_child_min = lower_bound right_child_max = upper_bound elif splitter.monotonic_cst[split.feature] == 1: - # Monotonically increasing constraint + # Split on a feature with increasing monotonicity constraint + # Lower bound for left child is the current lower bound left_child_min = lower_bound - left_child_max = middle_value - right_child_min = middle_value + # Upper bound for right child is the current upper bound right_child_max = upper_bound + # Lower bound for right child and upper bound for left child are set to the same value + right_child_min = middle_value + left_child_max = middle_value elif splitter.monotonic_cst[split.feature] == -1: - # Monotonically decreasing constraint - left_child_min = middle_value - left_child_max = upper_bound + # Split on a feature with decreasing monotonicity constraint + # Lower bound for right child is the current lower bound right_child_min = lower_bound + # Upper bound for left child is the current upper bound + left_child_max = upper_bound + # Lower bound for left child and upper bound for right child are set to the same value + left_child_min = middle_value right_child_max = middle_value # Push right child on stack @@ -465,25 +473,33 @@ cdef class BestFirstTreeBuilder(TreeBuilder): else: # Node is expandable + + middle_value = tree._get_middle_value(record.node_id) if splitter.monotonic_cst[node.feature] == 0: - # No constraint + # Split on a feature with no monotonicity constraint + # Current bounds must always be propagated to both children + # Splitting criterion will always check them in check_monotonicity left_child_min = record.lower_bound left_child_max = record.upper_bound right_child_min = record.lower_bound right_child_max = record.upper_bound elif splitter.monotonic_cst[node.feature] == 1: - # Monotonically increasing constraint - middle_value = tree._get_middle_value(record.node_id) + # Split on a feature with increasing monotonicity constraint + # Lower bound for left child is the current lower bound left_child_min = record.lower_bound - left_child_max = middle_value - right_child_min = middle_value + # Upper bound for right child is the current upper bound right_child_max = record.upper_bound + # Lower bound for right child and upper bound for left child and are set to the same value + right_child_min = middle_value + left_child_max = middle_value elif splitter.monotonic_cst[node.feature] == -1: - # Monotonically decreasing constraint - middle_value = tree._get_middle_value(record.node_id) - left_child_min = middle_value - left_child_max = record.upper_bound + # Split on a feature with decreasing monotonicity constraint + # Lower bound for right child is the current lower bound right_child_min = record.lower_bound + # Upper bound for left child is the current upper bound + left_child_max = record.upper_bound + # Lower bound for left child and upper bound for right child and are set to the same value + left_child_min = middle_value right_child_max = middle_value # Decrement number of split nodes available From bde6754378a8c7733c4d9e9c1829225e0f8eda5a Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 20 Jul 2022 19:32:10 +0200 Subject: [PATCH 103/174] Fix format --- sklearn/tree/tests/test_monotonic_tree.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 0891653c89500..32e99527bebb6 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -224,7 +224,9 @@ def test_1d_tree_nodes_values(monotonic_sign, splitter, depth_first, global_rand if depth_first: # No max_leaf_nodes, default depth first tree builder clf = DecisionTreeRegressor( - splitter=splitter, monotonic_cst=[monotonic_sign], random_state=global_random_seed + splitter=splitter, + monotonic_cst=[monotonic_sign], + random_state=global_random_seed, ) else: # max_leaf_nodes triggers best first tree builder @@ -313,7 +315,9 @@ def test_nd_tree_nodes_values(monotonic_sign, splitter, depth_first, global_rand if depth_first: # No max_leaf_nodes, default depth first tree builder clf = DecisionTreeRegressor( - splitter=splitter, monotonic_cst=monotonic_cst, random_state=global_random_seed + splitter=splitter, + monotonic_cst=monotonic_cst, + random_state=global_random_seed, ) else: # max_leaf_nodes triggers best first tree builder From 2d8b025516db2d37639755a047f3befc109168a0 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 20 Jul 2022 19:37:28 +0200 Subject: [PATCH 104/174] Fix format --- sklearn/tree/tests/test_monotonic_tree.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 32e99527bebb6..4f4c1fab45cc9 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -199,7 +199,9 @@ def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps): @pytest.mark.parametrize("monotonic_sign", (-1, 1)) @pytest.mark.parametrize("splitter", ("best", "random")) @pytest.mark.parametrize("depth_first", (True, False)) -def test_1d_tree_nodes_values(monotonic_sign, splitter, depth_first, global_random_seed): +def test_1d_tree_nodes_values( + monotonic_sign, splitter, depth_first, global_random_seed +): # Adaptation from test_nodes_values in test_montonic_constraints.py # Build a single tree with only one feature, and make sure the nodes # values respect the monotonic constraints. @@ -284,7 +286,9 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): @pytest.mark.parametrize("monotonic_sign", (-1, 1)) @pytest.mark.parametrize("splitter", ("best", "random")) @pytest.mark.parametrize("depth_first", (True, False)) -def test_nd_tree_nodes_values(monotonic_sign, splitter, depth_first, global_random_seed): +def test_nd_tree_nodes_values( + monotonic_sign, splitter, depth_first, global_random_seed +): # Build tree with several features, and make sure the nodes # values respect the monotonic constraints. From efb6424974884b8f17abaf772eb4c3f85cf0e525 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 21 Jul 2022 00:03:55 +0200 Subject: [PATCH 105/174] Add example of incorrect tree if bounds and checks are not propagated on splits on features without constraint --- sklearn/tree/tests/test_monotonic_tree.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 4f4c1fab45cc9..71626cc720cf8 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -252,25 +252,38 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): assert tree_.value[i] <= upper_bound[i] assert tree_.value[i] >= lower_bound[i] if feature < 0: - # leaf: nothing to do + # Leaf: nothing to do continue else: i_left = tree_.children_left[i] i_right = tree_.children_right[i] if monotonic_cst[feature] == 0: - # unconstrained feature: propagate bounds down the tree + # Feature without monotonicity constraint: propagate bounds + # down the tree to both children. + # Otherwise with 2 features and a POS constraint on feature 0 + # the following tree can be accepted, although it does not + # respect the positive monotonicity constraint: + # + # X[0] <= 0 + # value = 100 + # / \ + # X[0] <= -1 X[1] <= 0 + # value = 50 value = 150 + # / \ / \ + # leaf leaf leaf leaf + # value = 25 value = 75 value = 50 value = 250 + upper_bound[i_left] = upper_bound[i] lower_bound[i_left] = lower_bound[i] upper_bound[i_right] = upper_bound[i] lower_bound[i_right] = lower_bound[i] else: - # constrained feature - # check montonicity + # Feature with constraint: check monotonicity assert ( monotonic_cst[feature] * tree_.value[i_left] <= monotonic_cst[feature] * tree_.value[i_right] ) - # update and propagate bounds down the tree + # Update and propagate bounds down the tree to both children. if monotonic_cst[feature] == 1: upper_bound[i_left] = tree_.value[i] lower_bound[i_left] = lower_bound[i] From 896ef6f267354b11773f3f52fede77a231f52d55 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Fri, 22 Jul 2022 17:46:13 +0200 Subject: [PATCH 106/174] Update whats_new entry --- doc/whats_new/v1.1.rst | 15 --------------- doc/whats_new/v1.2.rst | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index fe8c8f4df5062..4c46c0d631f76 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -623,15 +623,6 @@ Changelog :class:`ensemble.ExtraTreesClassifier`. :pr:`20803` by :user:`Brian Sun `. -- |Feature| :class:`ensemble.RandomForestClassifier`, - :class:`ensemble.RandomForestRegressor`, - :class:`ensemble.ExtraTreesClassifier`, - :class:`ensemble.ExtraTreesRegressor`, - :class:`ensemble.GradientBoostingClassifier` and - :class:`ensemble.GradientBoostingRegressor` now support monotonic - constraints, useful when features are supposed to have a positive/negative - effect on the target. :pr:`13649` by :user:`Samuel Ronsin ` and - :user:`Patrick O'Reilly `. - |Efficiency| Improve runtime performance of :class:`ensemble.IsolationForest` by skipping repetitive input checks. :pr:`23149` by :user:`Zhehao Liu `. @@ -1233,12 +1224,6 @@ Changelog removed in version 1.3. Use :func:`utils.metaestimators.available_if` instead. :pr:`22830` by :user:`Jérémie du Boisberranger `. -- |Feature| :class:`tree.DecisionTreeClassifier` and - :class:`tree.DecisionTreeRegressor` now support monotonic - constraints, useful when features are supposed to have a positive/negative - effect on the target. :pr:`13649` by :user:`Samuel Ronsin ` and - :user:`Patrick O'Reilly `. - Code and Documentation Contributors ----------------------------------- diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 7a5e92f5f960d..237ca6a30aa71 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -143,6 +143,16 @@ Changelog :mod:`sklearn.ensemble` ....................... +- |Feature| :class:`ensemble.RandomForestClassifier`, + :class:`ensemble.RandomForestRegressor`, + :class:`ensemble.ExtraTreesClassifier`, + :class:`ensemble.ExtraTreesRegressor`, + :class:`ensemble.GradientBoostingClassifier` and + :class:`ensemble.GradientBoostingRegressor` now support monotonic + constraints, useful when features are supposed to have a positive/negative + effect on the target. :pr:`13649` by :user:`Samuel Ronsin ` and + :user:`Patrick O'Reilly `. + - |Efficiency| Improve runtime performance of :class:`ensemble.IsolationForest` by avoiding data copies. :pr:`23252` by :user:`Zhehao Liu `. @@ -243,6 +253,12 @@ Changelog :mod:`sklearn.tree` ................... +- |Feature| :class:`tree.DecisionTreeClassifier` and + :class:`tree.DecisionTreeRegressor` now support monotonic + constraints, useful when features are supposed to have a positive/negative + effect on the target. :pr:`13649` by :user:`Samuel Ronsin ` and + :user:`Patrick O'Reilly `. + - |Enhancement| :func:`tree.plot_tree`, :func:`tree.export_graphviz` now uses a lower case `x[i]` to represent feature `i`. :pr:`23480` by `Thomas Fan`_. From cc4be1a3f62c7cea260c136b77f7c2105412d812 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Fri, 22 Jul 2022 17:46:51 +0200 Subject: [PATCH 107/174] Reduce dataset size in test --- sklearn/tree/tests/test_monotonic_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 71626cc720cf8..883d6ba6bfaf8 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -323,7 +323,7 @@ def test_nd_tree_nodes_values( # upper bounds. rng = np.random.RandomState(global_random_seed) - n_samples = 10000 + n_samples = 1000 n_features = 2 monotonic_cst = [monotonic_sign, 0] X = rng.rand(n_samples, n_features) From 23d3a2b350f2f136ff6b40a582344cad90785e83 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 7 Sep 2022 07:51:31 +0200 Subject: [PATCH 108/174] MAINT Adapt logic to use the latest NumPy API A newer NumPy C API is now used when possible. The updates the code post-merge to make it use this newer NumPy C API. Relates to: https://github.com/scikit-learn/scikit-learn/pull/24049 --- sklearn/tree/_splitter.pyx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 0d52a811dca96..b745944698570 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -54,7 +54,7 @@ cdef class Splitter: def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, cnp.ndarray[INT32_t] monotonic_cst): + object random_state, INT32_t[:] monotonic_cst): """ Parameters ---------- @@ -93,7 +93,7 @@ cdef class Splitter: self.min_samples_leaf = min_samples_leaf self.min_weight_leaf = min_weight_leaf self.random_state = random_state - self.monotonic_cst = monotonic_cst.data + self.monotonic_cst = &monotonic_cst[0] def __getstate__(self): return {} @@ -792,10 +792,10 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, cnp.ndarray[INT32_t] monotonic_cst): + object random_state, INT32_t[:] monotonic_cst): # Parent __cinit__ is automatically called self.n_total_samples = 0 - self.monotonic_cst = monotonic_cst.data + self.monotonic_cst = &monotonic_cst[0] cdef int init(self, object X, From 69a0dd80749e314ccf18c199d7469002284d30b8 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 7 Sep 2022 10:20:46 +0200 Subject: [PATCH 109/174] Clean up cimports in _splitter.pyx --- sklearn/tree/_splitter.pyx | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index b745944698570..2e25440e08423 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -16,10 +16,7 @@ from ._criterion cimport Criterion from libc.stdlib cimport qsort from libc.string cimport memcpy -from cython.operator cimport dereference as deref - import numpy as np -cimport numpy as cnp from scipy.sparse import csc_matrix From baa1cf6f90a852bae43ef04b55cfc41c8eea4b55 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 7 Sep 2022 10:37:29 +0200 Subject: [PATCH 110/174] Revert hist_gradient_boosting docstring correction to save for a dedicated PR Co-authored-by: Julien Jerphanion --- sklearn/ensemble/_hist_gradient_boosting/grower.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/grower.py b/sklearn/ensemble/_hist_gradient_boosting/grower.py index 65523b511f77c..926071cad6ea9 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/grower.py +++ b/sklearn/ensemble/_hist_gradient_boosting/grower.py @@ -164,7 +164,7 @@ class TreeGrower: is_categorical : ndarray of bool of shape (n_features,), default=None Indicates categorical features. monotonic_cst : array-like of shape (n_features,), dtype=int, default=None - Indicates the monotonic constraint to enforce on each feature. 1, -1 + Indicates the monotonic constraint to enforce on each feature. -1, 1 and 0 respectively correspond to a positive constraint, negative constraint and no constraint. Read more in the :ref:`User Guide `. From 26a69c0a83a66aaf5b72b35f1e3cd982a6b08ed6 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 7 Sep 2022 10:46:09 +0200 Subject: [PATCH 111/174] Address compilation warnings on uninitialised variables Co-authored-by: Julien Jerphanion --- sklearn/tree/_tree.pyx | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index b45f69473d089..6dcbd1fe7e07b 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -294,7 +294,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # Lower bound for right child and upper bound for left child are set to the same value right_child_min = middle_value left_child_max = middle_value - elif splitter.monotonic_cst[split.feature] == -1: + else: # i.e. splitter.monotonic_cst[split.feature] == -1 # Split on a feature with decreasing monotonicity constraint # Lower bound for right child is the current lower bound right_child_min = lower_bound @@ -487,7 +487,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Lower bound for right child and upper bound for left child and are set to the same value right_child_min = middle_value left_child_max = middle_value - elif splitter.monotonic_cst[node.feature] == -1: + else: # i.e. splitter.monotonic_cst[node.feature] == -1 # Split on a feature with decreasing monotonicity constraint # Lower bound for right child is the current lower bound right_child_min = record.lower_bound @@ -1736,7 +1736,6 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT # ... and child_r[node_idx] != _TREE_LEAF: node_indices_stack.push(child_l[node_idx]) node_indices_stack.push(child_r[node_idx]) - leaves_in_subtree[pruned_branch_node_idx] = 1 in_subtree[pruned_branch_node_idx] = 1 From 66a18e0f4e79f6edf2ea17d4e3d44a11db9999ff Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 7 Sep 2022 11:27:29 +0200 Subject: [PATCH 112/174] Use keyword arguments in _add_split_node calls Co-authored-by: Julien Jerphanion --- sklearn/tree/_tree.pyx | 72 +++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 6dcbd1fe7e07b..23fa12df574cd 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -435,18 +435,18 @@ cdef class BestFirstTreeBuilder(TreeBuilder): with nogil: # add root to frontier rc = self._add_split_node( - splitter, # splitter - tree, # tree - 0, # start - n_node_samples, # end - INFINITY, # impurity - IS_FIRST, # is_first - IS_LEFT, # is_left - NULL, # parent - 0, # depth - -INFINITY, # lower_bound - INFINITY, # upper_bound - &split_node_left # res + splitter=splitter, + tree=tree, + start=0, + end=n_node_samples, + impurity=INFINITY, + is_first=IS_FIRST, + is_left=IS_LEFT, + parent=NULL, + depth=0, + lower_bound=-INFINITY, + upper_bound=INFINITY, + res=&split_node_left, ) if rc >= 0: _add_to_frontier(split_node_left, frontier) @@ -502,18 +502,18 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Compute left split node rc = self._add_split_node( - splitter, # splitter - tree, # tree - record.start, # start - record.pos, # end - record.impurity_left, # impurity - IS_NOT_FIRST, # is_first - IS_LEFT, # is_left - node, # parent - record.depth + 1, # depth - left_child_min, # lower_bound - left_child_max, # upper_bound - &split_node_left # res + splitter=splitter, + tree=tree, + start=record.start, + end=record.pos, + impurity=record.impurity_left, + if_first=IS_NOT_FIRST, + is_left=IS_LEFT, + parent=node, + depth=record.depth + 1, + lower_bound=left_child_min, + upper_bound=left_child_max, + res=&split_node_left, ) if rc == -1: break @@ -523,18 +523,18 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Compute right split node rc = self._add_split_node( - splitter, # splitter - tree, # tree - record.pos, # start - record.end, # end - record.impurity_right,# impurity - IS_NOT_FIRST, # is_first - IS_NOT_LEFT, # is_left - node, # parent - record.depth + 1, # depth - right_child_min, # lower_bound - right_child_max, # upper_bound - &split_node_right # res + splitter=splitter, + tree=tree, + start=record.pos, + end=record.end, + impurity=record.impurity_right, + is_first=IS_NOT_FIRST, + is_left=IS_NOT_LEFT, + parent=node, + depth=record.depth + 1, + lower_bound=right_child_min, + upper_bound=right_child_max, + res=&split_node_right, ) if rc == -1: break From bbf5fb28b10db11b9d2f6c35a54ac4c9021cc9b3 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 7 Sep 2022 11:22:39 +0200 Subject: [PATCH 113/174] Qualify Splitter.monotonic_cst as const and switch to MemoryView --- sklearn/tree/_splitter.pxd | 2 +- sklearn/tree/_splitter.pyx | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 62245e9f43c95..367f8db5f1e2e 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -63,7 +63,7 @@ cdef class Splitter: # -1: monotonically decreasing # 0: no constraint # +1: monotonically increasing - cdef INT32_t* monotonic_cst + cdef const INT32_t[:] monotonic_cst # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 2e25440e08423..8d822ca239f4a 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -51,7 +51,7 @@ cdef class Splitter: def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, INT32_t[:] monotonic_cst): + object random_state, const INT32_t[:] monotonic_cst): """ Parameters ---------- @@ -74,7 +74,7 @@ cdef class Splitter: random_state : object The user inputted random state to be used for pseudo-randomness - monotonic_cst : INT32_t* + monotonic_cst : INT32_t[:] Monotonicity constraints """ @@ -90,7 +90,7 @@ cdef class Splitter: self.min_samples_leaf = min_samples_leaf self.min_weight_leaf = min_weight_leaf self.random_state = random_state - self.monotonic_cst = &monotonic_cst[0] + self.monotonic_cst = monotonic_cst def __getstate__(self): return {} @@ -279,6 +279,7 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t min_samples_leaf = self.min_samples_leaf cdef double min_weight_leaf = self.min_weight_leaf cdef UINT32_t* random_state = &self.rand_r_state + cdef INT32_t monotonic_constraint cdef SplitRecord best, current cdef double current_proxy_improvement = -INFINITY @@ -345,7 +346,7 @@ cdef class BestSplitter(BaseDenseSplitter): # f_j in the interval [n_total_constants, f_i[ current.feature = features[f_j] - monotonic_constraint = self.monotonic_cst[current.feature] + monotonic_constraint = self.monotonic_cst[0] # Sort samples along that feature; by # copying the values into an array and # sorting the array in a manner which utilizes the cache more @@ -603,6 +604,7 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef SIZE_t min_samples_leaf = self.min_samples_leaf cdef double min_weight_leaf = self.min_weight_leaf cdef UINT32_t* random_state = &self.rand_r_state + cdef INT32_t monotonic_constraint cdef SplitRecord best, current cdef double current_proxy_improvement = - INFINITY @@ -789,10 +791,10 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, INT32_t[:] monotonic_cst): + object random_state, const INT32_t[:] monotonic_cst): # Parent __cinit__ is automatically called self.n_total_samples = 0 - self.monotonic_cst = &monotonic_cst[0] + self.monotonic_cst = monotonic_cst cdef int init(self, object X, @@ -1120,6 +1122,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): cdef SIZE_t min_samples_leaf = self.min_samples_leaf cdef double min_weight_leaf = self.min_weight_leaf cdef UINT32_t* random_state = &self.rand_r_state + cdef INT32_t monotonic_constraint cdef SplitRecord best, current _init_split(&best, end) @@ -1357,6 +1360,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): cdef SIZE_t min_samples_leaf = self.min_samples_leaf cdef double min_weight_leaf = self.min_weight_leaf cdef UINT32_t* random_state = &self.rand_r_state + cdef INT32_t monotonic_constraint cdef SplitRecord best, current _init_split(&best, end) From 96923b689162eee5462a2f7c111f195194316576 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 7 Sep 2022 11:46:23 +0200 Subject: [PATCH 114/174] Revert "Use keyword arguments in _add_split_node calls" This reverts commit 66a18e0f4e79f6edf2ea17d4e3d44a11db9999ff. --- sklearn/tree/_tree.pyx | 72 +++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 23fa12df574cd..6dcbd1fe7e07b 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -435,18 +435,18 @@ cdef class BestFirstTreeBuilder(TreeBuilder): with nogil: # add root to frontier rc = self._add_split_node( - splitter=splitter, - tree=tree, - start=0, - end=n_node_samples, - impurity=INFINITY, - is_first=IS_FIRST, - is_left=IS_LEFT, - parent=NULL, - depth=0, - lower_bound=-INFINITY, - upper_bound=INFINITY, - res=&split_node_left, + splitter, # splitter + tree, # tree + 0, # start + n_node_samples, # end + INFINITY, # impurity + IS_FIRST, # is_first + IS_LEFT, # is_left + NULL, # parent + 0, # depth + -INFINITY, # lower_bound + INFINITY, # upper_bound + &split_node_left # res ) if rc >= 0: _add_to_frontier(split_node_left, frontier) @@ -502,18 +502,18 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Compute left split node rc = self._add_split_node( - splitter=splitter, - tree=tree, - start=record.start, - end=record.pos, - impurity=record.impurity_left, - if_first=IS_NOT_FIRST, - is_left=IS_LEFT, - parent=node, - depth=record.depth + 1, - lower_bound=left_child_min, - upper_bound=left_child_max, - res=&split_node_left, + splitter, # splitter + tree, # tree + record.start, # start + record.pos, # end + record.impurity_left, # impurity + IS_NOT_FIRST, # is_first + IS_LEFT, # is_left + node, # parent + record.depth + 1, # depth + left_child_min, # lower_bound + left_child_max, # upper_bound + &split_node_left # res ) if rc == -1: break @@ -523,18 +523,18 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Compute right split node rc = self._add_split_node( - splitter=splitter, - tree=tree, - start=record.pos, - end=record.end, - impurity=record.impurity_right, - is_first=IS_NOT_FIRST, - is_left=IS_NOT_LEFT, - parent=node, - depth=record.depth + 1, - lower_bound=right_child_min, - upper_bound=right_child_max, - res=&split_node_right, + splitter, # splitter + tree, # tree + record.pos, # start + record.end, # end + record.impurity_right,# impurity + IS_NOT_FIRST, # is_first + IS_NOT_LEFT, # is_left + node, # parent + record.depth + 1, # depth + right_child_min, # lower_bound + right_child_max, # upper_bound + &split_node_right # res ) if rc == -1: break From e06806afcc55d4be2b843cd3d67779d0889a29a5 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 7 Sep 2022 12:48:01 +0200 Subject: [PATCH 115/174] Factorize common regression and classification logic of constraints checks into inlined private method --- sklearn/tree/_criterion.pxd | 2 ++ sklearn/tree/_criterion.pyx | 66 +++++++++++++++---------------------- 2 files changed, 29 insertions(+), 39 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 7620aaad91f26..482d41e53903e 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -58,6 +58,8 @@ cdef class Criterion: cdef double get_sum_total(self, INT32_t node_id) nogil cdef bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil + cdef inline bint _check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound, + double sum_left, double sum_right) nogil cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index c94090491ce5e..6d5e07b8e1036 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -196,9 +196,33 @@ cdef class Criterion: pass cdef bint check_monotonicity(self, INT32_t monotonic_cst, - double lower_bound, double upper_bound) nogil: + double lower_bound, double upper_bound) nogil: pass + cdef inline bint _check_monotonicity(self, INT32_t monotonic_cst, + double lower_bound, double upper_bound, + double sum_left, double sum_right) nogil: + cdef: + double weighted_n_left = self.weighted_n_left + double weighted_n_right = self.weighted_n_right + bint check_lower_bound = ( + (sum_left >= lower_bound * weighted_n_left) & + (sum_right >= lower_bound * weighted_n_right) + ) + bint check_upper_bound = ( + (sum_left <= upper_bound * weighted_n_left) & + (sum_right <= upper_bound * weighted_n_right) + ) + bint check_monotonic_cst + if monotonic_cst == 0: # No constraint + return check_lower_bound & check_upper_bound + else: + check_monotonic_cst = ( + (sum_left * weighted_n_right - + sum_right * weighted_n_left) * monotonic_cst <= 0 + ) + return check_lower_bound & check_upper_bound & check_monotonic_cst + cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" @@ -445,26 +469,8 @@ cdef class ClassificationCriterion(Criterion): cdef: double sum_left = self.sum_left[0][0] double sum_right = self.sum_right[0][0] - double weighted_n_left = self.weighted_n_left - double weighted_n_right = self.weighted_n_right - bint check_lower = ( - (sum_left >= lower_bound * weighted_n_left) & - (sum_right >= lower_bound * weighted_n_right) - ) - bint check_upper = ( - (sum_left <= upper_bound * weighted_n_left) & - (sum_right <= upper_bound * weighted_n_right) - ) - bint check_monotonic - if monotonic_cst == 0: # No constraint - return check_lower & check_upper - else: - check_monotonic = ( - (sum_left * weighted_n_right - - sum_right * weighted_n_left) * monotonic_cst <= 0 - ) - return check_lower & check_upper & check_monotonic + return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, sum_left, sum_right) cdef class Entropy(ClassificationCriterion): @@ -814,26 +820,8 @@ cdef class RegressionCriterion(Criterion): cdef: double sum_left = self.sum_left[0] double sum_right = self.sum_right[0] - double weighted_n_left = self.weighted_n_left - double weighted_n_right = self.weighted_n_right - bint check_lower = ( - (sum_left >= lower_bound * weighted_n_left) & - (sum_right >= lower_bound * weighted_n_right) - ) - bint check_upper = ( - (sum_left <= upper_bound * weighted_n_left) & - (sum_right <= upper_bound * weighted_n_right) - ) - bint check_monotonic - if monotonic_cst == 0: # No constraint - return check_lower & check_upper - else: - check_monotonic = ( - (sum_left * weighted_n_right - - sum_right * weighted_n_left) * monotonic_cst <= 0 - ) - return check_lower & check_upper & check_monotonic + return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, sum_left, sum_right) cdef class MSE(RegressionCriterion): """Mean squared error impurity criterion. From 7bfb603bcf1f068e4f013a858f7a71394b5da13c Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 7 Sep 2022 13:55:46 +0200 Subject: [PATCH 116/174] Revert "Qualify Splitter.monotonic_cst as const and switch to MemoryView" This reverts commit bbf5fb28b10db11b9d2f6c35a54ac4c9021cc9b3. --- sklearn/tree/_splitter.pxd | 2 +- sklearn/tree/_splitter.pyx | 16 ++++++---------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 367f8db5f1e2e..62245e9f43c95 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -63,7 +63,7 @@ cdef class Splitter: # -1: monotonically decreasing # 0: no constraint # +1: monotonically increasing - cdef const INT32_t[:] monotonic_cst + cdef INT32_t* monotonic_cst # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 8d822ca239f4a..2e25440e08423 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -51,7 +51,7 @@ cdef class Splitter: def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, const INT32_t[:] monotonic_cst): + object random_state, INT32_t[:] monotonic_cst): """ Parameters ---------- @@ -74,7 +74,7 @@ cdef class Splitter: random_state : object The user inputted random state to be used for pseudo-randomness - monotonic_cst : INT32_t[:] + monotonic_cst : INT32_t* Monotonicity constraints """ @@ -90,7 +90,7 @@ cdef class Splitter: self.min_samples_leaf = min_samples_leaf self.min_weight_leaf = min_weight_leaf self.random_state = random_state - self.monotonic_cst = monotonic_cst + self.monotonic_cst = &monotonic_cst[0] def __getstate__(self): return {} @@ -279,7 +279,6 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t min_samples_leaf = self.min_samples_leaf cdef double min_weight_leaf = self.min_weight_leaf cdef UINT32_t* random_state = &self.rand_r_state - cdef INT32_t monotonic_constraint cdef SplitRecord best, current cdef double current_proxy_improvement = -INFINITY @@ -346,7 +345,7 @@ cdef class BestSplitter(BaseDenseSplitter): # f_j in the interval [n_total_constants, f_i[ current.feature = features[f_j] - monotonic_constraint = self.monotonic_cst[0] + monotonic_constraint = self.monotonic_cst[current.feature] # Sort samples along that feature; by # copying the values into an array and # sorting the array in a manner which utilizes the cache more @@ -604,7 +603,6 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef SIZE_t min_samples_leaf = self.min_samples_leaf cdef double min_weight_leaf = self.min_weight_leaf cdef UINT32_t* random_state = &self.rand_r_state - cdef INT32_t monotonic_constraint cdef SplitRecord best, current cdef double current_proxy_improvement = - INFINITY @@ -791,10 +789,10 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, const INT32_t[:] monotonic_cst): + object random_state, INT32_t[:] monotonic_cst): # Parent __cinit__ is automatically called self.n_total_samples = 0 - self.monotonic_cst = monotonic_cst + self.monotonic_cst = &monotonic_cst[0] cdef int init(self, object X, @@ -1122,7 +1120,6 @@ cdef class BestSparseSplitter(BaseSparseSplitter): cdef SIZE_t min_samples_leaf = self.min_samples_leaf cdef double min_weight_leaf = self.min_weight_leaf cdef UINT32_t* random_state = &self.rand_r_state - cdef INT32_t monotonic_constraint cdef SplitRecord best, current _init_split(&best, end) @@ -1360,7 +1357,6 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): cdef SIZE_t min_samples_leaf = self.min_samples_leaf cdef double min_weight_leaf = self.min_weight_leaf cdef UINT32_t* random_state = &self.rand_r_state - cdef INT32_t monotonic_constraint cdef SplitRecord best, current _init_split(&best, end) From 8065ab785c08964af540a0b797ef51f67818ffaa Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 7 Sep 2022 14:06:10 +0200 Subject: [PATCH 117/174] Qualify Splitter.monotonic_cst as const --- sklearn/tree/_splitter.pxd | 2 +- sklearn/tree/_splitter.pyx | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 62245e9f43c95..c1c680734f8d9 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -63,7 +63,7 @@ cdef class Splitter: # -1: monotonically decreasing # 0: no constraint # +1: monotonically increasing - cdef INT32_t* monotonic_cst + cdef const INT32_t* monotonic_cst # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 2e25440e08423..44abfeaad74ff 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -51,7 +51,7 @@ cdef class Splitter: def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, INT32_t[:] monotonic_cst): + object random_state, const INT32_t[:] monotonic_cst): """ Parameters ---------- @@ -789,7 +789,7 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, INT32_t[:] monotonic_cst): + object random_state, const INT32_t[:] monotonic_cst): # Parent __cinit__ is automatically called self.n_total_samples = 0 self.monotonic_cst = &monotonic_cst[0] From c54eab9edd9d6158fa52d8fe5266f74c365c7f77 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 8 Sep 2022 11:05:33 +0200 Subject: [PATCH 118/174] Use MemoryViews for monotonic_cst instead of pointers --- sklearn/tree/_splitter.pxd | 2 +- sklearn/tree/_splitter.pyx | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index c1c680734f8d9..367f8db5f1e2e 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -63,7 +63,7 @@ cdef class Splitter: # -1: monotonically decreasing # 0: no constraint # +1: monotonically increasing - cdef const INT32_t* monotonic_cst + cdef const INT32_t[:] monotonic_cst # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 44abfeaad74ff..b4e202e079ec9 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -74,7 +74,7 @@ cdef class Splitter: random_state : object The user inputted random state to be used for pseudo-randomness - monotonic_cst : INT32_t* + monotonic_cst : INT32_t[:] Monotonicity constraints """ @@ -90,7 +90,7 @@ cdef class Splitter: self.min_samples_leaf = min_samples_leaf self.min_weight_leaf = min_weight_leaf self.random_state = random_state - self.monotonic_cst = &monotonic_cst[0] + self.monotonic_cst = monotonic_cst def __getstate__(self): return {} @@ -299,6 +299,8 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t n_total_constants = n_known_constants cdef SIZE_t partition_end + cdef INT32_t monotonic_constraint + _init_split(&best, end) # Sample up to max_features without replacement using a @@ -624,6 +626,8 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef DTYPE_t max_feature_value cdef DTYPE_t current_feature_value + cdef INT32_t monotonic_constraint + _init_split(&best, end) # Sample up to max_features without replacement using a @@ -792,7 +796,7 @@ cdef class BaseSparseSplitter(Splitter): object random_state, const INT32_t[:] monotonic_cst): # Parent __cinit__ is automatically called self.n_total_samples = 0 - self.monotonic_cst = &monotonic_cst[0] + self.monotonic_cst = monotonic_cst cdef int init(self, object X, @@ -1147,6 +1151,8 @@ cdef class BestSparseSplitter(BaseSparseSplitter): cdef SIZE_t start_positive cdef SIZE_t end_negative + cdef INT32_t monotonic_constraint + # Sample up to max_features without replacement using a # Fisher-Yates-based algorithm (using the local variables `f_i` and # `f_j` to compute a permutation of the `features` array). @@ -1387,6 +1393,8 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): cdef SIZE_t start_positive cdef SIZE_t end_negative + cdef INT32_t monotonic_constraint + # Sample up to max_features without replacement using a # Fisher-Yates-based algorithm (using the local variables `f_i` and # `f_j` to compute a permutation of the `features` array). From eb4ee0fdc808bd25ab184a688581b902e9167a58 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Thu, 8 Sep 2022 11:18:19 +0200 Subject: [PATCH 119/174] Use keyword arguments in _add_split_node calls Co-authored-by: Julien Jerphanion --- sklearn/tree/_tree.pyx | 99 +++++++++++++++++++++--------------------- 1 file changed, 50 insertions(+), 49 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 6dcbd1fe7e07b..e73336aa01a12 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -435,18 +435,18 @@ cdef class BestFirstTreeBuilder(TreeBuilder): with nogil: # add root to frontier rc = self._add_split_node( - splitter, # splitter - tree, # tree - 0, # start - n_node_samples, # end - INFINITY, # impurity - IS_FIRST, # is_first - IS_LEFT, # is_left - NULL, # parent - 0, # depth - -INFINITY, # lower_bound - INFINITY, # upper_bound - &split_node_left # res + splitter=splitter, + tree=tree, + start=0, + end=n_node_samples, + impurity=INFINITY, + is_first=IS_FIRST, + is_left=IS_LEFT, + parent=NULL, + depth=0, + lower_bound=-INFINITY, + upper_bound=INFINITY, + res=&split_node_left, ) if rc >= 0: _add_to_frontier(split_node_left, frontier) @@ -502,18 +502,18 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Compute left split node rc = self._add_split_node( - splitter, # splitter - tree, # tree - record.start, # start - record.pos, # end - record.impurity_left, # impurity - IS_NOT_FIRST, # is_first - IS_LEFT, # is_left - node, # parent - record.depth + 1, # depth - left_child_min, # lower_bound - left_child_max, # upper_bound - &split_node_left # res + splitter=splitter, + tree=tree, + start=record.start, + end=record.pos, + impurity=record.impurity_left, + is_first=IS_NOT_FIRST, + is_left=IS_LEFT, + parent=node, + depth=record.depth + 1, + lower_bound=left_child_min, + upper_bound=left_child_max, + res=&split_node_left, ) if rc == -1: break @@ -523,18 +523,18 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Compute right split node rc = self._add_split_node( - splitter, # splitter - tree, # tree - record.pos, # start - record.end, # end - record.impurity_right,# impurity - IS_NOT_FIRST, # is_first - IS_NOT_LEFT, # is_left - node, # parent - record.depth + 1, # depth - right_child_min, # lower_bound - right_child_max, # upper_bound - &split_node_right # res + splitter=splitter, + tree=tree, + start=record.pos, + end=record.end, + impurity=record.impurity_right, + is_first=IS_NOT_FIRST, + is_left=IS_NOT_LEFT, + parent=node, + depth=record.depth + 1, + lower_bound=right_child_min, + upper_bound=right_child_max, + res=&split_node_right, ) if rc == -1: break @@ -556,19 +556,20 @@ cdef class BestFirstTreeBuilder(TreeBuilder): raise MemoryError() cdef inline int _add_split_node( - self, - Splitter splitter, - Tree tree, - SIZE_t start, - SIZE_t end, - double impurity, - bint is_first, - bint is_left, - Node* parent, - SIZE_t depth, - double lower_bound, - double upper_bound, - FrontierRecord* res) nogil except -1: + self, + Splitter splitter, + Tree tree, + SIZE_t start, + SIZE_t end, + double impurity, + bint is_first, + bint is_left, + Node* parent, + SIZE_t depth, + double lower_bound, + double upper_bound, + FrontierRecord* res + ) nogil except -1: """Adds node w/ partition ``[start, end)`` to the frontier. """ cdef SplitRecord split cdef SIZE_t node_id From 1cbc5d8a295238a1865d4beb83a3d3ff48cf16be Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 9 Sep 2022 18:54:25 +0200 Subject: [PATCH 120/174] CI Retrigger CI due to assert reproducibility of the faulty run From 438d224e316f07ef2f320974eeebef1d106f070c Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 14 Sep 2022 15:35:39 +0200 Subject: [PATCH 121/174] Remove unused get_sum_total Co-authored-by: Samuel O. Ronsin --- sklearn/tree/_criterion.pxd | 1 - sklearn/tree/_criterion.pyx | 11 ----------- sklearn/tree/_tree.pyx | 1 - 3 files changed, 13 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 482d41e53903e..35ee14c5a7024 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -55,7 +55,6 @@ cdef class Criterion: double impurity_left, double impurity_right) nogil cdef double proxy_impurity_improvement(self) nogil - cdef double get_sum_total(self, INT32_t node_id) nogil cdef bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil cdef inline bint _check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound, diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 6d5e07b8e1036..5caa3fd18e9ad 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -192,9 +192,6 @@ cdef class Criterion: - (self.weighted_n_left / self.weighted_n_node_samples * impurity_left))) - cdef double get_sum_total(self, INT32_t node_id) nogil: - pass - cdef bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: pass @@ -460,10 +457,6 @@ cdef class ClassificationCriterion(Criterion): memcpy(dest, &self.sum_total[k, 0], self.n_classes[k] * sizeof(double)) dest += self.max_n_classes - cdef inline double get_sum_total(self, INT32_t node_id) nogil: - """Get the weighted sum of the class 0 occurences in node""" - return self.sum_total[node_id][0] - cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: """Check monotonic constraint is satisfied at the current classification split""" cdef: @@ -811,10 +804,6 @@ cdef class RegressionCriterion(Criterion): for k in range(self.n_outputs): dest[k] = self.sum_total[k] / self.weighted_n_node_samples - cdef inline double get_sum_total(self, INT32_t node_id) nogil: - """Get the weighted sum of the target values in node""" - return self.sum_total[node_id] - cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: """Check monotonic constraint is satisfied at the current regression split""" cdef: diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index e73336aa01a12..0d3151c35b598 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -623,7 +623,6 @@ cdef class BestFirstTreeBuilder(TreeBuilder): res.end = end res.depth = depth res.impurity = impurity - node_value = splitter.criterion.get_sum_total(node_id) / weighted_n_node_samples if not is_leaf: # is split node From 301fb4e4b8355b61d2796f2dbd7a8c54659d1dd9 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 14 Sep 2022 16:24:11 +0200 Subject: [PATCH 122/174] Replace positive/negative wording by increasing/decreasing and use bullet points in docstring --- sklearn/ensemble/_forest.py | 47 ++++++++++++++++++++----------------- sklearn/ensemble/_gb.py | 18 +++++++------- sklearn/tree/_classes.py | 36 +++++++++++++++------------- 3 files changed, 56 insertions(+), 45 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index f3962e7d071ab..9e48cfa5db1b8 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1270,11 +1270,12 @@ class RandomForestClassifier(ForestClassifier): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. 1, -1 - and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. The constraints are only valid for - binary classifications and hold over the probability of the positive - class. + Indicates the monotonic constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing + The constraints are only valid for binary classifications and hold + over the probability of the positive class. Attributes ---------- @@ -1626,9 +1627,10 @@ class RandomForestRegressor(ForestRegressor): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. 1, -1 - and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. + Indicates the monotonic constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing Attributes ---------- @@ -1977,11 +1979,12 @@ class ExtraTreesClassifier(ForestClassifier): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. 1, -1 - and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. The constraints are only valid for - binary classifications and hold over the probability of the positive - class. + Indicates the monotonic constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing + The constraints are only valid for binary classifications and hold + over the probability of the positive class. Attributes ---------- @@ -2324,9 +2327,10 @@ class ExtraTreesRegressor(ForestRegressor): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. 1, -1 - and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. + Indicates the monotonic constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing Attributes ---------- @@ -2584,11 +2588,12 @@ class RandomTreesEmbedding(TransformerMixin, BaseForest): new forest. See :term:`the Glossary `. monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. 1, -1 - and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. The constraints are only valid for - binary classifications and hold over the probability of the positive - class. + Indicates the monotonic constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing + The constraints are only valid for binary classifications and hold + over the probability of the positive class. Attributes ---------- diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 456760bef324f..a6819d7b56b4e 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -1077,11 +1077,12 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. 1, -1 - and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. The constraints are only valid for - binary classifications and hold over the probability of the positive - class. + Indicates the monotonic constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing + The constraints are only valid for binary classifications and hold + over the probability of the positive class. Attributes ---------- @@ -1670,9 +1671,10 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. 1, -1 - and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. + Indicates the monotonic constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing Attributes ---------- diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 7b02fdad355c2..0eed0a207c76c 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -767,11 +767,12 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. 1, -1 - and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. The constraints are only valid for - binary classifications and hold over the probability of the positive - class. + Indicates the monotonic constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing + The constraints are only valid for binary classifications and hold + over the probability of the positive class. Attributes ---------- @@ -1164,9 +1165,10 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. 1, -1 - and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. + Indicates the monotonic constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing Attributes ---------- @@ -1505,11 +1507,12 @@ class ExtraTreeClassifier(DecisionTreeClassifier): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. 1, -1 - and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. The constraints are only valid for - binary classifications and hold over the probability of the positive - class. + Indicates the monotonic constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing + The constraints are only valid for binary classifications and hold + over the probability of the positive class. Attributes ---------- @@ -1770,9 +1773,10 @@ class ExtraTreeRegressor(DecisionTreeRegressor): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. 1, -1 - and 0 respectively correspond to a positive constraint, negative - constraint and no constraint. + Indicates the monotonic constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing Attributes ---------- From a491f59479b6228f81e40cb841e67f8b39210565 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 14 Sep 2022 17:20:02 +0200 Subject: [PATCH 123/174] Fix docstrings --- sklearn/ensemble/_forest.py | 3 +++ sklearn/ensemble/_gb.py | 1 + 2 files changed, 4 insertions(+) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 9e48cfa5db1b8..21cc877206930 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1274,6 +1274,7 @@ class RandomForestClassifier(ForestClassifier): - 1: monotonically increasing - 0: no constraint - -1: monotonically decreasing + The constraints are only valid for binary classifications and hold over the probability of the positive class. @@ -1983,6 +1984,7 @@ class ExtraTreesClassifier(ForestClassifier): - 1: monotonically increasing - 0: no constraint - -1: monotonically decreasing + The constraints are only valid for binary classifications and hold over the probability of the positive class. @@ -2592,6 +2594,7 @@ class RandomTreesEmbedding(TransformerMixin, BaseForest): - 1: monotonically increasing - 0: no constraint - -1: monotonically decreasing + The constraints are only valid for binary classifications and hold over the probability of the positive class. diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index a6819d7b56b4e..a4fc8c2b44db4 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -1081,6 +1081,7 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): - 1: monotonically increasing - 0: no constraint - -1: monotonically decreasing + The constraints are only valid for binary classifications and hold over the probability of the positive class. From 1c7ac01753a47baff271fb5aa34b86045ad49def Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 14 Sep 2022 17:38:40 +0200 Subject: [PATCH 124/174] Fix docstrings --- sklearn/tree/_classes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 0eed0a207c76c..605e9162a60d5 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -771,6 +771,7 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): - 1: monotonically increasing - 0: no constraint - -1: monotonically decreasing + The constraints are only valid for binary classifications and hold over the probability of the positive class. @@ -1511,6 +1512,7 @@ class ExtraTreeClassifier(DecisionTreeClassifier): - 1: monotonically increasing - 0: no constraint - -1: monotonically decreasing + The constraints are only valid for binary classifications and hold over the probability of the positive class. From d65c437bb7d6e42dd511af51f44e9a6a8521c790 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Oct 2022 11:34:59 +0200 Subject: [PATCH 125/174] Remove monotonic_cst for RandomTreesEmbedding --- sklearn/ensemble/_forest.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 6deec3e0a855b..4d9ad1254eec0 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -2600,15 +2600,6 @@ class RandomTreesEmbedding(TransformerMixin, BaseForest): and add more estimators to the ensemble, otherwise, just fit a whole new forest. See :term:`the Glossary `. - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. - - 1: monotonically increasing - - 0: no constraint - - -1: monotonically decreasing - - The constraints are only valid for binary classifications and hold - over the probability of the positive class. - Attributes ---------- estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor` instance @@ -2691,7 +2682,7 @@ class RandomTreesEmbedding(TransformerMixin, BaseForest): **BaseDecisionTree._parameter_constraints, "sparse_output": ["boolean"], } - for param in ("max_features", "ccp_alpha", "splitter"): + for param in ("max_features", "ccp_alpha", "splitter", "monotonic_cst"): _parameter_constraints.pop(param) criterion = "squared_error" @@ -2712,7 +2703,6 @@ def __init__( random_state=None, verbose=0, warm_start=False, - monotonic_cst=None, ): super().__init__( estimator=ExtraTreeRegressor(), @@ -2727,7 +2717,6 @@ def __init__( "max_leaf_nodes", "min_impurity_decrease", "random_state", - "monotonic_cst", ), bootstrap=False, oob_score=False, @@ -2745,7 +2734,6 @@ def __init__( self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.sparse_output = sparse_output - self.monotonic_cst = monotonic_cst def _set_oob_score_and_attributes(self, X, y): raise NotImplementedError("OOB score not supported by tree embedding") From 56966a6ef93191a76e4206d1c209563f9de7fb58 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Oct 2022 14:27:11 +0200 Subject: [PATCH 126/174] Use np.int8 by default Co-authored-by: Julien Jerphanion --- sklearn/tree/_classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 8817cf4047477..6e111b35c9f41 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -340,7 +340,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): splitter = self.splitter if self.monotonic_cst is None: - monotonic_cst = np.full(shape=X.shape[1], fill_value=0, dtype=np.int32) + monotonic_cst = np.zeros(shape=X.shape[1], dtype=np.int8) else: if self.n_outputs_ > 1: raise ValueError( From e00ab65c35d1b4b18af22c48ed303b3fa918b694 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Oct 2022 14:50:11 +0200 Subject: [PATCH 127/174] Apply suggestions from code review Co-authored-by: Christian Lorentzen Co-authored-by: Julien Jerphanion --- sklearn/ensemble/_forest.py | 2 +- sklearn/tree/_classes.py | 13 +++++++-- sklearn/tree/tests/test_monotonic_tree.py | 34 +++++++++++------------ 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 4d9ad1254eec0..59041e7824f90 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1988,7 +1988,7 @@ class ExtraTreesClassifier(ForestClassifier): - 0: no constraint - -1: monotonically decreasing - The constraints are only valid for binary classifications and hold + The constraints are only valid for binary classification and hold over the probability of the positive class. Attributes diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 6e111b35c9f41..4e0870ac86df2 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -344,11 +344,11 @@ def fit(self, X, y, sample_weight=None, check_input=True): else: if self.n_outputs_ > 1: raise ValueError( - "Monotonic constraints are not supported with multiple output" + "Monotonic constraints are not supported with multiple outputs." ) # Applying element-wise logical conjunction # for monotonic constraints' support. - monotonic_cst = np.asarray(self.monotonic_cst) + monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int8) unsatisfied_constraints_conditions = ( (monotonic_cst != -1) * (monotonic_cst != 0) * (monotonic_cst != 1) ) @@ -356,7 +356,6 @@ def fit(self, X, y, sample_weight=None, check_input=True): raise ValueError( "monotonic_cst must be None or an array-like of -1, 0 or 1." ) - monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int32) if is_classifier(self): if self.n_classes_[0] > 2: raise ValueError( @@ -771,6 +770,8 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): - 0: no constraint - -1: monotonically decreasing + If monotonic_cst is None, no constraints are applied. + The constraints are only valid for binary classifications and hold over the probability of the positive class. @@ -1155,6 +1156,10 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): - 0: no constraint - -1: monotonically decreasing + If monotonic_cst is None, no constraints are applied. + + Monotonic constraints are not supported for multioutput regression + (`n_outputs_>1`). Attributes ---------- feature_importances_ : ndarray of shape (n_features,) @@ -1482,6 +1487,8 @@ class ExtraTreeClassifier(DecisionTreeClassifier): - 0: no constraint - -1: monotonically decreasing + If monotonic_cst is None, no constraints are applied. + The constraints are only valid for binary classifications and hold over the probability of the positive class. diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 883d6ba6bfaf8..498d26fcb9290 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -20,16 +20,15 @@ def test_montonic_constraints_classifications(depth_first, global_random_seed): n_redundant=0, random_state=0, ) - train = np.arange(n_samples_train) - test = np.arange(n_samples_train, n_samples) - X_train = X[train] - y_train = y[train] - X_test = np.copy(X[test]) - - X_test_incr = np.copy(X_test) - X_test_decr = np.copy(X_test) - X_test_incr[:, 0] += 10 - X_test_decr[:, 1] += 10 + X_train, y_train = X[:n_samples_train], y[:n_samples_train] + X_test, y_test = X[n_samples_train:], y[n_samples_train:] + + X_test_0incr, X_test_0decr = np.copy(X_test), np.copy(X_test) + X_test_1incr, X_test_1decr = np.copy(X_test), np.copy(X_test) + X_test_0incr[:, 0] += 10 + X_test_0decr[:, 0] -= 10 + X_test_1incr[:, 1] += 10 + X_test_1decr[:, 1] -= 10 monotonic_cst = np.zeros(X.shape[1]) monotonic_cst[0] = 1 monotonic_cst[1] = -1 @@ -52,15 +51,13 @@ def test_montonic_constraints_classifications(depth_first, global_random_seed): est.fit(X_train, y_train) y = est.predict_proba(X_test)[:, 1] - # increasing constraint - y_incr = est.predict_proba(X_test_incr)[:, 1] - # y_incr should always be greater than y - assert np.all(y_incr >= y) + # increasing constraint, they apply to positive class + assert np.all(est.predict_proba(X_test_0incr)[:, 1] >= y) + assert np.all(est.predict_proba(X_test_0decr)[:, 1] <= y) # decreasing constraint - y_decr = est.predict_proba(X_test_decr)[:, 1] - # y_decr should always be lower than y - assert np.all(y_decr <= y) + assert np.all(est.predict_proba(X_test_1incr)[:, 1] <= y) + assert np.all(est.predict_proba(X_test_1decr)[:, 1] >= y) @pytest.mark.parametrize("depth_first", (True, False)) @@ -203,10 +200,11 @@ def test_1d_tree_nodes_values( monotonic_sign, splitter, depth_first, global_random_seed ): # Adaptation from test_nodes_values in test_montonic_constraints.py + # in sklearn.ensemble._hist_gradient_boosting # Build a single tree with only one feature, and make sure the nodes # values respect the monotonic constraints. - # Considering the following tree with a monotonic POS constraint, we + # Considering the following tree with a monotonic +1 constraint, we # should have: # # root From e047f52b320a24b2dd1377d4050c2f29259b9afc Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Oct 2022 14:58:53 +0200 Subject: [PATCH 128/174] Apply suggestions from code review Co-authored-by: Julien Jerphanion Co-authored-by: Christian Lorentzen --- doc/whats_new/v1.2.rst | 13 ++++++++----- sklearn/ensemble/_forest.py | 2 +- sklearn/tree/_classes.py | 4 ++++ sklearn/tree/tests/test_monotonic_tree.py | 2 +- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index a7eff52860894..76a0ecc714b98 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -475,11 +475,14 @@ Changelog :mod:`sklearn.tree` ................... -- |Feature| :class:`tree.DecisionTreeClassifier` and - :class:`tree.DecisionTreeRegressor` now support monotonic - constraints, useful when features are supposed to have a positive/negative - effect on the target. :pr:`13649` by :user:`Samuel Ronsin `, - initiated by :user:`Patrick O'Reilly `. +- |Feature| :class:`tree.DecisionTreeClassifier`, + :class:`tree.DecisionTreeRegressor`, + :class:`tree.ExtraTreeClassifier` and + :class:`tree.ExtraTreeRegressor` now support monotonic + constraints, useful when features are supposed to have a positive or + negative effect on the target. + :pr:`13649` by :user:`Samuel Ronsin `, initiated by + :user:`Patrick O'Reilly `. - |Enhancement| :func:`tree.plot_tree`, :func:`tree.export_graphviz` now uses a lower case `x[i]` to represent feature `i`. :pr:`23480` by `Thomas Fan`_. diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 59041e7824f90..4295a2e200fb3 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1270,7 +1270,7 @@ class RandomForestClassifier(ForestClassifier): - 0: no constraint - -1: monotonically decreasing - The constraints are only valid for binary classifications and hold + The constraints are only valid for binary classification and hold over the probability of the positive class. Attributes diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 4e0870ac86df2..a4076a9649b12 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -1749,6 +1749,10 @@ class ExtraTreeRegressor(DecisionTreeRegressor): - 0: no constraint - -1: monotonically decreasing + If monotonic_cst is None, no constraints are applied. + + Monotonic constraints are not supported for multioutput regression + (`n_outputs_>1`). Attributes ---------- max_features_ : int diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 498d26fcb9290..0a72213fbff55 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -258,7 +258,7 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): if monotonic_cst[feature] == 0: # Feature without monotonicity constraint: propagate bounds # down the tree to both children. - # Otherwise with 2 features and a POS constraint on feature 0 + # Otherwise with 2 features and a +1 constraint on feature 0 # the following tree can be accepted, although it does not # respect the positive monotonicity constraint: # From c70c1b67881d2a076c0d5fb87870e2852591c836 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Oct 2022 17:28:26 +0200 Subject: [PATCH 129/174] Missing .. versionadded:: 1.2 & more consistent docstrings --- sklearn/ensemble/_forest.py | 23 +++++++++++++++++++++++ sklearn/tree/_classes.py | 16 +++++++++++++--- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 4295a2e200fb3..2a51403780d23 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1270,9 +1270,13 @@ class RandomForestClassifier(ForestClassifier): - 0: no constraint - -1: monotonically decreasing + If monotonic_cst is None, no constraints are applied. + The constraints are only valid for binary classification and hold over the probability of the positive class. + .. versionadded:: 1.2 + Attributes ---------- estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier` @@ -1632,6 +1636,13 @@ class RandomForestRegressor(ForestRegressor): - 0: no constraint - -1: monotonically decreasing + If monotonic_cst is None, no constraints are applied. + + Monotonic constraints are not supported for multioutput regression + (`n_outputs_>1`). + + .. versionadded:: 1.2 + Attributes ---------- estimator_ : :class:`~sklearn.tree.DecisionTreeRegressor` @@ -1988,9 +1999,14 @@ class ExtraTreesClassifier(ForestClassifier): - 0: no constraint - -1: monotonically decreasing + If monotonic_cst is None, no constraints are applied. + The constraints are only valid for binary classification and hold over the probability of the positive class. + .. versionadded:: 1.2 + + Attributes ---------- estimator_ : :class:`~sklearn.tree.ExtraTreesClassifier` @@ -2341,6 +2357,13 @@ class ExtraTreesRegressor(ForestRegressor): - 0: no constraint - -1: monotonically decreasing + If monotonic_cst is None, no constraints are applied. + + Monotonic constraints are not supported for multioutput regression + (`n_outputs_>1`). + + .. versionadded:: 1.2 + Attributes ---------- estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor` diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index a4076a9649b12..85ce141dc4349 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -775,6 +775,8 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): The constraints are only valid for binary classifications and hold over the probability of the positive class. + .. versionadded:: 1.2 + Attributes ---------- classes_ : ndarray of shape (n_classes,) or list of ndarray @@ -1156,10 +1158,13 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): - 0: no constraint - -1: monotonically decreasing - If monotonic_cst is None, no constraints are applied. + If monotonic_cst is None, no constraints are applied. + + Monotonic constraints are not supported for multioutput regression + (`n_outputs_>1`). + + .. versionadded:: 1.2 - Monotonic constraints are not supported for multioutput regression - (`n_outputs_>1`). Attributes ---------- feature_importances_ : ndarray of shape (n_features,) @@ -1492,6 +1497,8 @@ class ExtraTreeClassifier(DecisionTreeClassifier): The constraints are only valid for binary classifications and hold over the probability of the positive class. + .. versionadded:: 1.2 + Attributes ---------- classes_ : ndarray of shape (n_classes,) or list of ndarray @@ -1753,6 +1760,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor): Monotonic constraints are not supported for multioutput regression (`n_outputs_>1`). + + .. versionadded:: 1.2 + Attributes ---------- max_features_ : int From e41435f8da0dd1235ad299e4ccbbecdfeeb20201 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Oct 2022 17:30:18 +0200 Subject: [PATCH 130/174] Fix linter --- sklearn/tree/tests/test_monotonic_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 0a72213fbff55..f871bd9e132d4 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -21,7 +21,7 @@ def test_montonic_constraints_classifications(depth_first, global_random_seed): random_state=0, ) X_train, y_train = X[:n_samples_train], y[:n_samples_train] - X_test, y_test = X[n_samples_train:], y[n_samples_train:] + X_test, _ = X[n_samples_train:], y[n_samples_train:] X_test_0incr, X_test_0decr = np.copy(X_test), np.copy(X_test) X_test_1incr, X_test_1decr = np.copy(X_test), np.copy(X_test) From 3a46ff6be340bddcfd40c8bd33aa58b94a18277c Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 27 Oct 2022 17:35:17 +0200 Subject: [PATCH 131/174] Monotonic increase --- sklearn/tree/tests/test_monotonic_tree.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index f871bd9e132d4..8bb181837daad 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -51,11 +51,11 @@ def test_montonic_constraints_classifications(depth_first, global_random_seed): est.fit(X_train, y_train) y = est.predict_proba(X_test)[:, 1] - # increasing constraint, they apply to positive class + # Monotonic increase constraint, it applies to the positive class assert np.all(est.predict_proba(X_test_0incr)[:, 1] >= y) assert np.all(est.predict_proba(X_test_0decr)[:, 1] <= y) - # decreasing constraint + # Monotonic decrease constraint, it applies to the positive class assert np.all(est.predict_proba(X_test_1incr)[:, 1] <= y) assert np.all(est.predict_proba(X_test_1decr)[:, 1] >= y) @@ -98,12 +98,12 @@ def test_montonic_constraints_regressions(depth_first, global_random_seed): est.set_params(**{"n_estimators": 5}) est.fit(X_train, y_train) y = est.predict(X_test) - # increasing constraint + # Monotonic increase constraint y_incr = est.predict(X_test_incr) # y_incr should always be greater than y assert np.all(y_incr >= y) - # decreasing constraint + # Monotonic decrease constraint y_decr = est.predict(X_test_decr) # y_decr should always be lower than y assert np.all(y_decr <= y) @@ -169,7 +169,7 @@ def test_bad_monotonic_cst_raises(): def assert_1d_reg_tree_children_monotonic_bounded(tree_, monotonic_sign): - # Flip values to always check for increasing constraint + # Flip values to always check for monotonic increase constraint values = monotonic_sign * tree_.value for i in range(tree_.node_count): if tree_.children_left[i] > i and tree_.children_right[i] > i: From 7a392196337769fe88ae2401b1e14c48c9c31577 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 28 Oct 2022 09:23:30 +0200 Subject: [PATCH 132/174] Encode monotonicity constraints with cnp.int8 --- sklearn/tree/_classes.py | 11 +++++++---- sklearn/tree/_criterion.pxd | 19 +++++++++++++++---- sklearn/tree/_criterion.pyx | 33 ++++++++++++++++++++++++++------- sklearn/tree/_splitter.pxd | 3 ++- sklearn/tree/_splitter.pyx | 22 +++++++++++++++------- 5 files changed, 65 insertions(+), 23 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 85ce141dc4349..6383b9ea53c13 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -346,16 +346,19 @@ def fit(self, X, y, sample_weight=None, check_input=True): raise ValueError( "Monotonic constraints are not supported with multiple outputs." ) - # Applying element-wise logical conjunction - # for monotonic constraints' support. - monotonic_cst = np.asarray(self.monotonic_cst, dtype=np.int8) + # Check to correct monotonic constraints' specification, + # by applying element-wise logical conjunction + monotonic_cst = np.asarray(self.monotonic_cst) unsatisfied_constraints_conditions = ( (monotonic_cst != -1) * (monotonic_cst != 0) * (monotonic_cst != 1) ) if np.any(unsatisfied_constraints_conditions): + unique_constaints_value = np.unique(monotonic_cst) raise ValueError( - "monotonic_cst must be None or an array-like of -1, 0 or 1." + "monotonic_cst must be None or an array-like of -1, 0 or 1, but" + f" got {unique_constaints_value}" ) + monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) if is_classifier(self): if self.n_classes_[0] > 2: raise ValueError( diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 35ee14c5a7024..ea045f557d685 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -8,6 +8,7 @@ # License: BSD 3 clause # See _criterion.pyx for implementation details. +cimport numpy as cnp from ._tree cimport DTYPE_t # Type of X from ._tree cimport DOUBLE_t # Type of y, sample_weight @@ -55,10 +56,20 @@ cdef class Criterion: double impurity_left, double impurity_right) nogil cdef double proxy_impurity_improvement(self) nogil - cdef bint check_monotonicity(self, INT32_t monotonic_cst, - double lower_bound, double upper_bound) nogil - cdef inline bint _check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound, - double sum_left, double sum_right) nogil + cdef bint check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + ) nogil + cdef inline bint _check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + double sum_left, + double sum_right, + ) nogil cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 5caa3fd18e9ad..4b8658df4a531 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -192,13 +192,22 @@ cdef class Criterion: - (self.weighted_n_left / self.weighted_n_node_samples * impurity_left))) - cdef bint check_monotonicity(self, INT32_t monotonic_cst, - double lower_bound, double upper_bound) nogil: + cdef bint check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + ) nogil: pass - cdef inline bint _check_monotonicity(self, INT32_t monotonic_cst, - double lower_bound, double upper_bound, - double sum_left, double sum_right) nogil: + cdef inline bint _check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + double sum_left, + double sum_right, + ) nogil: cdef: double weighted_n_left = self.weighted_n_left double weighted_n_right = self.weighted_n_right @@ -457,7 +466,12 @@ cdef class ClassificationCriterion(Criterion): memcpy(dest, &self.sum_total[k, 0], self.n_classes[k] * sizeof(double)) dest += self.max_n_classes - cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: + cdef inline bint check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + ) nogil: """Check monotonic constraint is satisfied at the current classification split""" cdef: double sum_left = self.sum_left[0][0] @@ -804,7 +818,12 @@ cdef class RegressionCriterion(Criterion): for k in range(self.n_outputs): dest[k] = self.sum_total[k] / self.weighted_n_node_samples - cdef inline bint check_monotonicity(self, INT32_t monotonic_cst, double lower_bound, double upper_bound) nogil: + cdef inline bint check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + ) nogil: """Check monotonic constraint is satisfied at the current regression split""" cdef: double sum_left = self.sum_left[0] diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 367f8db5f1e2e..dac0d56bec9ef 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -8,6 +8,7 @@ # License: BSD 3 clause # See _splitter.pyx for details. +cimport numpy as cnp from ._criterion cimport Criterion @@ -63,7 +64,7 @@ cdef class Splitter: # -1: monotonically decreasing # 0: no constraint # +1: monotonically increasing - cdef const INT32_t[:] monotonic_cst + cdef const cnp.int8_t[:] monotonic_cst # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index b4e202e079ec9..e0cfcc1dc631e 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -11,6 +11,8 @@ # # License: BSD 3 clause +cimport numpy as cnp + from ._criterion cimport Criterion from libc.stdlib cimport qsort @@ -49,9 +51,15 @@ cdef class Splitter: sparse and dense data, one split at a time. """ - def __cinit__(self, Criterion criterion, SIZE_t max_features, - SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, const INT32_t[:] monotonic_cst): + def __cinit__( + self, + Criterion criterion, + SIZE_t max_features, + SIZE_t min_samples_leaf, + double min_weight_leaf, + object random_state, + const cnp.int8_t[:] monotonic_cst, + ): """ Parameters ---------- @@ -74,7 +82,7 @@ cdef class Splitter: random_state : object The user inputted random state to be used for pseudo-randomness - monotonic_cst : INT32_t[:] + monotonic_cst : const cnp.int8_t[:] Monotonicity constraints """ @@ -299,7 +307,7 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t n_total_constants = n_known_constants cdef SIZE_t partition_end - cdef INT32_t monotonic_constraint + cdef cnp.int8_t monotonic_constraint _init_split(&best, end) @@ -626,7 +634,7 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef DTYPE_t max_feature_value cdef DTYPE_t current_feature_value - cdef INT32_t monotonic_constraint + cdef cnp.int8_t monotonic_constraint _init_split(&best, end) @@ -793,7 +801,7 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, const INT32_t[:] monotonic_cst): + object random_state, const cnp.int8_t[:] monotonic_cst): # Parent __cinit__ is automatically called self.n_total_samples = 0 self.monotonic_cst = monotonic_cst From dd5df852c75048fe6bc694a1f77742988482bde2 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 28 Oct 2022 14:26:28 +0200 Subject: [PATCH 133/174] Fix docstring format of ExtraTreesClassifier (double blank line) --- sklearn/ensemble/_forest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 2a51403780d23..956cf1470774b 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -2006,7 +2006,6 @@ class ExtraTreesClassifier(ForestClassifier): .. versionadded:: 1.2 - Attributes ---------- estimator_ : :class:`~sklearn.tree.ExtraTreesClassifier` From 2b596ec07e37ea1b38b73d31f118d128b7dca8df Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 28 Oct 2022 14:14:11 +0200 Subject: [PATCH 134/174] Remove monotonic_cst support for GradientBoosting for now Some tests aren't passing for reasons given in this thread: https://github.com/scikit-learn/scikit-learn/pull/13649/files#r973630373 Olivier and I agreed that this support can be postponed and integrated in another pull request. Co-authored-by: Olivier Grisel --- doc/whats_new/v1.2.rst | 19 +++++------- sklearn/ensemble/_gb.py | 25 +-------------- sklearn/ensemble/_gb_losses.py | 31 +++++++------------ .../_base.pyx.tp | 2 +- sklearn/tree/tests/test_monotonic_tree.py | 4 --- 5 files changed, 20 insertions(+), 61 deletions(-) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index af5c49cd2a4e2..579a618725e42 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -276,13 +276,10 @@ Changelog ....................... - |Feature| :class:`ensemble.RandomForestClassifier`, - :class:`ensemble.RandomForestRegressor`, - :class:`ensemble.ExtraTreesClassifier`, - :class:`ensemble.ExtraTreesRegressor`, - :class:`ensemble.GradientBoostingClassifier` and - :class:`ensemble.GradientBoostingRegressor` now support monotonic - constraints, useful when features are supposed to have a positive or negative - effect on the target. :pr:`13649` by :user:`Samuel Ronsin `, + :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier` + and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints, + useful when features are supposed to have a positive or negative effect on the target. + :pr:`13649` by :user:`Samuel Ronsin `, initiated by :user:`Patrick O'Reilly `. - |Feature| :class:`~sklearn.ensemble.HistGradientBoostingClassifier` and @@ -566,11 +563,9 @@ Changelog :mod:`sklearn.tree` ................... -- |Feature| :class:`tree.DecisionTreeClassifier`, - :class:`tree.DecisionTreeRegressor`, - :class:`tree.ExtraTreeClassifier` and - :class:`tree.ExtraTreeRegressor` now support monotonic - constraints, useful when features are supposed to have a positive or +- |Feature| :class:`tree.DecisionTreeClassifier`, :class:`tree.DecisionTreeRegressor`, + :class:`tree.ExtraTreeClassifier` and :class:`tree.ExtraTreeRegressor` now support + monotonic constraints, useful when features are supposed to have a positive or negative effect on the target. :pr:`13649` by :user:`Samuel Ronsin `, initiated by :user:`Patrick O'Reilly `. diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 3278ba3c5531a..e41742f3d51d6 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -148,6 +148,7 @@ class BaseGradientBoosting(BaseEnsemble, metaclass=ABCMeta): "tol": [Interval(Real, 0.0, None, closed="left")], } _parameter_constraints.pop("splitter") + _parameter_constraints.pop("monotonic_cst") @abstractmethod def __init__( @@ -174,7 +175,6 @@ def __init__( validation_fraction=0.1, n_iter_no_change=None, tol=1e-4, - monotonic_cst=None, ): self.n_estimators = n_estimators @@ -198,7 +198,6 @@ def __init__( self.validation_fraction = validation_fraction self.n_iter_no_change = n_iter_no_change self.tol = tol - self.monotonic_cst = monotonic_cst @abstractmethod def _validate_y(self, y, sample_weight=None): @@ -249,7 +248,6 @@ def _fit_stage( max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, ccp_alpha=self.ccp_alpha, - monotonic_cst=self.monotonic_cst, ) if self.subsample < 1.0: @@ -260,7 +258,6 @@ def _fit_stage( tree.fit(X, residual, sample_weight=sample_weight, check_input=False) # update tree leaves - perform_line_search = self.monotonic_cst is None loss.update_terminal_regions( tree.tree_, X, @@ -271,7 +268,6 @@ def _fit_stage( sample_mask, learning_rate=self.learning_rate, k=k, - perform_line_search=perform_line_search, ) # add tree to ensemble @@ -1050,15 +1046,6 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): .. versionadded:: 0.22 - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. - - 1: monotonically increasing - - 0: no constraint - - -1: monotonically decreasing - - The constraints are only valid for binary classifications and hold - over the probability of the positive class. - Attributes ---------- n_estimators_ : int @@ -1210,7 +1197,6 @@ def __init__( n_iter_no_change=None, tol=1e-4, ccp_alpha=0.0, - monotonic_cst=None, ): super().__init__( @@ -1234,7 +1220,6 @@ def __init__( n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha, - monotonic_cst=monotonic_cst, ) def _validate_y(self, y, sample_weight): @@ -1630,12 +1615,6 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): .. versionadded:: 0.22 - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. - - 1: monotonically increasing - - 0: no constraint - - -1: monotonically decreasing - Attributes ---------- feature_importances_ : ndarray of shape (n_features,) @@ -1768,7 +1747,6 @@ def __init__( n_iter_no_change=None, tol=1e-4, ccp_alpha=0.0, - monotonic_cst=None, ): super().__init__( @@ -1793,7 +1771,6 @@ def __init__( n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha, - monotonic_cst=monotonic_cst, ) def _validate_y(self, y, sample_weight=None): diff --git a/sklearn/ensemble/_gb_losses.py b/sklearn/ensemble/_gb_losses.py index 110eb2c295e13..e44692f518bb5 100644 --- a/sklearn/ensemble/_gb_losses.py +++ b/sklearn/ensemble/_gb_losses.py @@ -80,7 +80,6 @@ def update_terminal_regions( sample_mask, learning_rate=0.1, k=0, - perform_line_search=True, ): """Update the terminal regions (=leaves) of the given tree and updates the current predictions of the model. Traverses tree @@ -108,9 +107,6 @@ def update_terminal_regions( ``learning_rate``. k : int, default=0 The index of the estimator being updated. - perform_line_search : bool, default=True - Whether line search must be performed. Line search must not be - performed under monotonic constraints. """ # compute leaf for each sample in ``X``. terminal_regions = tree.apply(X) @@ -120,18 +116,17 @@ def update_terminal_regions( masked_terminal_regions[~sample_mask] = -1 # update each leaf (= perform line search) - if perform_line_search: - for leaf in np.where(tree.children_left == TREE_LEAF)[0]: - self._update_terminal_region( - tree, - masked_terminal_regions, - leaf, - X, - y, - residual, - raw_predictions[:, k], - sample_weight, - ) + for leaf in np.where(tree.children_left == TREE_LEAF)[0]: + self._update_terminal_region( + tree, + masked_terminal_regions, + leaf, + X, + y, + residual, + raw_predictions[:, k], + sample_weight, + ) # update predictions (both in-bag and out-of-bag) raw_predictions[:, k] += learning_rate * tree.value[:, 0, 0].take( @@ -260,7 +255,6 @@ def update_terminal_regions( sample_mask, learning_rate=0.1, k=0, - perform_line_search=True, ): """Least squares does not need to update terminal regions. @@ -288,9 +282,6 @@ def update_terminal_regions( ``learning_rate``. k : int, default=0 The index of the estimator being updated. - perform_line_search : bool, default=True - Whether line search must be performed. Line search must not be - performed under monotonic constraints. """ # update predictions raw_predictions[:, k] += learning_rate * tree.predict(X).ravel() diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp index c1bade148c988..46bef76bb64b0 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp @@ -112,7 +112,7 @@ cdef class BaseDistancesReduction{{name_suffix}}: if chunk_size is None: chunk_size = get_config().get("pairwise_dist_chunk_size", 256) - self.chunk_size = check_scalar(chunk_size, "chunk_size", Integral, min_val=20) + self.chunk_size = check_scalar(chunk_size, "chunk_size", Integral, min_val=1) self.effective_n_threads = _openmp_effective_n_threads() diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 8bb181837daad..14319df5330c5 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -2,8 +2,6 @@ import pytest from sklearn.datasets import make_classification, make_regression -from sklearn.ensemble import GradientBoostingClassifier -from sklearn.ensemble import GradientBoostingRegressor from sklearn.tree import DecisionTreeRegressor from sklearn.tree.tests.test_tree import REG_TREES, CLF_TREES @@ -34,7 +32,6 @@ def test_montonic_constraints_classifications(depth_first, global_random_seed): monotonic_cst[1] = -1 classifiers = CLF_TREES.copy() - classifiers.update({"GradientBoostingClassifier": GradientBoostingClassifier}) for name, TreeClassifier in classifiers.items(): if depth_first: est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) @@ -81,7 +78,6 @@ def test_montonic_constraints_regressions(depth_first, global_random_seed): monotonic_cst[0] = 1 monotonic_cst[1] = -1 regressors = REG_TREES.copy() - regressors.update({"GradientBoostingRegressor": GradientBoostingRegressor}) for name, TreeRegressor in regressors.items(): if depth_first: From b9c7a90838bc84fcc9a6939604ddb51bfc52d316 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 28 Oct 2022 14:40:03 +0200 Subject: [PATCH 135/174] DOC Clarify wording for constraints Prefer: - "monotonicity constraints" to refer to all of constraints (encoded by -1, 0, 1) which replaces "monotonic constraints" - "monotonic increase constraints" (encoded by 1) which replaces "positive constraints", "monotonically increasing" and "monotonically increase" - "monotonic decrease constraints" (encoded by -1) which replaces "negative constraints", "monotonically decreasing" and "monotonically decrease" Note: "monotonically {de,in}creasing" is valid when using continuous tenses but was improperly used before. See: https://github.com/scikit-learn/scikit-learn/pull/13649#discussion_r1006646907 Co-authored-by: Olivier Grisel --- sklearn/ensemble/_forest.py | 12 +++---- sklearn/tree/_classes.py | 38 +++++++++++------------ sklearn/tree/_criterion.pyx | 4 +-- sklearn/tree/_splitter.pxd | 4 +-- sklearn/tree/tests/test_monotonic_tree.py | 18 ++++++----- 5 files changed, 39 insertions(+), 37 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 956cf1470774b..b3a951d17dba7 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1265,10 +1265,10 @@ class RandomForestClassifier(ForestClassifier): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. - - 1: monotonically increasing + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase - 0: no constraint - - -1: monotonically decreasing + - -1: monotonic decrease If monotonic_cst is None, no constraints are applied. @@ -1631,7 +1631,7 @@ class RandomForestRegressor(ForestRegressor): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. + Indicates the monotonicity constraint to enforce on each feature. - 1: monotonically increasing - 0: no constraint - -1: monotonically decreasing @@ -1994,7 +1994,7 @@ class ExtraTreesClassifier(ForestClassifier): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. + Indicates the monotonicity constraint to enforce on each feature. - 1: monotonically increasing - 0: no constraint - -1: monotonically decreasing @@ -2351,7 +2351,7 @@ class ExtraTreesRegressor(ForestRegressor): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. + Indicates the monotonicity constraint to enforce on each feature. - 1: monotonically increasing - 0: no constraint - -1: monotonically decreasing diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 6383b9ea53c13..3b056382a75be 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -344,9 +344,9 @@ def fit(self, X, y, sample_weight=None, check_input=True): else: if self.n_outputs_ > 1: raise ValueError( - "Monotonic constraints are not supported with multiple outputs." + "Monotonicity constraints are not supported with multiple outputs." ) - # Check to correct monotonic constraints' specification, + # Check to correct monotonicity constraint' specification, # by applying element-wise logical conjunction monotonic_cst = np.asarray(self.monotonic_cst) unsatisfied_constraints_conditions = ( @@ -362,7 +362,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): if is_classifier(self): if self.n_classes_[0] > 2: raise ValueError( - "Monotonic constraints are not supported with multiclass " + "Monotonicity constraints are not supported with multiclass " "classification" ) # Imposing the constraint on the probability of the positive class @@ -768,10 +768,10 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. - - 1: monotonically increasing + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase - 0: no constraint - - -1: monotonically decreasing + - -1: monotonic decrease If monotonic_cst is None, no constraints are applied. @@ -1156,15 +1156,15 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. - - 1: monotonically increasing + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase - 0: no constraint - - -1: monotonically decreasing + - -1: monotonic decrease If monotonic_cst is None, no constraints are applied. - Monotonic constraints are not supported for multioutput regression - (`n_outputs_>1`). + Monotonicity constraints are not supported for multioutput regression + (i.e. when `n_outputs_ > 1`). .. versionadded:: 1.2 @@ -1490,10 +1490,10 @@ class ExtraTreeClassifier(DecisionTreeClassifier): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. - - 1: monotonically increasing + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase - 0: no constraint - - -1: monotonically decreasing + - -1: monotonic decrease If monotonic_cst is None, no constraints are applied. @@ -1754,15 +1754,15 @@ class ExtraTreeRegressor(DecisionTreeRegressor): .. versionadded:: 0.22 monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. - - 1: monotonically increasing + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase - 0: no constraint - - -1: monotonically decreasing + - -1: monotonic decrease If monotonic_cst is None, no constraints are applied. - Monotonic constraints are not supported for multioutput regression - (`n_outputs_>1`). + Monotonicity constraints are not supported for multioutput regression + (i.e. when `n_outputs_ > 1`). .. versionadded:: 1.2 diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 4b8658df4a531..9a7472610910f 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -472,7 +472,7 @@ cdef class ClassificationCriterion(Criterion): double lower_bound, double upper_bound, ) nogil: - """Check monotonic constraint is satisfied at the current classification split""" + """Check monotonicity constraint is satisfied at the current classification split""" cdef: double sum_left = self.sum_left[0][0] double sum_right = self.sum_right[0][0] @@ -824,7 +824,7 @@ cdef class RegressionCriterion(Criterion): double lower_bound, double upper_bound, ) nogil: - """Check monotonic constraint is satisfied at the current regression split""" + """Check monotonicity constraint is satisfied at the current regression split""" cdef: double sum_left = self.sum_left[0] double sum_right = self.sum_right[0] diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index dac0d56bec9ef..024a836061373 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -61,9 +61,9 @@ cdef class Splitter: cdef DOUBLE_t* sample_weight # Monotonicity constraints for each feature. # The encoding is as follows: - # -1: monotonically decreasing + # -1: monotonic decrease # 0: no constraint - # +1: monotonically increasing + # +1: monotonic increase cdef const cnp.int8_t[:] monotonic_cst # The samples vector `samples` is maintained by the Splitter object such diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 14319df5330c5..cf35fe0abca0a 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -120,7 +120,9 @@ def test_multiclass_raises(): if hasattr(est, "random_state"): est.set_params(**{"random_state": 0}) - msg = "Monotonic constraints are not supported with multiclass classification" + msg = ( + "Monotonicity constraints are not supported with multiclass classification" + ) with pytest.raises(ValueError, match=msg): est.fit(X, y) @@ -133,7 +135,7 @@ def test_multiple_output_raises(): est = TreeClassifier( max_depth=None, monotonic_cst=np.array([-1, 1]), random_state=0 ) - msg = "Monotonic constraints are not supported with multiple output" + msg = "Monotonicity constraints are not supported with multiple output" with pytest.raises(ValueError, match=msg): est.fit(X, y) @@ -198,7 +200,7 @@ def test_1d_tree_nodes_values( # Adaptation from test_nodes_values in test_montonic_constraints.py # in sklearn.ensemble._hist_gradient_boosting # Build a single tree with only one feature, and make sure the nodes - # values respect the monotonic constraints. + # values respect the monotonicity constraints. # Considering the following tree with a monotonic +1 constraint, we # should have: @@ -254,9 +256,9 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): if monotonic_cst[feature] == 0: # Feature without monotonicity constraint: propagate bounds # down the tree to both children. - # Otherwise with 2 features and a +1 constraint on feature 0 - # the following tree can be accepted, although it does not - # respect the positive monotonicity constraint: + # Otherwise, with 2 features and a monotonic increase constraint + # (encoded by +1) on feature 0, the following tree can be accepted, + # although it does not respect the monotonic increase constraint: # # X[0] <= 0 # value = 100 @@ -297,9 +299,9 @@ def test_nd_tree_nodes_values( monotonic_sign, splitter, depth_first, global_random_seed ): # Build tree with several features, and make sure the nodes - # values respect the monotonic constraints. + # values respect the monotonicity constraints. - # Considering the following tree with a monotonic POS constraint on X[0], + # Considering the following tree with a monotonic increase constraint on X[0], # we should have: # # root From 94d993cae8635f91e03f4e65789f1bc8c9168a45 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 28 Oct 2022 14:53:55 +0200 Subject: [PATCH 136/174] TST Parametrize tests and simplify parameters' lists --- sklearn/tree/tests/test_monotonic_tree.py | 207 +++++++++++----------- 1 file changed, 107 insertions(+), 100 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index cf35fe0abca0a..113d23ae852dc 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -2,12 +2,22 @@ import pytest from sklearn.datasets import make_classification, make_regression -from sklearn.tree import DecisionTreeRegressor -from sklearn.tree.tests.test_tree import REG_TREES, CLF_TREES +from sklearn.tree import ( + DecisionTreeRegressor, + DecisionTreeClassifier, + ExtraTreeRegressor, + ExtraTreeClassifier, +) +TREE_CLASSIFIER_CLASSES = [DecisionTreeClassifier, ExtraTreeClassifier] +TREE_REGRESSOR_CLASSES = [DecisionTreeRegressor, ExtraTreeRegressor] + +@pytest.mark.parametrize("TreeClassifier", TREE_CLASSIFIER_CLASSES) @pytest.mark.parametrize("depth_first", (True, False)) -def test_montonic_constraints_classifications(depth_first, global_random_seed): +def test_montonic_constraints_classifications( + TreeClassifier, depth_first, global_random_seed +): n_samples = 1000 n_samples_train = 900 X, y = make_classification( @@ -31,34 +41,35 @@ def test_montonic_constraints_classifications(depth_first, global_random_seed): monotonic_cst[0] = 1 monotonic_cst[1] = -1 - classifiers = CLF_TREES.copy() - for name, TreeClassifier in classifiers.items(): - if depth_first: - est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) - else: - est = TreeClassifier( - max_depth=None, - monotonic_cst=monotonic_cst, - max_leaf_nodes=n_samples_train, - ) - if hasattr(est, "random_state"): - est.set_params(**{"random_state": global_random_seed}) - if hasattr(est, "n_estimators"): - est.set_params(**{"n_estimators": 5}) - est.fit(X_train, y_train) - y = est.predict_proba(X_test)[:, 1] - - # Monotonic increase constraint, it applies to the positive class - assert np.all(est.predict_proba(X_test_0incr)[:, 1] >= y) - assert np.all(est.predict_proba(X_test_0decr)[:, 1] <= y) - - # Monotonic decrease constraint, it applies to the positive class - assert np.all(est.predict_proba(X_test_1incr)[:, 1] <= y) - assert np.all(est.predict_proba(X_test_1decr)[:, 1] >= y) + if depth_first: + est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) + else: + est = TreeClassifier( + max_depth=None, + monotonic_cst=monotonic_cst, + max_leaf_nodes=n_samples_train, + ) + if hasattr(est, "random_state"): + est.set_params(**{"random_state": global_random_seed}) + if hasattr(est, "n_estimators"): + est.set_params(**{"n_estimators": 5}) + est.fit(X_train, y_train) + y = est.predict_proba(X_test)[:, 1] + # Monotonic increase constraint, it applies to the positive class + assert np.all(est.predict_proba(X_test_0incr)[:, 1] >= y) + assert np.all(est.predict_proba(X_test_0decr)[:, 1] <= y) + # Monotonic decrease constraint, it applies to the positive class + assert np.all(est.predict_proba(X_test_1incr)[:, 1] <= y) + assert np.all(est.predict_proba(X_test_1decr)[:, 1] >= y) + + +@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) @pytest.mark.parametrize("depth_first", (True, False)) -def test_montonic_constraints_regressions(depth_first, global_random_seed): +def test_montonic_constraints_regressions( + TreeRegressor, depth_first, global_random_seed +): n_samples = 1000 n_samples_train = 900 # Build a classification task using 3 informative features @@ -77,35 +88,34 @@ def test_montonic_constraints_regressions(depth_first, global_random_seed): monotonic_cst = np.zeros(X.shape[1]) monotonic_cst[0] = 1 monotonic_cst[1] = -1 - regressors = REG_TREES.copy() - for name, TreeRegressor in regressors.items(): - if depth_first: - est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) - else: - est = TreeRegressor( - max_depth=None, - monotonic_cst=monotonic_cst, - max_leaf_nodes=n_samples_train, - ) - if hasattr(est, "random_state"): - est.set_params(random_state=global_random_seed) - if hasattr(est, "n_estimators"): - est.set_params(**{"n_estimators": 5}) - est.fit(X_train, y_train) - y = est.predict(X_test) - # Monotonic increase constraint - y_incr = est.predict(X_test_incr) - # y_incr should always be greater than y - assert np.all(y_incr >= y) - - # Monotonic decrease constraint - y_decr = est.predict(X_test_decr) - # y_decr should always be lower than y - assert np.all(y_decr <= y) - - -def test_multiclass_raises(): + if depth_first: + est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) + else: + est = TreeRegressor( + max_depth=None, + monotonic_cst=monotonic_cst, + max_leaf_nodes=n_samples_train, + ) + if hasattr(est, "random_state"): + est.set_params(random_state=global_random_seed) + if hasattr(est, "n_estimators"): + est.set_params(**{"n_estimators": 5}) + est.fit(X_train, y_train) + y = est.predict(X_test) + # Monotonic increase constraint + y_incr = est.predict(X_test_incr) + # y_incr should always be greater than y + assert np.all(y_incr >= y) + + # Monotonic decrease constraint + y_decr = est.predict(X_test_decr) + # y_decr should always be lower than y + assert np.all(y_decr <= y) + + +@pytest.mark.parametrize("TreeClassifier", TREE_CLASSIFIER_CLASSES) +def test_multiclass_raises(TreeClassifier): X, y = make_classification( n_samples=100, n_features=5, n_classes=3, n_informative=3, random_state=0 ) @@ -113,57 +123,52 @@ def test_multiclass_raises(): monotonic_cst = np.zeros(X.shape[1]) monotonic_cst[0] = -1 monotonic_cst[1] = 1 - for name, TreeClassifier in CLF_TREES.items(): - est = TreeClassifier( - max_depth=None, monotonic_cst=monotonic_cst, random_state=0 - ) - if hasattr(est, "random_state"): - est.set_params(**{"random_state": 0}) + est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst, random_state=0) + if hasattr(est, "random_state"): + est.set_params(**{"random_state": 0}) - msg = ( - "Monotonicity constraints are not supported with multiclass classification" - ) - with pytest.raises(ValueError, match=msg): - est.fit(X, y) + msg = "Monotonicity constraints are not supported with multiclass classification" + with pytest.raises(ValueError, match=msg): + est.fit(X, y) -def test_multiple_output_raises(): +@pytest.mark.parametrize("TreeClassifier", TREE_CLASSIFIER_CLASSES) +def test_multiple_output_raises(TreeClassifier): X = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] y = [[1, 0, 1, 0, 1], [1, 0, 1, 0, 1]] - for name, TreeClassifier in CLF_TREES.items(): - est = TreeClassifier( - max_depth=None, monotonic_cst=np.array([-1, 1]), random_state=0 - ) - msg = "Monotonicity constraints are not supported with multiple output" - with pytest.raises(ValueError, match=msg): - est.fit(X, y) + est = TreeClassifier( + max_depth=None, monotonic_cst=np.array([-1, 1]), random_state=0 + ) + msg = "Monotonicity constraints are not supported with multiple output" + with pytest.raises(ValueError, match=msg): + est.fit(X, y) -def test_bad_monotonic_cst_raises(): +@pytest.mark.parametrize("TreeClassifier", TREE_CLASSIFIER_CLASSES) +def test_bad_monotonic_cst_raises(TreeClassifier): X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] y = [1, 0, 1, 0, 1] - for name, TreeClassifier in CLF_TREES.items(): - msg = "monotonic_cst has shape 3 but the input data X has 2 features." - est = TreeClassifier( - max_depth=None, monotonic_cst=np.array([-1, 1, 0]), random_state=0 - ) - with pytest.raises(ValueError, match=msg): - est.fit(X, y) + msg = "monotonic_cst has shape 3 but the input data X has 2 features." + est = TreeClassifier( + max_depth=None, monotonic_cst=np.array([-1, 1, 0]), random_state=0 + ) + with pytest.raises(ValueError, match=msg): + est.fit(X, y) - msg = "monotonic_cst must be None or an array-like of -1, 0 or 1." - est = TreeClassifier( - max_depth=None, monotonic_cst=np.array([-2, 2]), random_state=0 - ) - with pytest.raises(ValueError, match=msg): - est.fit(X, y) + msg = "monotonic_cst must be None or an array-like of -1, 0 or 1." + est = TreeClassifier( + max_depth=None, monotonic_cst=np.array([-2, 2]), random_state=0 + ) + with pytest.raises(ValueError, match=msg): + est.fit(X, y) - est = TreeClassifier( - max_depth=None, monotonic_cst=np.array([-1, 0.8]), random_state=0 - ) - with pytest.raises(ValueError, match=msg): - est.fit(X, y) + est = TreeClassifier( + max_depth=None, monotonic_cst=np.array([-1, 0.8]), random_state=0 + ) + with pytest.raises(ValueError, match=msg): + est.fit(X, y) def assert_1d_reg_tree_children_monotonic_bounded(tree_, monotonic_sign): @@ -191,11 +196,12 @@ def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps): assert (monotonic_sign * np.diff(y_pred_grid) >= 0.0).all() +@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) @pytest.mark.parametrize("monotonic_sign", (-1, 1)) @pytest.mark.parametrize("splitter", ("best", "random")) @pytest.mark.parametrize("depth_first", (True, False)) def test_1d_tree_nodes_values( - monotonic_sign, splitter, depth_first, global_random_seed + TreeRegressor, monotonic_sign, splitter, depth_first, global_random_seed ): # Adaptation from test_nodes_values in test_montonic_constraints.py # in sklearn.ensemble._hist_gradient_boosting @@ -221,14 +227,14 @@ def test_1d_tree_nodes_values( if depth_first: # No max_leaf_nodes, default depth first tree builder - clf = DecisionTreeRegressor( + clf = TreeRegressor( splitter=splitter, monotonic_cst=[monotonic_sign], random_state=global_random_seed, ) else: # max_leaf_nodes triggers best first tree builder - clf = DecisionTreeRegressor( + clf = TreeRegressor( splitter=splitter, monotonic_cst=[monotonic_sign], max_leaf_nodes=n_samples, @@ -292,11 +298,12 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): lower_bound[i_right] = lower_bound[i] +@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) @pytest.mark.parametrize("monotonic_sign", (-1, 1)) @pytest.mark.parametrize("splitter", ("best", "random")) @pytest.mark.parametrize("depth_first", (True, False)) def test_nd_tree_nodes_values( - monotonic_sign, splitter, depth_first, global_random_seed + TreeRegressor, monotonic_sign, splitter, depth_first, global_random_seed ): # Build tree with several features, and make sure the nodes # values respect the monotonicity constraints. @@ -327,14 +334,14 @@ def test_nd_tree_nodes_values( if depth_first: # No max_leaf_nodes, default depth first tree builder - clf = DecisionTreeRegressor( + clf = TreeRegressor( splitter=splitter, monotonic_cst=monotonic_cst, random_state=global_random_seed, ) else: # max_leaf_nodes triggers best first tree builder - clf = DecisionTreeRegressor( + clf = TreeRegressor( splitter=splitter, monotonic_cst=monotonic_cst, max_leaf_nodes=n_samples, From 41e443e7dd4c672730eb4c3511a2d219c8b43741 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 28 Oct 2022 15:01:24 +0200 Subject: [PATCH 137/174] MAINT Format ala black --- sklearn/tree/_criterion.pyx | 1 + sklearn/tree/_splitter.pyx | 82 ++++++++++++++++++++----------------- sklearn/tree/_tree.pyx | 29 +++++++------ 3 files changed, 61 insertions(+), 51 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 9a7472610910f..89d4add2fa97f 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -220,6 +220,7 @@ cdef class Criterion: (sum_right <= upper_bound * weighted_n_right) ) bint check_monotonic_cst + if monotonic_cst == 0: # No constraint return check_lower_bound & check_upper_bound else: diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index e0cfcc1dc631e..67810f9a71762 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -201,13 +201,13 @@ cdef class Splitter: return 0 cdef int node_split( - self, - double impurity, - SplitRecord* split, - SIZE_t* n_constant_features, - double lower_bound, - double upper_bound, - ) nogil except -1: + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound, + ) nogil except -1: """Find the best split on node samples[start:end]. This is a placeholder method. The majority of computation will be done @@ -261,13 +261,13 @@ cdef class BestSplitter(BaseDenseSplitter): self.random_state), self.__getstate__()) cdef int node_split( - self, - double impurity, - SplitRecord* split, - SIZE_t* n_constant_features, - double lower_bound, - double upper_bound - ) nogil except -1: + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) nogil except -1: """Find the best split on node samples[start:end] Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -587,13 +587,13 @@ cdef class RandomSplitter(BaseDenseSplitter): self.random_state), self.__getstate__()) cdef int node_split( - self, - double impurity, - SplitRecord* split, - SIZE_t* n_constant_features, - double lower_bound, - double upper_bound - ) nogil except -1: + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) nogil except -1: """Find the best random split on node samples[start:end] Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -799,9 +799,15 @@ cdef class BaseSparseSplitter(Splitter): cdef SIZE_t[::1] index_to_samples cdef SIZE_t[::1] sorted_samples - def __cinit__(self, Criterion criterion, SIZE_t max_features, - SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, const cnp.int8_t[:] monotonic_cst): + def __cinit__( + self, + Criterion criterion, + SIZE_t max_features, + SIZE_t min_samples_leaf, + double min_weight_leaf, + object random_state, + const cnp.int8_t[:] monotonic_cst, + ): # Parent __cinit__ is automatically called self.n_total_samples = 0 self.monotonic_cst = monotonic_cst @@ -1105,13 +1111,13 @@ cdef class BestSparseSplitter(BaseSparseSplitter): self.random_state), self.__getstate__()) cdef int node_split( - self, - double impurity, - SplitRecord* split, - SIZE_t* n_constant_features, - double lower_bound, - double upper_bound - ) nogil except -1: + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) nogil except -1: """Find the best split on node samples[start:end], using sparse features Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -1346,13 +1352,13 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): self.random_state), self.__getstate__()) cdef int node_split( - self, - double impurity, - SplitRecord* split, - SIZE_t* n_constant_features, - double lower_bound, - double upper_bound - ) nogil except -1: + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) nogil except -1: """Find a random split on node samples[start:end], using sparse features Returns -1 in case of failure to allocate memory (and raise MemoryError) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 0d3151c35b598..1c08b69d6396c 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -216,7 +216,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): "impurity": INFINITY, "n_constant_features": 0, "lower_bound": -INFINITY, - "upper_bound": INFINITY}) + "upper_bound": INFINITY, + }) while not builder_stack.empty(): stack_record = builder_stack.top() @@ -249,11 +250,11 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if not is_leaf: splitter.node_split( - impurity, - &split, - &n_constant_features, - lower_bound, - upper_bound + impurity, + &split, + &n_constant_features, + lower_bound, + upper_bound ) # If EPSILON=0 in the below comparison, float precision # issues stop splitting, producing trees that are @@ -314,7 +315,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): "impurity": split.impurity_right, "n_constant_features": n_constant_features, "lower_bound": right_child_min, - "upper_bound": right_child_max}) + "upper_bound": right_child_max, + }) # Push left child on stack builder_stack.push({ @@ -326,7 +328,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): "impurity": split.impurity_left, "n_constant_features": n_constant_features, "lower_bound": left_child_min, - "upper_bound": left_child_max}) + "upper_bound": left_child_max, + }) if depth > max_depth_seen: max_depth_seen = depth @@ -595,11 +598,11 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if not is_leaf: splitter.node_split( - impurity, - &split, - &n_constant_features, - lower_bound, - upper_bound + impurity, + &split, + &n_constant_features, + lower_bound, + upper_bound ) # If EPSILON=0 in the below comparison, float precision issues stop # splitting early, producing trees that are dissimilar to v0.18 From bf0a1065695d011b276810e522c93578f9cb2989 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 28 Oct 2022 15:14:45 +0200 Subject: [PATCH 138/174] Remove irrelevant changes --- sklearn/ensemble/_gb_losses.py | 1 + sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/ensemble/_gb_losses.py b/sklearn/ensemble/_gb_losses.py index e44692f518bb5..228250910b94f 100644 --- a/sklearn/ensemble/_gb_losses.py +++ b/sklearn/ensemble/_gb_losses.py @@ -107,6 +107,7 @@ def update_terminal_regions( ``learning_rate``. k : int, default=0 The index of the estimator being updated. + """ # compute leaf for each sample in ``X``. terminal_regions = tree.apply(X) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp index 46bef76bb64b0..c1bade148c988 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp @@ -112,7 +112,7 @@ cdef class BaseDistancesReduction{{name_suffix}}: if chunk_size is None: chunk_size = get_config().get("pairwise_dist_chunk_size", 256) - self.chunk_size = check_scalar(chunk_size, "chunk_size", Integral, min_val=1) + self.chunk_size = check_scalar(chunk_size, "chunk_size", Integral, min_val=20) self.effective_n_threads = _openmp_effective_n_threads() From ec9cb4a6d5d6dea2f243921e9ccf34ed3caff836 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 28 Oct 2022 15:16:44 +0200 Subject: [PATCH 139/174] fixup! DOC Clarify wording for constraints --- sklearn/ensemble/_forest.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index b3a951d17dba7..58641278346a7 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1638,8 +1638,8 @@ class RandomForestRegressor(ForestRegressor): If monotonic_cst is None, no constraints are applied. - Monotonic constraints are not supported for multioutput regression - (`n_outputs_>1`). + Monotonicity constraints are not supported for multioutput regression + (i.e. when `n_outputs_ > 1`). .. versionadded:: 1.2 @@ -2358,8 +2358,8 @@ class ExtraTreesRegressor(ForestRegressor): If monotonic_cst is None, no constraints are applied. - Monotonic constraints are not supported for multioutput regression - (`n_outputs_>1`). + Monotonicity constraints are not supported for multioutput regression + (i.e. when `n_outputs_ > 1`). .. versionadded:: 1.2 From 0d58fe502695836c0c2259ae2da575ecee694f73 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 28 Oct 2022 15:19:13 +0200 Subject: [PATCH 140/174] Use global_random_seed in data generation for tests --- sklearn/tree/tests/test_monotonic_tree.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 113d23ae852dc..7e9486d988313 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -26,7 +26,7 @@ def test_montonic_constraints_classifications( n_features=5, n_informative=5, n_redundant=0, - random_state=0, + random_state=global_random_seed, ) X_train, y_train = X[:n_samples_train], y[:n_samples_train] X_test, _ = X[n_samples_train:], y[n_samples_train:] @@ -74,7 +74,10 @@ def test_montonic_constraints_regressions( n_samples_train = 900 # Build a classification task using 3 informative features X, y = make_regression( - n_samples=n_samples, n_features=5, n_informative=5, random_state=0 + n_samples=n_samples, + n_features=5, + n_informative=5, + random_state=global_random_seed, ) train = np.arange(n_samples_train) test = np.arange(n_samples_train, n_samples) From 23d44978c8a7b3037eecc94d7a7acee5a2f99e5c Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 28 Oct 2022 15:28:50 +0200 Subject: [PATCH 141/174] English grammar Co-authored-by: Christian Lorentzen --- sklearn/tree/tests/test_monotonic_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 7e9486d988313..82ac5993bd3ec 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -208,7 +208,7 @@ def test_1d_tree_nodes_values( ): # Adaptation from test_nodes_values in test_montonic_constraints.py # in sklearn.ensemble._hist_gradient_boosting - # Build a single tree with only one feature, and make sure the nodes + # Build a single tree with only one feature, and make sure the node # values respect the monotonicity constraints. # Considering the following tree with a monotonic +1 constraint, we From bfe94afe93fc44a39b6cd2cbc749bc369bd6c61d Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 28 Oct 2022 15:57:18 +0200 Subject: [PATCH 142/174] Optim: protect check_monotonicity under monotonic_constraint != 0 Co-authored-by: Julien Jerphanion --- sklearn/tree/_splitter.pyx | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 67810f9a71762..4d8ba4add6611 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -402,7 +402,14 @@ cdef class BestSplitter(BaseDenseSplitter): self.criterion.update(current.pos) # Reject if monotonicity constraints are not satisfied - if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + if ( + monotonic_constraint != 0 and not + self.criterion.check_monotonicity( + monotonic_constraint, + lower_bound, + upper_bound, + ) + ): continue # Reject if min_weight_leaf is not satisfied From 9e1f8b1a2d39101ec7b9b2f29a919c1b4d124a32 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 28 Oct 2022 16:38:29 +0200 Subject: [PATCH 143/174] Revert "Optim: protect check_monotonicity under monotonic_constraint != 0" This reverts commit bfe94afe93fc44a39b6cd2cbc749bc369bd6c61d. --- sklearn/tree/_splitter.pyx | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 4d8ba4add6611..67810f9a71762 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -402,14 +402,7 @@ cdef class BestSplitter(BaseDenseSplitter): self.criterion.update(current.pos) # Reject if monotonicity constraints are not satisfied - if ( - monotonic_constraint != 0 and not - self.criterion.check_monotonicity( - monotonic_constraint, - lower_bound, - upper_bound, - ) - ): + if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): continue # Reject if min_weight_leaf is not satisfied From 3fdae14832a3cbd6b3e9e21e824eb9a64b084d2f Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 28 Oct 2022 16:40:05 +0200 Subject: [PATCH 144/174] Test clean-up --- sklearn/tree/tests/test_monotonic_tree.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 82ac5993bd3ec..39fb1044a45d3 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -51,8 +51,6 @@ def test_montonic_constraints_classifications( ) if hasattr(est, "random_state"): est.set_params(**{"random_state": global_random_seed}) - if hasattr(est, "n_estimators"): - est.set_params(**{"n_estimators": 5}) est.fit(X_train, y_train) y = est.predict_proba(X_test)[:, 1] @@ -127,8 +125,6 @@ def test_multiclass_raises(TreeClassifier): monotonic_cst[0] = -1 monotonic_cst[1] = 1 est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst, random_state=0) - if hasattr(est, "random_state"): - est.set_params(**{"random_state": 0}) msg = "Monotonicity constraints are not supported with multiclass classification" with pytest.raises(ValueError, match=msg): From 7e33b05cec773dd6c3ee1f1bb07ff12d007fe333 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 28 Oct 2022 17:12:48 +0200 Subject: [PATCH 145/174] PERF Use monotonic_cst only is it is provided This does not reallocate placeholder null'ed numpy arrays, and perform operation only when needed. Co-authored-by: Christian Lorentzen Co-authored-by: Olivier Grisel --- sklearn/tree/_classes.py | 12 ++++----- sklearn/tree/_splitter.pxd | 1 + sklearn/tree/_splitter.pyx | 53 ++++++++++++++++++++++++-------------- sklearn/tree/_tree.pyx | 4 +-- 4 files changed, 42 insertions(+), 28 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 3b056382a75be..baa802b991fdd 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -340,7 +340,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): splitter = self.splitter if self.monotonic_cst is None: - monotonic_cst = np.zeros(shape=X.shape[1], dtype=np.int8) + monotonic_cst = None else: if self.n_outputs_ > 1: raise ValueError( @@ -349,6 +349,11 @@ def fit(self, X, y, sample_weight=None, check_input=True): # Check to correct monotonicity constraint' specification, # by applying element-wise logical conjunction monotonic_cst = np.asarray(self.monotonic_cst) + if monotonic_cst.shape[0] != X.shape[1]: + raise ValueError( + "monotonic_cst has shape {} but the input data " + "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) + ) unsatisfied_constraints_conditions = ( (monotonic_cst != -1) * (monotonic_cst != 0) * (monotonic_cst != 1) ) @@ -368,11 +373,6 @@ def fit(self, X, y, sample_weight=None, check_input=True): # Imposing the constraint on the probability of the positive class monotonic_cst *= -1 - if monotonic_cst.shape[0] != X.shape[1]: - raise ValueError( - "monotonic_cst has shape {} but the input data " - "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) - ) if not isinstance(self.splitter, Splitter): splitter = SPLITTERS[self.splitter]( criterion, diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 024a836061373..78e3a75d2e55b 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -65,6 +65,7 @@ cdef class Splitter: # 0: no constraint # +1: monotonic increase cdef const cnp.int8_t[:] monotonic_cst + cdef bint with_monotonic_cst # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 67810f9a71762..602d138fdd85c 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -99,6 +99,7 @@ cdef class Splitter: self.min_weight_leaf = min_weight_leaf self.random_state = random_state self.monotonic_cst = monotonic_cst + self.with_monotonic_cst = monotonic_cst is not None def __getstate__(self): return {} @@ -307,8 +308,6 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t n_total_constants = n_known_constants cdef SIZE_t partition_end - cdef cnp.int8_t monotonic_constraint - _init_split(&best, end) # Sample up to max_features without replacement using a @@ -355,7 +354,6 @@ cdef class BestSplitter(BaseDenseSplitter): # f_j in the interval [n_total_constants, f_i[ current.feature = features[f_j] - monotonic_constraint = self.monotonic_cst[current.feature] # Sort samples along that feature; by # copying the values into an array and # sorting the array in a manner which utilizes the cache more @@ -402,7 +400,14 @@ cdef class BestSplitter(BaseDenseSplitter): self.criterion.update(current.pos) # Reject if monotonicity constraints are not satisfied - if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + if ( + self.with_monotonic_cst and + not self.criterion.check_monotonicity( + self.monotonic_cst[current.feature], + lower_bound, + upper_bound, + ) + ): continue # Reject if min_weight_leaf is not satisfied @@ -634,8 +639,6 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef DTYPE_t max_feature_value cdef DTYPE_t current_feature_value - cdef cnp.int8_t monotonic_constraint - _init_split(&best, end) # Sample up to max_features without replacement using a @@ -681,8 +684,6 @@ cdef class RandomSplitter(BaseDenseSplitter): current.feature = features[f_j] - monotonic_constraint = self.monotonic_cst[current.feature] - # Find min, max min_feature_value = self.X[samples[start], current.feature] max_feature_value = min_feature_value @@ -743,7 +744,14 @@ cdef class RandomSplitter(BaseDenseSplitter): continue # Reject if monotonicity constraints are not satisfied - if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + if ( + self.with_monotonic_cst and + not self.criterion.check_monotonicity( + self.monotonic_cst[current.feature], + lower_bound, + upper_bound, + ) + ): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -810,7 +818,6 @@ cdef class BaseSparseSplitter(Splitter): ): # Parent __cinit__ is automatically called self.n_total_samples = 0 - self.monotonic_cst = monotonic_cst cdef int init(self, object X, @@ -1165,8 +1172,6 @@ cdef class BestSparseSplitter(BaseSparseSplitter): cdef SIZE_t start_positive cdef SIZE_t end_negative - cdef INT32_t monotonic_constraint - # Sample up to max_features without replacement using a # Fisher-Yates-based algorithm (using the local variables `f_i` and # `f_j` to compute a permutation of the `features` array). @@ -1214,8 +1219,6 @@ cdef class BestSparseSplitter(BaseSparseSplitter): self.extract_nnz(current.feature, &end_negative, &start_positive, &is_samples_sorted) - monotonic_constraint = self.monotonic_cst[current.feature] - # Sort the positive and negative parts of `Xf` sort(&Xf[start], &samples[start], end_negative - start) if start_positive < end: @@ -1291,7 +1294,14 @@ cdef class BestSparseSplitter(BaseSparseSplitter): continue # Reject if monotonicity constraints are not satisfied - if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + if ( + self.with_monotonic_cst and + not self.criterion.check_monotonicity( + self.monotonic_cst[current.feature], + lower_bound, + upper_bound, + ) + ): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() @@ -1407,8 +1417,6 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): cdef SIZE_t start_positive cdef SIZE_t end_negative - cdef INT32_t monotonic_constraint - # Sample up to max_features without replacement using a # Fisher-Yates-based algorithm (using the local variables `f_i` and # `f_j` to compute a permutation of the `features` array). @@ -1458,8 +1466,6 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): &end_negative, &start_positive, &is_samples_sorted) - monotonic_constraint = self.monotonic_cst[current.feature] - if end_negative != start_positive: # There is a zero min_feature_value = 0 @@ -1526,7 +1532,14 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): continue # Reject if monotonicity constraints are not satisfied - if not self.criterion.check_monotonicity(monotonic_constraint, lower_bound, upper_bound): + if ( + self.with_monotonic_cst and + not self.criterion.check_monotonicity( + self.monotonic_cst[current.feature], + lower_bound, + upper_bound, + ) + ): continue current_proxy_improvement = self.criterion.proxy_impurity_improvement() diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 1c08b69d6396c..3dd0446531da5 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -278,7 +278,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): middle_value = tree._get_middle_value(node_id) if not is_leaf: - if splitter.monotonic_cst[split.feature] == 0: + if not splitter.with_monotonic_cst or splitter.monotonic_cst[split.feature] == 0: # Split on a feature with no monotonicity constraint # Current bounds must always be propagated to both children # Splitting criterion will always check them in check_monotonicity @@ -473,7 +473,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Node is expandable middle_value = tree._get_middle_value(record.node_id) - if splitter.monotonic_cst[node.feature] == 0: + if not splitter.with_monotonic_cst or splitter.monotonic_cst[node.feature] == 0: # Split on a feature with no monotonicity constraint # Current bounds must always be propagated to both children # Splitting criterion will always check them in check_monotonicity From 1cae8eeec83d6fdf37e69cf9defdecda022c9c5b Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 28 Oct 2022 18:25:38 +0200 Subject: [PATCH 146/174] Improve comments in _get_middle_value --- sklearn/tree/_tree.pyx | 26 ++++++++++++++++++----- sklearn/tree/tests/test_monotonic_tree.py | 2 +- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 3dd0446531da5..10cebb69c7785 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1281,19 +1281,35 @@ cdef class Tree: return arr cdef inline double _get_middle_value(self, int node_id) nogil: + """Returns the new bound when a monotonic constraint is active + + For regression, this is the average target value of the samples in + the parent node. + + For binary classification, this is the proportion of the negative + class of samples in the parent node. + """ cdef: int i double middle_value = 0 + # self.value has shape: [node_count, n_outputs, max_n_classes] + # Monotonic constraints are only supported for single-output trees so + # we can safely assume n_outputs == 1. + + # Furthermore, we assume that for + # classification trees, max_n_classes == 2 as we only support binary + # classification (and so self.value_stride == 2). + if self.max_n_classes == 1: - # Regression + # Regression: self.max_n_classes == 1, so self.value_stride == 1 middle_value = self.value[node_id] else: - # Classification + # Binary classification: self.max_n_classes == 2 - # This performs some raw pointers arithmetic - # Ideally a memory view could wrap self.value so that's - # indexing on several axis can be used. + # This performs some raw pointers arithmetic. Ideally a memory view + # could wrap self.value so that indexing on several axis can be + # used. for i in range(self.max_n_classes): middle_value += self.value[node_id * self.value_stride + i] diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 39fb1044a45d3..7eb51ebacf181 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -202,7 +202,7 @@ def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps): def test_1d_tree_nodes_values( TreeRegressor, monotonic_sign, splitter, depth_first, global_random_seed ): - # Adaptation from test_nodes_values in test_montonic_constraints.py + # Adaptation from test_nodes_values in test_monotonic_constraints.py # in sklearn.ensemble._hist_gradient_boosting # Build a single tree with only one feature, and make sure the node # values respect the monotonicity constraints. From 3558bfd0b4e56ad58280d96dd93eb20b27e7465e Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 2 Nov 2022 15:27:34 +0100 Subject: [PATCH 147/174] Only compute bounds' middle_value when needed Also simplify the logic and its comment. Note: we could factor this duplicated logic in a private function or a private method, but we would need to use pointers to modify variable on the program stack because we can't simply return several values. This factorization might be more complex than the current duplication, hence we propose leaving it as is. See: https://github.com/scikit-learn/scikit-learn/pull/13649/files#r973633867 Co-authored-by: Olivier Grisel --- sklearn/tree/_tree.pyx | 67 +++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 10cebb69c7785..3fbb08c00edce 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -275,33 +275,35 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # inspection and interpretation splitter.node_value(tree.value + node_id * tree.value_stride) - middle_value = tree._get_middle_value(node_id) - if not is_leaf: - if not splitter.with_monotonic_cst or splitter.monotonic_cst[split.feature] == 0: + if ( + not splitter.with_monotonic_cst or + splitter.monotonic_cst[split.feature] == 0 + ): # Split on a feature with no monotonicity constraint + # Current bounds must always be propagated to both children # Splitting criterion will always check them in check_monotonicity - left_child_min = lower_bound - left_child_max = upper_bound - right_child_min = lower_bound - right_child_max = upper_bound + left_child_min = right_child_min = lower_bound + left_child_max = right_child_max = upper_bound elif splitter.monotonic_cst[split.feature] == 1: - # Split on a feature with increasing monotonicity constraint - # Lower bound for left child is the current lower bound + # Split on a feature with monotonic increase constraint left_child_min = lower_bound - # Upper bound for right child is the current upper bound right_child_max = upper_bound - # Lower bound for right child and upper bound for left child are set to the same value + + # Lower bound for right child and upper bound for left child + # are set to the same value. + middle_value = tree._get_middle_value(node_id) right_child_min = middle_value left_child_max = middle_value else: # i.e. splitter.monotonic_cst[split.feature] == -1 - # Split on a feature with decreasing monotonicity constraint - # Lower bound for right child is the current lower bound + # Split on a feature with monotonic decrease constraint right_child_min = lower_bound - # Upper bound for left child is the current upper bound left_child_max = upper_bound - # Lower bound for left child and upper bound for right child are set to the same value + + # Lower bound for left child and upper bound for right child + # are set to the same value. + middle_value = tree._get_middle_value(node_id) left_child_min = middle_value right_child_max = middle_value @@ -472,31 +474,34 @@ cdef class BestFirstTreeBuilder(TreeBuilder): else: # Node is expandable - middle_value = tree._get_middle_value(record.node_id) - if not splitter.with_monotonic_cst or splitter.monotonic_cst[node.feature] == 0: + if ( + not splitter.with_monotonic_cst or + splitter.monotonic_cst[node.feature] == 0 + ): # Split on a feature with no monotonicity constraint - # Current bounds must always be propagated to both children + + # Current bounds must always be propagated to both children. # Splitting criterion will always check them in check_monotonicity - left_child_min = record.lower_bound - left_child_max = record.upper_bound - right_child_min = record.lower_bound - right_child_max = record.upper_bound + left_child_min = right_child_min = record.lower_bound + left_child_max = right_child_max = record.upper_bound elif splitter.monotonic_cst[node.feature] == 1: - # Split on a feature with increasing monotonicity constraint - # Lower bound for left child is the current lower bound + # Split on a feature with monotonic increase constraint left_child_min = record.lower_bound - # Upper bound for right child is the current upper bound right_child_max = record.upper_bound - # Lower bound for right child and upper bound for left child and are set to the same value + + # Lower bound for right child and upper bound for left child + # are set to the same value. + middle_value = tree._get_middle_value(record.node_id) right_child_min = middle_value left_child_max = middle_value - else: # i.e. splitter.monotonic_cst[node.feature] == -1 - # Split on a feature with decreasing monotonicity constraint - # Lower bound for right child is the current lower bound + else: # i.e. splitter.monotonic_cst[split.feature] == -1 + # Split on a feature with monotonic decrease constraint right_child_min = record.lower_bound - # Upper bound for left child is the current upper bound left_child_max = record.upper_bound - # Lower bound for left child and upper bound for right child and are set to the same value + + # Lower bound for left child and upper bound for right child + # are set to the same value. + middle_value = tree._get_middle_value(record.node_id) left_child_min = middle_value right_child_max = middle_value From 35ae0a3b3795a0fbdad0ad57fd63dc3869b81559 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 4 Nov 2022 14:54:55 +0100 Subject: [PATCH 148/174] More intuitive assert_nd_reg_tree_children_monotonic_bounded --- sklearn/tree/tests/test_monotonic_tree.py | 37 ++++++++++++++--------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 7eb51ebacf181..ca840d2bff3ca 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -258,6 +258,7 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): else: i_left = tree_.children_left[i] i_right = tree_.children_right[i] + if monotonic_cst[feature] == 0: # Feature without monotonicity constraint: propagate bounds # down the tree to both children. @@ -278,23 +279,29 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): lower_bound[i_left] = lower_bound[i] upper_bound[i_right] = upper_bound[i] lower_bound[i_right] = lower_bound[i] - else: + + elif monotonic_cst[feature] == 1: # Feature with constraint: check monotonicity - assert ( - monotonic_cst[feature] * tree_.value[i_left] - <= monotonic_cst[feature] * tree_.value[i_right] - ) + assert tree_.value[i_left] <= tree_.value[i_right] + + # Propagate bounds down the tree to both children. + upper_bound[i_left] = tree_.value[i] + lower_bound[i_left] = lower_bound[i] + upper_bound[i_right] = upper_bound[i] + lower_bound[i_right] = tree_.value[i] + + elif monotonic_cst[feature] == -1: + # Feature with constraint: check monotonicity + assert tree_.value[i_left] >= tree_.value[i_right] + # Update and propagate bounds down the tree to both children. - if monotonic_cst[feature] == 1: - upper_bound[i_left] = tree_.value[i] - lower_bound[i_left] = lower_bound[i] - upper_bound[i_right] = upper_bound[i] - lower_bound[i_right] = tree_.value[i] - else: - upper_bound[i_left] = upper_bound[i] - lower_bound[i_left] = tree_.value[i] - upper_bound[i_right] = tree_.value[i] - lower_bound[i_right] = lower_bound[i] + upper_bound[i_left] = upper_bound[i] + lower_bound[i_left] = tree_.value[i] + upper_bound[i_right] = tree_.value[i] + lower_bound[i_right] = lower_bound[i] + + else: # pragma: no cover + raise ValueError(f"monotonic_cst[{feature}]={monotonic_cst[feature]}") @pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) From 5926abef3356346fd8938cd4d0bb925172c3dace Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 4 Nov 2022 19:10:23 +0100 Subject: [PATCH 149/174] WIP try to test assert_nd_reg_tree_children_monotonic_bounded --- sklearn/tree/tests/test_monotonic_tree.py | 139 ++++++++++++++-------- 1 file changed, 91 insertions(+), 48 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index ca840d2bff3ca..d0992e8d6decc 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -250,58 +250,101 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): lower_bound = np.full(tree_.node_count, -np.inf) for i in range(tree_.node_count): feature = tree_.feature[i] - assert tree_.value[i] <= upper_bound[i] - assert tree_.value[i] >= lower_bound[i] + node_value = tree_.value[i][0][0] # unpack value from nx1x1 array + assert node_value <= upper_bound[i] + assert node_value >= lower_bound[i] + if feature < 0: # Leaf: nothing to do continue - else: - i_left = tree_.children_left[i] - i_right = tree_.children_right[i] - if monotonic_cst[feature] == 0: - # Feature without monotonicity constraint: propagate bounds - # down the tree to both children. - # Otherwise, with 2 features and a monotonic increase constraint - # (encoded by +1) on feature 0, the following tree can be accepted, - # although it does not respect the monotonic increase constraint: - # - # X[0] <= 0 - # value = 100 - # / \ - # X[0] <= -1 X[1] <= 0 - # value = 50 value = 150 - # / \ / \ - # leaf leaf leaf leaf - # value = 25 value = 75 value = 50 value = 250 - - upper_bound[i_left] = upper_bound[i] - lower_bound[i_left] = lower_bound[i] - upper_bound[i_right] = upper_bound[i] - lower_bound[i_right] = lower_bound[i] - - elif monotonic_cst[feature] == 1: - # Feature with constraint: check monotonicity - assert tree_.value[i_left] <= tree_.value[i_right] - - # Propagate bounds down the tree to both children. - upper_bound[i_left] = tree_.value[i] - lower_bound[i_left] = lower_bound[i] - upper_bound[i_right] = upper_bound[i] - lower_bound[i_right] = tree_.value[i] - - elif monotonic_cst[feature] == -1: - # Feature with constraint: check monotonicity - assert tree_.value[i_left] >= tree_.value[i_right] - - # Update and propagate bounds down the tree to both children. - upper_bound[i_left] = upper_bound[i] - lower_bound[i_left] = tree_.value[i] - upper_bound[i_right] = tree_.value[i] - lower_bound[i_right] = lower_bound[i] - - else: # pragma: no cover - raise ValueError(f"monotonic_cst[{feature}]={monotonic_cst[feature]}") + # Split node: check and update bounds for the children. + i_left = tree_.children_left[i] + i_right = tree_.children_right[i] + + if monotonic_cst[feature] == 0: + # Feature without monotonicity constraint: propagate bounds + # down the tree to both children. + # Otherwise, with 2 features and a monotonic increase constraint + # (encoded by +1) on feature 0, the following tree can be accepted, + # although it does not respect the monotonic increase constraint: + # + # X[0] <= 0 + # value = 100 + # / \ + # X[0] <= -1 X[1] <= 0 + # value = 50 value = 150 + # / \ / \ + # leaf leaf leaf leaf + # value = 25 value = 75 value = 50 value = 250 + + lower_bound[i_left] = lower_bound[i] + upper_bound[i_left] = upper_bound[i] + lower_bound[i_right] = lower_bound[i] + upper_bound[i_right] = upper_bound[i] + + elif monotonic_cst[feature] == 1: + # Feature with constraint: check monotonicity + assert tree_.value[i_left] <= tree_.value[i_right] + + # Propagate bounds down the tree to both children. + lower_bound[i_left] = lower_bound[i] + upper_bound[i_left] = node_value + lower_bound[i_right] = node_value + upper_bound[i_right] = upper_bound[i] + + elif monotonic_cst[feature] == -1: + # Feature with constraint: check monotonicity + assert tree_.value[i_left] >= tree_.value[i_right] + + # Update and propagate bounds down the tree to both children. + lower_bound[i_left] = node_value + upper_bound[i_left] = upper_bound[i] + lower_bound[i_right] = lower_bound[i] + upper_bound[i_right] = node_value + + else: # pragma: no cover + raise ValueError(f"monotonic_cst[{feature}]={monotonic_cst[feature]}") + + +def test_assert_nd_reg_tree_children_monotonic_bounded(): + # Check that assert_nd_reg_tree_children_monotonic_bounded can detect + # non-monotonic tree predictions. + X = np.linspace(0, 2 * np.pi, 30).reshape(-1, 1) + y = np.sin(X).ravel() + reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y) + + with pytest.raises(AssertionError): + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [1]) + + with pytest.raises(AssertionError): + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [-1]) + + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [0]) + + # Check that assert_nd_reg_tree_children_monotonic_bounded does not raise + # when the data (and therefore the model) is naturally monotonic in the + # right direction. + X = np.linspace(-5, 5, 5).reshape(-1, 1) + y = X.ravel() ** 3 + reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y) + + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [1]) + + with pytest.raises(AssertionError): + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [-1]) + + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [0]) + + # For completeness, check that the converse holds when swapping the sign. + reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, -y) + + with pytest.raises(AssertionError): + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [1]) + + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [-1]) + + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [0]) @pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) From a6502241bd90fa5e99dfcd8460defa37b5847861 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 4 Nov 2022 19:31:47 +0100 Subject: [PATCH 150/174] Apply suggestions from code review Co-authored-by: Julien Jerphanion --- sklearn/tree/_tree.pyx | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 3fbb08c00edce..db57c1b314adb 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1286,7 +1286,7 @@ cdef class Tree: return arr cdef inline double _get_middle_value(self, int node_id) nogil: - """Returns the new bound when a monotonic constraint is active + """Returns the new bound when a monotonicity constraint is active For regression, this is the average target value of the samples in the parent node. @@ -1298,9 +1298,14 @@ cdef class Tree: int i double middle_value = 0 - # self.value has shape: [node_count, n_outputs, max_n_classes] - # Monotonic constraints are only supported for single-output trees so - # we can safely assume n_outputs == 1. + # self.value conceptually has shape: + # + # [node_count, n_outputs, max_n_classes] + # + # but is here handled via a pointer and strides. + # + # Monotonicity constraints are only supported for single-output + # trees we can safely assume n_outputs == 1. # Furthermore, we assume that for # classification trees, max_n_classes == 2 as we only support binary From 6809be74b9435ee080a097930fa7f0c7bd181b95 Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Wed, 14 Jun 2023 15:54:41 +0200 Subject: [PATCH 151/174] Clean up --- sklearn/tree/_criterion.pyx | 2 +- sklearn/tree/_tree.pyx | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 71212461c9a7c..1393497a9b91d 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -240,7 +240,7 @@ cdef class Criterion: ) bint check_monotonic_cst - if monotonic_cst == 0: # No constraint + if monotonic_cst == 0: # No constraint return check_lower_bound & check_upper_bound else: check_monotonic_cst = ( diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 41e28738bdf54..fc944426be567 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -211,10 +211,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double impurity = INFINITY cdef double lower_bound cdef double upper_bound - cdef double lower_bound_left - cdef double upper_bound_left - cdef double lower_bound_right - cdef double upper_bound_right cdef double middle_value cdef SIZE_t n_constant_features cdef bint is_leaf @@ -609,7 +605,6 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef double min_impurity_decrease = self.min_impurity_decrease cdef double weighted_n_node_samples cdef bint is_leaf - cdef double node_value splitter.node_reset(start, end, &weighted_n_node_samples) @@ -1348,9 +1343,9 @@ cdef class Tree: # self.value conceptually has shape: # # [node_count, n_outputs, max_n_classes] - # - # but is here handled via a pointer and strides. - # + # + # but is here handled via a pointer and strides. + # # Monotonicity constraints are only supported for single-output # trees we can safely assume n_outputs == 1. From 27895f5566a77bbbfd90e7b1686eee3e109cfb55 Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Thu, 15 Jun 2023 17:04:19 +0200 Subject: [PATCH 152/174] Implement middle value as unweighted average of right and left children values of constrained nodes --- sklearn/tree/_criterion.pxd | 1 + sklearn/tree/_criterion.pyx | 26 +++++++++ sklearn/tree/_tree.pxd | 1 - sklearn/tree/_tree.pyx | 64 ++++------------------- sklearn/tree/tests/test_monotonic_tree.py | 16 +++--- 5 files changed, 45 insertions(+), 63 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index bcba4e98776f7..ab9d46ff10286 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -69,6 +69,7 @@ cdef class Criterion: self, double* dest ) noexcept nogil + cdef double middle_value(self) noexcept nogil cdef double impurity_improvement( self, double impurity_parent, diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 1393497a9b91d..856b3645f6838 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -157,6 +157,13 @@ cdef class Criterion: """ pass + cdef double middle_value(self) noexcept nogil: + """Compute the middle value of a split for monotonicity constraints + + This method is implemented in ClassificationCriterion and RegressionCriterion. + """ + pass + cdef double proxy_impurity_improvement(self) noexcept nogil: """Compute a proxy of the impurity reduction. @@ -572,6 +579,16 @@ cdef class ClassificationCriterion(Criterion): memcpy(dest, &self.sum_total[k, 0], self.n_classes[k] * sizeof(double)) dest += self.max_n_classes + cdef double middle_value(self) noexcept nogil: + """Compute the middle value of a split for monotonicity constraints as the simple average + of the left and right children values. + + Note that monotonicity constraints are only supported for: + - single-output trees and + - binary classifications. + """ + return ((self.sum_left[0, 0] / self.weighted_n_left) + (self.sum_right[0, 0] / self.weighted_n_right)) / 2 + cdef inline bint check_monotonicity( self, cnp.int8_t monotonic_cst, @@ -1010,6 +1027,15 @@ cdef class RegressionCriterion(Criterion): for k in range(self.n_outputs): dest[k] = self.sum_total[k] / self.weighted_n_node_samples + cdef double middle_value(self) noexcept nogil: + """Compute the middle value of a split for monotonicity constraints as the simple average + of the left and right children values. + + Monotonicity constraints are only supported for single-output trees we can safely assume + n_outputs == 1. + """ + return ((self.sum_left[0] / self.weighted_n_left) + (self.sum_right[0] / self.weighted_n_right)) / 2 + cdef inline bint check_monotonicity( self, cnp.int8_t monotonic_cst, diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 2f8fe42d9670c..e08ec5c94e41a 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -66,7 +66,6 @@ cdef class Tree: cdef cnp.ndarray _get_value_ndarray(self) cdef cnp.ndarray _get_node_ndarray(self) - cdef inline double _get_middle_value(self, int node_id) nogil cpdef cnp.ndarray predict(self, object X) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index fc944426be567..b4748e75871ef 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -272,6 +272,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): lower_bound, upper_bound ) + middle_value = splitter.criterion.middle_value() # If EPSILON=0 in the below comparison, float precision # issues stop splitting, producing trees that are # dissimilar to v0.18 @@ -310,7 +311,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # Lower bound for right child and upper bound for left child # are set to the same value. - middle_value = tree._get_middle_value(node_id) right_child_min = middle_value left_child_max = middle_value else: # i.e. splitter.monotonic_cst[split.feature] == -1 @@ -320,7 +320,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # Lower bound for left child and upper bound for right child # are set to the same value. - middle_value = tree._get_middle_value(node_id) left_child_min = middle_value right_child_max = middle_value @@ -379,6 +378,7 @@ cdef struct FrontierRecord: double improvement double lower_bound double upper_bound + double middle_value cdef inline bool _compare_records( const FrontierRecord& left, @@ -439,7 +439,6 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef FrontierRecord record cdef FrontierRecord split_node_left cdef FrontierRecord split_node_right - cdef double middle_value cdef double left_child_min cdef double left_child_max cdef double right_child_min @@ -510,9 +509,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Lower bound for right child and upper bound for left child # are set to the same value. - middle_value = tree._get_middle_value(record.node_id) - right_child_min = middle_value - left_child_max = middle_value + right_child_min = record.middle_value + left_child_max = record.middle_value else: # i.e. splitter.monotonic_cst[split.feature] == -1 # Split on a feature with monotonic decrease constraint right_child_min = record.lower_bound @@ -520,9 +518,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Lower bound for left child and upper bound for right child # are set to the same value. - middle_value = tree._get_middle_value(record.node_id) - left_child_min = middle_value - right_child_max = middle_value + left_child_min = record.middle_value + right_child_max = record.middle_value # Decrement number of split nodes available max_split_nodes -= 1 @@ -650,6 +647,9 @@ cdef class BestFirstTreeBuilder(TreeBuilder): res.end = end res.depth = depth res.impurity = impurity + res.lower_bound = lower_bound + res.upper_bound = upper_bound + res.middle_value = splitter.criterion.middle_value() if not is_leaf: # is split node @@ -658,8 +658,6 @@ cdef class BestFirstTreeBuilder(TreeBuilder): res.improvement = split.improvement res.impurity_left = split.impurity_left res.impurity_right = split.impurity_right - res.lower_bound = lower_bound - res.upper_bound = upper_bound else: # is leaf => 0 improvement @@ -668,8 +666,6 @@ cdef class BestFirstTreeBuilder(TreeBuilder): res.improvement = 0.0 res.impurity_left = impurity res.impurity_right = impurity - res.lower_bound = lower_bound - res.upper_bound = upper_bound return 0 @@ -1327,48 +1323,6 @@ cdef class Tree: raise ValueError("Can't initialize array.") return arr - cdef inline double _get_middle_value(self, int node_id) nogil: - """Returns the new bound when a monotonicity constraint is active - - For regression, this is the average target value of the samples in - the parent node. - - For binary classification, this is the proportion of the negative - class of samples in the parent node. - """ - cdef: - int i - double middle_value = 0 - - # self.value conceptually has shape: - # - # [node_count, n_outputs, max_n_classes] - # - # but is here handled via a pointer and strides. - # - # Monotonicity constraints are only supported for single-output - # trees we can safely assume n_outputs == 1. - - # Furthermore, we assume that for - # classification trees, max_n_classes == 2 as we only support binary - # classification (and so self.value_stride == 2). - - if self.max_n_classes == 1: - # Regression: self.max_n_classes == 1, so self.value_stride == 1 - middle_value = self.value[node_id] - else: - # Binary classification: self.max_n_classes == 2 - - # This performs some raw pointers arithmetic. Ideally a memory view - # could wrap self.value so that indexing on several axis can be - # used. - for i in range(self.max_n_classes): - middle_value += self.value[node_id * self.value_stride + i] - - middle_value = self.value[node_id * self.value_stride] / middle_value - - return middle_value - def compute_partial_dependence(self, DTYPE_t[:, ::1] X, int[::1] target_features, double[::1] out): diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index d0992e8d6decc..3cbdac450c98e 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -179,7 +179,7 @@ def assert_1d_reg_tree_children_monotonic_bounded(tree_, monotonic_sign): i_left = tree_.children_left[i] i_right = tree_.children_right[i] assert values[i_left] <= values[i_right] - val_middle = values[i] + val_middle = (values[i_left] + values[i_right]) / 2 # Check bounds on grand-children, filtering out leaf nodes if tree_.feature[i_left] >= 0: i_left_right = tree_.children_right[i_left] @@ -261,6 +261,8 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): # Split node: check and update bounds for the children. i_left = tree_.children_left[i] i_right = tree_.children_right[i] + # unpack value from nx1x1 array + middle_value = (tree_.value[i_left][0][0] + tree_.value[i_right][0][0]) / 2 if monotonic_cst[feature] == 0: # Feature without monotonicity constraint: propagate bounds @@ -289,8 +291,8 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): # Propagate bounds down the tree to both children. lower_bound[i_left] = lower_bound[i] - upper_bound[i_left] = node_value - lower_bound[i_right] = node_value + upper_bound[i_left] = middle_value + lower_bound[i_right] = middle_value upper_bound[i_right] = upper_bound[i] elif monotonic_cst[feature] == -1: @@ -298,10 +300,10 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): assert tree_.value[i_left] >= tree_.value[i_right] # Update and propagate bounds down the tree to both children. - lower_bound[i_left] = node_value + lower_bound[i_left] = middle_value upper_bound[i_left] = upper_bound[i] lower_bound[i_right] = lower_bound[i] - upper_bound[i_right] = node_value + upper_bound[i_right] = middle_value else: # pragma: no cover raise ValueError(f"monotonic_cst[{feature}]={monotonic_cst[feature]}") @@ -369,8 +371,8 @@ def test_nd_tree_nodes_values( # c d e f # # i) a <= root <= b - # ii) c <= a <= d <= root - # iii) root <= min(e,f) + # ii) c <= a <= d <= (a+b)/2 + # iii) (a+b)/2 <= min(e,f) # For iii) we check that each node value is within the proper lower and # upper bounds. From c6e27d63e5d3ad1b95bef7a38b542d59272fbf8d Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Thu, 15 Jun 2023 17:08:51 +0200 Subject: [PATCH 153/174] Fix broken test caused by Splitter init change --- sklearn/tree/_splitter.pyx | 3 ++- sklearn/tree/tests/test_tree.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index f4606c83450e3..64ddde9f91913 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -114,7 +114,8 @@ cdef class Splitter: self.max_features, self.min_samples_leaf, self.min_weight_leaf, - self.random_state), self.__getstate__()) + self.random_state, + self.monotonic_cst), self.__getstate__()) cdef int init( self, diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index eefae6cdaa3f6..0468d5a2bdeee 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2367,7 +2367,7 @@ def test_splitter_serializable(Splitter): n_outputs, n_classes = 2, np.array([3, 2], dtype=np.intp) criterion = CRITERIA_CLF["gini"](n_outputs, n_classes) - splitter = Splitter(criterion, max_features, 5, 0.5, rng) + splitter = Splitter(criterion, max_features, 5, 0.5, rng, None) splitter_serialize = pickle.dumps(splitter) splitter_back = pickle.loads(splitter_serialize) From 187a5347b379ed443452e263cec36c707a3ad300 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Fri, 16 Jun 2023 17:22:40 +0200 Subject: [PATCH 154/174] Clean up Co-authored-by: Olivier Grisel --- sklearn/ensemble/_forest.py | 8 ++++---- sklearn/tree/_classes.py | 11 +++++++---- sklearn/tree/_criterion.pyx | 10 ++++++++-- sklearn/tree/tests/test_monotonic_tree.py | 2 +- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 856b59fdc164e..8fa6debb5929a 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1282,7 +1282,7 @@ class RandomForestClassifier(ForestClassifier): The constraints are only valid for binary classification and hold over the probability of the positive class. - .. versionadded:: 1.2 + .. versionadded:: 1.4 Attributes ---------- @@ -1652,7 +1652,7 @@ class RandomForestRegressor(ForestRegressor): Monotonicity constraints are not supported for multioutput regression (i.e. when `n_outputs_ > 1`). - .. versionadded:: 1.2 + .. versionadded:: 1.4 Attributes ---------- @@ -2016,7 +2016,7 @@ class ExtraTreesClassifier(ForestClassifier): The constraints are only valid for binary classification and hold over the probability of the positive class. - .. versionadded:: 1.2 + .. versionadded:: 1.4 Attributes ---------- @@ -2371,7 +2371,7 @@ class ExtraTreesRegressor(ForestRegressor): Monotonicity constraints are not supported for multioutput regression (i.e. when `n_outputs_ > 1`). - .. versionadded:: 1.2 + .. versionadded:: 1.4 Attributes ---------- diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 33a1129ced637..3df86d30ad9f9 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -411,6 +411,9 @@ def _fit( ) # Check to correct monotonicity constraint' specification, # by applying element-wise logical conjunction + # Note: we do not cast `np.asarray(self.monotonic_cst, dtype=np.int8)` + # straight away here so as to generate error messages for invalid + # values using the original values prior to any dtype related conversion. monotonic_cst = np.asarray(self.monotonic_cst) if monotonic_cst.shape[0] != X.shape[1]: raise ValueError( @@ -846,7 +849,7 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): The constraints are only valid for binary classifications and hold over the probability of the positive class. - .. versionadded:: 1.2 + .. versionadded:: 1.4 Attributes ---------- @@ -1237,7 +1240,7 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): Monotonicity constraints are not supported for multioutput regression (i.e. when `n_outputs_ > 1`). - .. versionadded:: 1.2 + .. versionadded:: 1.4 Attributes ---------- @@ -1577,7 +1580,7 @@ class ExtraTreeClassifier(DecisionTreeClassifier): The constraints are only valid for binary classifications and hold over the probability of the positive class. - .. versionadded:: 1.2 + .. versionadded:: 1.4 Attributes ---------- @@ -1836,7 +1839,7 @@ class ExtraTreeRegressor(DecisionTreeRegressor): Monotonicity constraints are not supported for multioutput regression (i.e. when `n_outputs_ > 1`). - .. versionadded:: 1.2 + .. versionadded:: 1.4 Attributes ---------- diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 856b3645f6838..1a3b07fcbfe81 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -587,7 +587,10 @@ cdef class ClassificationCriterion(Criterion): - single-output trees and - binary classifications. """ - return ((self.sum_left[0, 0] / self.weighted_n_left) + (self.sum_right[0, 0] / self.weighted_n_right)) / 2 + return ( + (self.sum_left[0, 0] / self.weighted_n_left) + + (self.sum_right[0, 0] / self.weighted_n_right) + ) / 2 cdef inline bint check_monotonicity( self, @@ -1034,7 +1037,10 @@ cdef class RegressionCriterion(Criterion): Monotonicity constraints are only supported for single-output trees we can safely assume n_outputs == 1. """ - return ((self.sum_left[0] / self.weighted_n_left) + (self.sum_right[0] / self.weighted_n_right)) / 2 + return ( + (self.sum_left[0] / self.weighted_n_left) + + (self.sum_right[0] / self.weighted_n_right) + ) / 2 cdef inline bint check_monotonicity( self, diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 3cbdac450c98e..15430cae25b6e 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -70,7 +70,7 @@ def test_montonic_constraints_regressions( ): n_samples = 1000 n_samples_train = 900 - # Build a classification task using 3 informative features + # Build a regression task using 5 informative features X, y = make_regression( n_samples=n_samples, n_features=5, From b4142c8b15edd728ca3dbf5d107ebe5be77fb939 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 19 Jun 2023 10:45:01 +0200 Subject: [PATCH 155/174] Make cython-lint happy Signed-off-by: Julien Jerphanion --- sklearn/tree/_criterion.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 1a3b07fcbfe81..08cc743bc208a 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -588,7 +588,7 @@ cdef class ClassificationCriterion(Criterion): - binary classifications. """ return ( - (self.sum_left[0, 0] / self.weighted_n_left) + + (self.sum_left[0, 0] / self.weighted_n_left) + (self.sum_right[0, 0] / self.weighted_n_right) ) / 2 From 55a054444bf2c74740b8b800ed7499297bddec2b Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Mon, 19 Jun 2023 14:18:40 +0200 Subject: [PATCH 156/174] Improve and test assert_1d_reg_tree_children_monotonic_bounded --- sklearn/tree/tests/test_monotonic_tree.py | 31 +++++++++++++++++++---- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 15430cae25b6e..6ff0cb44cce6e 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -171,22 +171,43 @@ def test_bad_monotonic_cst_raises(TreeClassifier): def assert_1d_reg_tree_children_monotonic_bounded(tree_, monotonic_sign): - # Flip values to always check for monotonic increase constraint - values = monotonic_sign * tree_.value + values = tree_.value for i in range(tree_.node_count): if tree_.children_left[i] > i and tree_.children_right[i] > i: # Check monotonicity on children i_left = tree_.children_left[i] i_right = tree_.children_right[i] - assert values[i_left] <= values[i_right] + if monotonic_sign == 1: + assert values[i_left] <= values[i_right] + elif monotonic_sign == -1: + assert values[i_left] >= values[i_right] val_middle = (values[i_left] + values[i_right]) / 2 # Check bounds on grand-children, filtering out leaf nodes if tree_.feature[i_left] >= 0: i_left_right = tree_.children_right[i_left] - assert values[i_left_right] <= val_middle + if monotonic_sign == 1: + assert values[i_left_right] <= val_middle + elif monotonic_sign == -1: + assert values[i_left_right] >= val_middle if tree_.feature[i_right] >= 0: i_right_left = tree_.children_left[i_right] - assert val_middle <= values[i_right_left] + if monotonic_sign == 1: + assert val_middle <= values[i_right_left] + elif monotonic_sign == -1: + assert val_middle >= values[i_right_left] + + +def test_assert_1d_reg_tree_children_monotonic_bounded(): + X = np.linspace(-1, 1, 7).reshape(-1, 1) + y = np.sin(2 * np.pi * X.ravel()) + + reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y) + + with pytest.raises(AssertionError): + assert_1d_reg_tree_children_monotonic_bounded(reg.tree_, 1) + + with pytest.raises(AssertionError): + assert_1d_reg_tree_children_monotonic_bounded(reg.tree_, -1) def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps): From e92a10ae5f06f5faa64ff0b919e2121d92b7ea56 Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Mon, 19 Jun 2023 14:42:41 +0200 Subject: [PATCH 157/174] Remove excessive assertions from test_assert_nd_reg_tree_children_monotonic_bounded --- sklearn/tree/tests/test_monotonic_tree.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 6ff0cb44cce6e..57b42cf48de82 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -345,30 +345,22 @@ def test_assert_nd_reg_tree_children_monotonic_bounded(): assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [0]) - # Check that assert_nd_reg_tree_children_monotonic_bounded does not raise + # Check that assert_nd_reg_tree_children_monotonic_bounded raises # when the data (and therefore the model) is naturally monotonic in the - # right direction. + # opposite direction. X = np.linspace(-5, 5, 5).reshape(-1, 1) y = X.ravel() ** 3 reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y) - assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [1]) - with pytest.raises(AssertionError): assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [-1]) - assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [0]) - # For completeness, check that the converse holds when swapping the sign. reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, -y) with pytest.raises(AssertionError): assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [1]) - assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [-1]) - - assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [0]) - @pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) @pytest.mark.parametrize("monotonic_sign", (-1, 1)) From ca7f03175b298d100a95d5f382a601d7945f828a Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Mon, 19 Jun 2023 15:07:37 +0200 Subject: [PATCH 158/174] Add tree ensembles to estimator tests --- sklearn/tree/tests/test_monotonic_tree.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 57b42cf48de82..f82376ecb49c5 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -8,12 +8,19 @@ ExtraTreeRegressor, ExtraTreeClassifier, ) +from sklearn.ensemble import ( + RandomForestRegressor, + RandomForestClassifier, + ExtraTreesRegressor, + ExtraTreesClassifier, +) TREE_CLASSIFIER_CLASSES = [DecisionTreeClassifier, ExtraTreeClassifier] TREE_REGRESSOR_CLASSES = [DecisionTreeRegressor, ExtraTreeRegressor] +TREE_BASED_CLASSIFIER_CLASSES = TREE_CLASSIFIER_CLASSES + [RandomForestClassifier, ExtraTreesClassifier] +TREE_BASED_REGRESSOR_CLASSES = TREE_REGRESSOR_CLASSES + [RandomForestRegressor, ExtraTreesRegressor] - -@pytest.mark.parametrize("TreeClassifier", TREE_CLASSIFIER_CLASSES) +@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) @pytest.mark.parametrize("depth_first", (True, False)) def test_montonic_constraints_classifications( TreeClassifier, depth_first, global_random_seed @@ -51,6 +58,8 @@ def test_montonic_constraints_classifications( ) if hasattr(est, "random_state"): est.set_params(**{"random_state": global_random_seed}) + if hasattr(est, "n_estimators"): + est.set_params(**{"n_estimators": 5}) est.fit(X_train, y_train) y = est.predict_proba(X_test)[:, 1] @@ -63,7 +72,7 @@ def test_montonic_constraints_classifications( assert np.all(est.predict_proba(X_test_1decr)[:, 1] >= y) -@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) +@pytest.mark.parametrize("TreeRegressor", TREE_BASED_REGRESSOR_CLASSES) @pytest.mark.parametrize("depth_first", (True, False)) def test_montonic_constraints_regressions( TreeRegressor, depth_first, global_random_seed @@ -115,7 +124,7 @@ def test_montonic_constraints_regressions( assert np.all(y_decr <= y) -@pytest.mark.parametrize("TreeClassifier", TREE_CLASSIFIER_CLASSES) +@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) def test_multiclass_raises(TreeClassifier): X, y = make_classification( n_samples=100, n_features=5, n_classes=3, n_informative=3, random_state=0 @@ -131,7 +140,7 @@ def test_multiclass_raises(TreeClassifier): est.fit(X, y) -@pytest.mark.parametrize("TreeClassifier", TREE_CLASSIFIER_CLASSES) +@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) def test_multiple_output_raises(TreeClassifier): X = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] y = [[1, 0, 1, 0, 1], [1, 0, 1, 0, 1]] @@ -144,7 +153,7 @@ def test_multiple_output_raises(TreeClassifier): est.fit(X, y) -@pytest.mark.parametrize("TreeClassifier", TREE_CLASSIFIER_CLASSES) +@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) def test_bad_monotonic_cst_raises(TreeClassifier): X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] y = [1, 0, 1, 0, 1] From e49a58b1137254c6e206839b5f3fc3d334235893 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Mon, 19 Jun 2023 15:17:54 +0200 Subject: [PATCH 159/174] Clean up tests Co-authored-by: Julien Jerphanion --- sklearn/tree/tests/test_monotonic_tree.py | 3 ++- sklearn/tree/tests/test_tree.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index f82376ecb49c5..8418a7ed76602 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -246,7 +246,8 @@ def test_1d_tree_nodes_values( # / \ / \ # c d e f # - # c <= d <= root <= e <= f + # a <= root <= b + # c <= d <= (a + b) / 2 <= e <= f rng = np.random.RandomState(global_random_seed) n_samples = 1000 diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 0468d5a2bdeee..a8e0a16cef411 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2367,7 +2367,7 @@ def test_splitter_serializable(Splitter): n_outputs, n_classes = 2, np.array([3, 2], dtype=np.intp) criterion = CRITERIA_CLF["gini"](n_outputs, n_classes) - splitter = Splitter(criterion, max_features, 5, 0.5, rng, None) + splitter = Splitter(criterion, max_features, 5, 0.5, rng, monotonic_cst=None) splitter_serialize = pickle.dumps(splitter) splitter_back = pickle.loads(splitter_serialize) From e89fa576371b6f8e265c9308a09ffd5980f772f2 Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Mon, 19 Jun 2023 15:53:06 +0200 Subject: [PATCH 160/174] Optimize Criterion modifications Co-authored-by: Julien Jerphanion --- sklearn/tree/_criterion.pxd | 4 ++-- sklearn/tree/_criterion.pyx | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index ab9d46ff10286..55b91dbcb1e38 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -82,7 +82,7 @@ cdef class Criterion: cnp.int8_t monotonic_cst, double lower_bound, double upper_bound, - ) nogil + ) noexcept nogil cdef inline bint _check_monotonicity( self, cnp.int8_t monotonic_cst, @@ -90,7 +90,7 @@ cdef class Criterion: double upper_bound, double sum_left, double sum_right, - ) nogil + ) noexcept nogil cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 08cc743bc208a..75570528767b4 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -223,7 +223,7 @@ cdef class Criterion: cnp.int8_t monotonic_cst, double lower_bound, double upper_bound, - ) nogil: + ) noexcept nogil: pass cdef inline bint _check_monotonicity( @@ -579,7 +579,7 @@ cdef class ClassificationCriterion(Criterion): memcpy(dest, &self.sum_total[k, 0], self.n_classes[k] * sizeof(double)) dest += self.max_n_classes - cdef double middle_value(self) noexcept nogil: + cdef inline double middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints as the simple average of the left and right children values. @@ -588,16 +588,16 @@ cdef class ClassificationCriterion(Criterion): - binary classifications. """ return ( - (self.sum_left[0, 0] / self.weighted_n_left) + - (self.sum_right[0, 0] / self.weighted_n_right) - ) / 2 + (self.sum_left[0, 0] / (2 * self.weighted_n_left)) + + (self.sum_right[0, 0] / (2 * self.weighted_n_right)) + ) cdef inline bint check_monotonicity( self, cnp.int8_t monotonic_cst, double lower_bound, double upper_bound, - ) nogil: + ) noexcept nogil: """Check monotonicity constraint is satisfied at the current classification split""" cdef: double sum_left = self.sum_left[0][0] @@ -1030,7 +1030,7 @@ cdef class RegressionCriterion(Criterion): for k in range(self.n_outputs): dest[k] = self.sum_total[k] / self.weighted_n_node_samples - cdef double middle_value(self) noexcept nogil: + cdef inline double middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints as the simple average of the left and right children values. @@ -1038,16 +1038,16 @@ cdef class RegressionCriterion(Criterion): n_outputs == 1. """ return ( - (self.sum_left[0] / self.weighted_n_left) + - (self.sum_right[0] / self.weighted_n_right) - ) / 2 + (self.sum_left[0] / (2 * self.weighted_n_left)) + + (self.sum_right[0] / (2 * self.weighted_n_right)) + ) cdef inline bint check_monotonicity( self, cnp.int8_t monotonic_cst, double lower_bound, double upper_bound, - ) nogil: + ) noexcept nogil: """Check monotonicity constraint is satisfied at the current regression split""" cdef: double sum_left = self.sum_left[0] From 80594a2eb80d44fa0f887c2212e7ef47344315e5 Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Mon, 19 Jun 2023 17:12:38 +0200 Subject: [PATCH 161/174] Black stuff + check that float value is not cast too early to int8 --- sklearn/tree/tests/test_monotonic_tree.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 8418a7ed76602..8af77a044ae74 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -17,8 +17,14 @@ TREE_CLASSIFIER_CLASSES = [DecisionTreeClassifier, ExtraTreeClassifier] TREE_REGRESSOR_CLASSES = [DecisionTreeRegressor, ExtraTreeRegressor] -TREE_BASED_CLASSIFIER_CLASSES = TREE_CLASSIFIER_CLASSES + [RandomForestClassifier, ExtraTreesClassifier] -TREE_BASED_REGRESSOR_CLASSES = TREE_REGRESSOR_CLASSES + [RandomForestRegressor, ExtraTreesRegressor] +TREE_BASED_CLASSIFIER_CLASSES = TREE_CLASSIFIER_CLASSES + [ + RandomForestClassifier, + ExtraTreesClassifier +] +TREE_BASED_REGRESSOR_CLASSES = TREE_REGRESSOR_CLASSES + [ + RandomForestRegressor, + ExtraTreesRegressor +] @pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) @pytest.mark.parametrize("depth_first", (True, False)) @@ -175,7 +181,7 @@ def test_bad_monotonic_cst_raises(TreeClassifier): est = TreeClassifier( max_depth=None, monotonic_cst=np.array([-1, 0.8]), random_state=0 ) - with pytest.raises(ValueError, match=msg): + with pytest.raises(ValueError, match=msg + "(.*)0.8]"): est.fit(X, y) From 7f9dc7103c8eeec515f2f3ebcdcf962e2c4497b4 Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Mon, 19 Jun 2023 18:38:53 +0200 Subject: [PATCH 162/174] Add test with opposite monotonicity data and constraints --- sklearn/tree/tests/test_monotonic_tree.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 8af77a044ae74..3022365a999d1 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -228,7 +228,28 @@ def test_assert_1d_reg_tree_children_monotonic_bounded(): def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps): X_grid = np.linspace(min_x, max_x, n_steps).reshape(-1, 1) y_pred_grid = clf.predict(X_grid) - assert (monotonic_sign * np.diff(y_pred_grid) >= 0.0).all() + if monotonic_sign == 1: + assert (np.diff(y_pred_grid) >= 0.0).all() + elif monotonic_sign == -1: + assert (np.diff(y_pred_grid) <= 0.0).all() + + +@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) +def test_1d_opposite_monotonicity_cst_data(TreeRegressor): + # Check that positive monotonic data with negative monotonic constraint + # yield constant predictions equal to the average target + X = np.linspace(-2, 2, 10).reshape(-1, 1) + y = X.ravel() + clf = TreeRegressor(monotonic_cst=[-1]) + clf.fit(X, y) + assert clf.tree_.node_count == 1 + assert clf.tree_.value[0] == 0. + + # Swap monotonicity + clf = TreeRegressor(monotonic_cst=[1]) + clf.fit(X, -y) + assert clf.tree_.node_count == 1 + assert clf.tree_.value[0] == 0. @pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) From 4e3edacc034f0c250718bbab01165de6b84d37a7 Mon Sep 17 00:00:00 2001 From: "Samuel O. Ronsin" Date: Wed, 21 Jun 2023 13:42:05 +0200 Subject: [PATCH 163/174] Implement value clipping * Clip the values of the nodes for constrained trees instead of checking the bounds for unconstrained features * Explain numerical noise issue in assert_nd_reg_tree_children_monotonic_bounded comparisons and the casting to 32-bit floats --- sklearn/tree/_criterion.pxd | 6 ++++ sklearn/tree/_criterion.pyx | 34 ++++++++++++++++++----- sklearn/tree/_splitter.pxd | 2 ++ sklearn/tree/_splitter.pyx | 7 +++++ sklearn/tree/_tree.pyx | 4 +++ sklearn/tree/tests/test_monotonic_tree.py | 10 +++++-- 6 files changed, 54 insertions(+), 9 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 55b91dbcb1e38..b765d324bebb9 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -69,6 +69,12 @@ cdef class Criterion: self, double* dest ) noexcept nogil + cdef void clip_node_value( + self, + double* dest, + double lower_bound, + double upper_bound + ) noexcept nogil cdef double middle_value(self) noexcept nogil cdef double impurity_improvement( self, diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 75570528767b4..c0d1fe2ad0a94 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -157,6 +157,9 @@ cdef class Criterion: """ pass + cdef void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil: + pass + cdef double middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints @@ -245,16 +248,11 @@ cdef class Criterion: (sum_left <= upper_bound * weighted_n_left) & (sum_right <= upper_bound * weighted_n_right) ) - bint check_monotonic_cst - - if monotonic_cst == 0: # No constraint - return check_lower_bound & check_upper_bound - else: - check_monotonic_cst = ( + bint check_monotonic_cst = ( (sum_left * weighted_n_right - sum_right * weighted_n_left) * monotonic_cst <= 0 ) - return check_lower_bound & check_upper_bound & check_monotonic_cst + return check_lower_bound & check_upper_bound & check_monotonic_cst cdef void init_sum_missing(self): """Init sum_missing to hold sums for missing values.""" @@ -579,6 +577,21 @@ cdef class ClassificationCriterion(Criterion): memcpy(dest, &self.sum_total[k, 0], self.n_classes[k] * sizeof(double)) dest += self.max_n_classes + cdef void clip_node_value(self, double * dest, double lower_bound, double upper_bound) noexcept nogil: + """Clip the value in dest between lower_bound and upper_bound for monotonic constraints. + + Note that monotonicity constraints are only supported for: + - single-output trees and + - binary classifications. + """ + if dest[0] < lower_bound: + dest[0] = lower_bound + elif dest[0] > upper_bound: + dest[0] = upper_bound + + # Class proportions for binary classification must sum to 1. + dest[1] = 1 - dest[0] + cdef inline double middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints as the simple average of the left and right children values. @@ -1030,6 +1043,13 @@ cdef class RegressionCriterion(Criterion): for k in range(self.n_outputs): dest[k] = self.sum_total[k] / self.weighted_n_node_samples + cdef inline void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil: + """Clip the value in dest between lower_bound and upper_bound for monotonic constraints.""" + if dest[0] < lower_bound: + dest[0] = lower_bound + elif dest[0] > upper_bound: + dest[0] = upper_bound + cdef inline double middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints as the simple average of the left and right children values. diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index bfb401fea808e..2547e14b324df 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -112,4 +112,6 @@ cdef class Splitter: cdef void node_value(self, double* dest) noexcept nogil + cdef void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil + cdef double node_impurity(self) noexcept nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 64ddde9f91913..edbfff13cd941 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -247,6 +247,11 @@ cdef class Splitter: self.criterion.node_value(dest) + cdef inline void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil: + """Clip the value in dest between lower_bound and upper_bound for monotonic constraints.""" + + self.criterion.clip_node_value(dest, lower_bound, upper_bound) + cdef double node_impurity(self) noexcept nogil: """Return the impurity of the current node.""" @@ -445,6 +450,7 @@ cdef inline int node_split_best( # Reject if monotonicity constraints are not satisfied if ( with_monotonic_cst and + monotonic_cst[current_split.feature] != 0 and not criterion.check_monotonicity( monotonic_cst[current_split.feature], lower_bound, @@ -800,6 +806,7 @@ cdef inline int node_split_random( # Reject if monotonicity constraints are not satisfied if ( with_monotonic_cst and + monotonic_cst[current_split.feature] != 0 and not criterion.check_monotonicity( monotonic_cst[current_split.feature], lower_bound, diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 60a0ab521e71b..ef35bb5858933 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -292,6 +292,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # Store value for all nodes, to facilitate tree/model # inspection and interpretation splitter.node_value(tree.value + node_id * tree.value_stride) + if splitter.with_monotonic_cst: + splitter.clip_node_value(tree.value + node_id * tree.value_stride, lower_bound, upper_bound) if not is_leaf: if ( @@ -641,6 +643,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # compute values also for split nodes (might become leafs later). splitter.node_value(tree.value + node_id * tree.value_stride) + if splitter.with_monotonic_cst: + splitter.clip_node_value(tree.value + node_id * tree.value_stride, lower_bound, upper_bound) res.node_id = node_id res.start = start diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 3022365a999d1..139af7e1e2d42 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -309,8 +309,14 @@ def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): for i in range(tree_.node_count): feature = tree_.feature[i] node_value = tree_.value[i][0][0] # unpack value from nx1x1 array - assert node_value <= upper_bound[i] - assert node_value >= lower_bound[i] + # While building the tree, the computed middle value is slightly + # different from the average of the siblings values, because + # sum_right / weighted_n_right + # is slightly different from the value of the right sibling. + # This can cause a discrepancy up to numerical noise when clipping, + # which is resolved by comparing with some loss of precision. + assert np.float32(node_value) <= np.float32(upper_bound[i]) + assert np.float32(node_value) >= np.float32(lower_bound[i]) if feature < 0: # Leaf: nothing to do From 4e1d7849e8e5a85fc848cd78f26f1430091b3638 Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Wed, 21 Jun 2023 13:58:30 +0200 Subject: [PATCH 164/174] Remove redundant tests: splitter kwarg is equivalent to choice of {Decision,Extra}TreeClassifier --- sklearn/tree/tests/test_monotonic_tree.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 139af7e1e2d42..449fed27f9c18 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -19,13 +19,14 @@ TREE_REGRESSOR_CLASSES = [DecisionTreeRegressor, ExtraTreeRegressor] TREE_BASED_CLASSIFIER_CLASSES = TREE_CLASSIFIER_CLASSES + [ RandomForestClassifier, - ExtraTreesClassifier + ExtraTreesClassifier, ] TREE_BASED_REGRESSOR_CLASSES = TREE_REGRESSOR_CLASSES + [ RandomForestRegressor, - ExtraTreesRegressor + ExtraTreesRegressor, ] + @pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) @pytest.mark.parametrize("depth_first", (True, False)) def test_montonic_constraints_classifications( @@ -243,21 +244,20 @@ def test_1d_opposite_monotonicity_cst_data(TreeRegressor): clf = TreeRegressor(monotonic_cst=[-1]) clf.fit(X, y) assert clf.tree_.node_count == 1 - assert clf.tree_.value[0] == 0. + assert clf.tree_.value[0] == 0.0 # Swap monotonicity clf = TreeRegressor(monotonic_cst=[1]) clf.fit(X, -y) assert clf.tree_.node_count == 1 - assert clf.tree_.value[0] == 0. + assert clf.tree_.value[0] == 0.0 @pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) @pytest.mark.parametrize("monotonic_sign", (-1, 1)) -@pytest.mark.parametrize("splitter", ("best", "random")) @pytest.mark.parametrize("depth_first", (True, False)) def test_1d_tree_nodes_values( - TreeRegressor, monotonic_sign, splitter, depth_first, global_random_seed + TreeRegressor, monotonic_sign, depth_first, global_random_seed ): # Adaptation from test_nodes_values in test_monotonic_constraints.py # in sklearn.ensemble._hist_gradient_boosting @@ -285,14 +285,12 @@ def test_1d_tree_nodes_values( if depth_first: # No max_leaf_nodes, default depth first tree builder clf = TreeRegressor( - splitter=splitter, monotonic_cst=[monotonic_sign], random_state=global_random_seed, ) else: # max_leaf_nodes triggers best first tree builder clf = TreeRegressor( - splitter=splitter, monotonic_cst=[monotonic_sign], max_leaf_nodes=n_samples, random_state=global_random_seed, @@ -407,10 +405,9 @@ def test_assert_nd_reg_tree_children_monotonic_bounded(): @pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) @pytest.mark.parametrize("monotonic_sign", (-1, 1)) -@pytest.mark.parametrize("splitter", ("best", "random")) @pytest.mark.parametrize("depth_first", (True, False)) def test_nd_tree_nodes_values( - TreeRegressor, monotonic_sign, splitter, depth_first, global_random_seed + TreeRegressor, monotonic_sign, depth_first, global_random_seed ): # Build tree with several features, and make sure the nodes # values respect the monotonicity constraints. @@ -442,14 +439,12 @@ def test_nd_tree_nodes_values( if depth_first: # No max_leaf_nodes, default depth first tree builder clf = TreeRegressor( - splitter=splitter, monotonic_cst=monotonic_cst, random_state=global_random_seed, ) else: # max_leaf_nodes triggers best first tree builder clf = TreeRegressor( - splitter=splitter, monotonic_cst=monotonic_cst, max_leaf_nodes=n_samples, random_state=global_random_seed, From 4e509fe588d0fde7411bdead6543e2508cae2069 Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Wed, 21 Jun 2023 14:56:55 +0200 Subject: [PATCH 165/174] Extend test coverage to random splitters --- sklearn/tree/tests/test_monotonic_tree.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 449fed27f9c18..bb0e5904db1fd 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import scipy.sparse from sklearn.datasets import make_classification, make_regression from sklearn.tree import ( @@ -28,9 +29,10 @@ @pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) -@pytest.mark.parametrize("depth_first", (True, False)) +@pytest.mark.parametrize("depth_first_builder", (True, False)) +@pytest.mark.parametrize("sparse_splitter", (True, False)) def test_montonic_constraints_classifications( - TreeClassifier, depth_first, global_random_seed + TreeClassifier, depth_first_builder, sparse_splitter, global_random_seed ): n_samples = 1000 n_samples_train = 900 @@ -55,7 +57,7 @@ def test_montonic_constraints_classifications( monotonic_cst[0] = 1 monotonic_cst[1] = -1 - if depth_first: + if depth_first_builder: est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) else: est = TreeClassifier( @@ -67,6 +69,8 @@ def test_montonic_constraints_classifications( est.set_params(**{"random_state": global_random_seed}) if hasattr(est, "n_estimators"): est.set_params(**{"n_estimators": 5}) + if sparse_splitter: + X_train = scipy.sparse.csc_matrix(X_train) est.fit(X_train, y_train) y = est.predict_proba(X_test)[:, 1] @@ -80,9 +84,10 @@ def test_montonic_constraints_classifications( @pytest.mark.parametrize("TreeRegressor", TREE_BASED_REGRESSOR_CLASSES) -@pytest.mark.parametrize("depth_first", (True, False)) +@pytest.mark.parametrize("depth_first_builder", (True, False)) +@pytest.mark.parametrize("sparse_splitter", (True, False)) def test_montonic_constraints_regressions( - TreeRegressor, depth_first, global_random_seed + TreeRegressor, depth_first_builder, sparse_splitter, global_random_seed ): n_samples = 1000 n_samples_train = 900 @@ -106,7 +111,7 @@ def test_montonic_constraints_regressions( monotonic_cst[0] = 1 monotonic_cst[1] = -1 - if depth_first: + if depth_first_builder: est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) else: est = TreeRegressor( @@ -118,6 +123,8 @@ def test_montonic_constraints_regressions( est.set_params(random_state=global_random_seed) if hasattr(est, "n_estimators"): est.set_params(**{"n_estimators": 5}) + if sparse_splitter: + X_train = scipy.sparse.csc_matrix(X_train) est.fit(X_train, y_train) y = est.predict(X_test) # Monotonic increase constraint From c0483dad47fa9d8a0cf5d2d5a65f4ab98173562e Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Thu, 22 Jun 2023 00:17:01 +0200 Subject: [PATCH 166/174] Add support for absolute error regression criterion --- sklearn/tree/_criterion.pyx | 58 ++++++++++++++++------- sklearn/tree/_tree.pyx | 11 +++-- sklearn/tree/tests/test_monotonic_tree.py | 32 +++++++++---- 3 files changed, 69 insertions(+), 32 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index c0d1fe2ad0a94..79f6346be239d 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -234,23 +234,20 @@ cdef class Criterion: cnp.int8_t monotonic_cst, double lower_bound, double upper_bound, - double sum_left, - double sum_right, + double value_left, + double value_right, ) nogil: cdef: - double weighted_n_left = self.weighted_n_left - double weighted_n_right = self.weighted_n_right bint check_lower_bound = ( - (sum_left >= lower_bound * weighted_n_left) & - (sum_right >= lower_bound * weighted_n_right) + (value_left >= lower_bound) & + (value_right >= lower_bound) ) bint check_upper_bound = ( - (sum_left <= upper_bound * weighted_n_left) & - (sum_right <= upper_bound * weighted_n_right) + (value_left <= upper_bound) & + (value_right <= upper_bound) ) bint check_monotonic_cst = ( - (sum_left * weighted_n_right - - sum_right * weighted_n_left) * monotonic_cst <= 0 + (value_left - value_right) * monotonic_cst <= 0 ) return check_lower_bound & check_upper_bound & check_monotonic_cst @@ -613,10 +610,10 @@ cdef class ClassificationCriterion(Criterion): ) noexcept nogil: """Check monotonicity constraint is satisfied at the current classification split""" cdef: - double sum_left = self.sum_left[0][0] - double sum_right = self.sum_right[0][0] + double value_left = self.sum_left[0][0] / self.weighted_n_left + double value_right = self.sum_right[0][0] / self.weighted_n_right - return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, sum_left, sum_right) + return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, value_left, value_right) cdef class Entropy(ClassificationCriterion): @@ -1050,7 +1047,7 @@ cdef class RegressionCriterion(Criterion): elif dest[0] > upper_bound: dest[0] = upper_bound - cdef inline double middle_value(self) noexcept nogil: + cdef double middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints as the simple average of the left and right children values. @@ -1062,7 +1059,7 @@ cdef class RegressionCriterion(Criterion): (self.sum_right[0] / (2 * self.weighted_n_right)) ) - cdef inline bint check_monotonicity( + cdef bint check_monotonicity( self, cnp.int8_t monotonic_cst, double lower_bound, @@ -1070,10 +1067,10 @@ cdef class RegressionCriterion(Criterion): ) noexcept nogil: """Check monotonicity constraint is satisfied at the current regression split""" cdef: - double sum_left = self.sum_left[0] - double sum_right = self.sum_right[0] + double value_left = self.sum_left[0] / self.weighted_n_left + double value_right = self.sum_right[0] / self.weighted_n_right - return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, sum_left, sum_right) + return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, value_left, value_right) cdef class MSE(RegressionCriterion): """Mean squared error impurity criterion. @@ -1404,6 +1401,31 @@ cdef class MAE(RegressionCriterion): for k in range(self.n_outputs): dest[k] = self.node_medians[k] + cdef inline double middle_value(self) noexcept nogil: + """Compute the middle value of a split for monotonicity constraints as the simple average + of the left and right children values. + + Monotonicity constraints are only supported for single-output trees we can safely assume + n_outputs == 1. + """ + return ( + ( self.left_child_ptr[0]).get_median() + + ( self.right_child_ptr[0]).get_median() + ) / 2 + + cdef inline bint check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + ) noexcept nogil: + """Check monotonicity constraint is satisfied at the current regression split""" + cdef: + double value_left = ( self.left_child_ptr[0]).get_median() + double value_right = ( self.right_child_ptr[0]).get_median() + + return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, value_left, value_right) + cdef double node_impurity(self) noexcept nogil: """Evaluate the impurity of the current node. diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index ef35bb5858933..e0aeec26bfef4 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -272,7 +272,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): lower_bound, upper_bound ) - middle_value = splitter.criterion.middle_value() # If EPSILON=0 in the below comparison, float precision # issues stop splitting, producing trees that are # dissimilar to v0.18 @@ -302,8 +301,9 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): ): # Split on a feature with no monotonicity constraint - # Current bounds must always be propagated to both children - # Splitting criterion will always check them in check_monotonicity + # Current bounds must always be propagated to both children. + # If a monotonic constraint is active, bounds are used in + # node value clipping. left_child_min = right_child_min = lower_bound left_child_max = right_child_max = upper_bound elif splitter.monotonic_cst[split.feature] == 1: @@ -313,6 +313,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # Lower bound for right child and upper bound for left child # are set to the same value. + middle_value = splitter.criterion.middle_value() right_child_min = middle_value left_child_max = middle_value else: # i.e. splitter.monotonic_cst[split.feature] == -1 @@ -322,6 +323,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # Lower bound for left child and upper bound for right child # are set to the same value. + middle_value = splitter.criterion.middle_value() left_child_min = middle_value right_child_max = middle_value @@ -501,7 +503,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Split on a feature with no monotonicity constraint # Current bounds must always be propagated to both children. - # Splitting criterion will always check them in check_monotonicity + # If a monotonic constraint is active, bounds are used in + # node value clipping. left_child_min = right_child_min = record.lower_bound left_child_max = right_child_max = record.upper_bound elif splitter.monotonic_cst[node.feature] == 1: diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index bb0e5904db1fd..5b83825a6d3c0 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -86,8 +86,9 @@ def test_montonic_constraints_classifications( @pytest.mark.parametrize("TreeRegressor", TREE_BASED_REGRESSOR_CLASSES) @pytest.mark.parametrize("depth_first_builder", (True, False)) @pytest.mark.parametrize("sparse_splitter", (True, False)) +@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error")) def test_montonic_constraints_regressions( - TreeRegressor, depth_first_builder, sparse_splitter, global_random_seed + TreeRegressor, depth_first_builder, sparse_splitter, criterion, global_random_seed ): n_samples = 1000 n_samples_train = 900 @@ -112,11 +113,16 @@ def test_montonic_constraints_regressions( monotonic_cst[1] = -1 if depth_first_builder: - est = TreeRegressor(max_depth=None, monotonic_cst=monotonic_cst) - else: est = TreeRegressor( max_depth=None, monotonic_cst=monotonic_cst, + criterion=criterion, + ) + else: + est = TreeRegressor( + max_depth=8, + monotonic_cst=monotonic_cst, + criterion=criterion, max_leaf_nodes=n_samples_train, ) if hasattr(est, "random_state"): @@ -245,7 +251,7 @@ def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps): @pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) def test_1d_opposite_monotonicity_cst_data(TreeRegressor): # Check that positive monotonic data with negative monotonic constraint - # yield constant predictions equal to the average target + # yield constant predictions, equal to the average of target values X = np.linspace(-2, 2, 10).reshape(-1, 1) y = X.ravel() clf = TreeRegressor(monotonic_cst=[-1]) @@ -262,9 +268,10 @@ def test_1d_opposite_monotonicity_cst_data(TreeRegressor): @pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) @pytest.mark.parametrize("monotonic_sign", (-1, 1)) -@pytest.mark.parametrize("depth_first", (True, False)) +@pytest.mark.parametrize("depth_first_builder", (True, False)) +@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error")) def test_1d_tree_nodes_values( - TreeRegressor, monotonic_sign, depth_first, global_random_seed + TreeRegressor, monotonic_sign, depth_first_builder, criterion, global_random_seed ): # Adaptation from test_nodes_values in test_monotonic_constraints.py # in sklearn.ensemble._hist_gradient_boosting @@ -289,10 +296,11 @@ def test_1d_tree_nodes_values( X = rng.rand(n_samples, n_features) y = rng.rand(n_samples) - if depth_first: + if depth_first_builder: # No max_leaf_nodes, default depth first tree builder clf = TreeRegressor( monotonic_cst=[monotonic_sign], + criterion=criterion, random_state=global_random_seed, ) else: @@ -300,6 +308,7 @@ def test_1d_tree_nodes_values( clf = TreeRegressor( monotonic_cst=[monotonic_sign], max_leaf_nodes=n_samples, + criterion=criterion, random_state=global_random_seed, ) clf.fit(X, y) @@ -412,9 +421,10 @@ def test_assert_nd_reg_tree_children_monotonic_bounded(): @pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) @pytest.mark.parametrize("monotonic_sign", (-1, 1)) -@pytest.mark.parametrize("depth_first", (True, False)) +@pytest.mark.parametrize("depth_first_builder", (True, False)) +@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error")) def test_nd_tree_nodes_values( - TreeRegressor, monotonic_sign, depth_first, global_random_seed + TreeRegressor, monotonic_sign, depth_first_builder, criterion, global_random_seed ): # Build tree with several features, and make sure the nodes # values respect the monotonicity constraints. @@ -443,10 +453,11 @@ def test_nd_tree_nodes_values( X = rng.rand(n_samples, n_features) y = rng.rand(n_samples) - if depth_first: + if depth_first_builder: # No max_leaf_nodes, default depth first tree builder clf = TreeRegressor( monotonic_cst=monotonic_cst, + criterion=criterion, random_state=global_random_seed, ) else: @@ -454,6 +465,7 @@ def test_nd_tree_nodes_values( clf = TreeRegressor( monotonic_cst=monotonic_cst, max_leaf_nodes=n_samples, + criterion=criterion, random_state=global_random_seed, ) clf.fit(X, y) From 16a88636b3eba77a94f4b1e7f79171eb454fd235 Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Thu, 22 Jun 2023 13:51:19 +0200 Subject: [PATCH 167/174] Mark missing values as unsupported --- sklearn/tree/_classes.py | 6 +++- sklearn/tree/tests/test_monotonic_tree.py | 39 +++++++++++++++++------ 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index c130243709b26..499a8ce6913cc 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -183,7 +183,11 @@ def get_n_leaves(self): return self.tree_.n_leaves def _support_missing_values(self, X): - return not issparse(X) and self._get_tags()["allow_nan"] + return ( + not issparse(X) + and self._get_tags()["allow_nan"] + and self.monotonic_cst is None + ) def _compute_missing_values_in_feature_mask(self, X): """Return boolean mask denoting if there are missing values for each feature. diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 5b83825a6d3c0..462ac7305d7c2 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -3,17 +3,17 @@ import scipy.sparse from sklearn.datasets import make_classification, make_regression +from sklearn.ensemble import ( + ExtraTreesClassifier, + ExtraTreesRegressor, + RandomForestClassifier, + RandomForestRegressor, +) from sklearn.tree import ( - DecisionTreeRegressor, DecisionTreeClassifier, - ExtraTreeRegressor, + DecisionTreeRegressor, ExtraTreeClassifier, -) -from sklearn.ensemble import ( - RandomForestRegressor, - RandomForestClassifier, - ExtraTreesRegressor, - ExtraTreesClassifier, + ExtraTreeRegressor, ) TREE_CLASSIFIER_CLASSES = [DecisionTreeClassifier, ExtraTreeClassifier] @@ -31,7 +31,7 @@ @pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) @pytest.mark.parametrize("depth_first_builder", (True, False)) @pytest.mark.parametrize("sparse_splitter", (True, False)) -def test_montonic_constraints_classifications( +def test_monotonic_constraints_classifications( TreeClassifier, depth_first_builder, sparse_splitter, global_random_seed ): n_samples = 1000 @@ -87,7 +87,7 @@ def test_montonic_constraints_classifications( @pytest.mark.parametrize("depth_first_builder", (True, False)) @pytest.mark.parametrize("sparse_splitter", (True, False)) @pytest.mark.parametrize("criterion", ("absolute_error", "squared_error")) -def test_montonic_constraints_regressions( +def test_monotonic_constraints_regressions( TreeRegressor, depth_first_builder, sparse_splitter, criterion, global_random_seed ): n_samples = 1000 @@ -173,6 +173,25 @@ def test_multiple_output_raises(TreeClassifier): est.fit(X, y) +@pytest.mark.parametrize( + "DecisionTreeEstimator", [DecisionTreeClassifier, DecisionTreeRegressor] +) +def test_missing_values_raises(DecisionTreeEstimator): + X, y = make_classification( + n_samples=100, n_features=5, n_classes=2, n_informative=3, random_state=0 + ) + X[0, 0] = np.nan + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = 1 + est = DecisionTreeEstimator( + max_depth=None, monotonic_cst=monotonic_cst, random_state=0 + ) + + msg = "Input X contains NaN" + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + + @pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) def test_bad_monotonic_cst_raises(TreeClassifier): X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] From c08c4bfcbe56511a41bdf437fd96d97a1554e76e Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Thu, 22 Jun 2023 18:45:26 +0200 Subject: [PATCH 168/174] Update docstrings to specify training data with missing values is not supported --- sklearn/ensemble/_forest.py | 26 ++++++++++++++++++-------- sklearn/tree/_classes.py | 26 ++++++++++++++++++-------- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 1194f119a1c6f..87323393a3284 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1281,8 +1281,12 @@ class RandomForestClassifier(ForestClassifier): If monotonic_cst is None, no constraints are applied. - The constraints are only valid for binary classification and hold - over the probability of the positive class. + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. .. versionadded:: 1.4 @@ -1651,8 +1655,9 @@ class RandomForestRegressor(ForestRegressor): If monotonic_cst is None, no constraints are applied. - Monotonicity constraints are not supported for multioutput regression - (i.e. when `n_outputs_ > 1`). + Monotonicity constraints are not supported for: + - multioutput regressions (i.e. when `n_outputs_ > 1`), + - regressions trained on data with missing values. .. versionadded:: 1.4 @@ -2015,8 +2020,12 @@ class ExtraTreesClassifier(ForestClassifier): If monotonic_cst is None, no constraints are applied. - The constraints are only valid for binary classification and hold - over the probability of the positive class. + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. .. versionadded:: 1.4 @@ -2370,8 +2379,9 @@ class ExtraTreesRegressor(ForestRegressor): If monotonic_cst is None, no constraints are applied. - Monotonicity constraints are not supported for multioutput regression - (i.e. when `n_outputs_ > 1`). + Monotonicity constraints are not supported for: + - multioutput regressions (i.e. when `n_outputs_ > 1`), + - regressions trained on data with missing values. .. versionadded:: 1.4 diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 499a8ce6913cc..d0b976cae685a 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -851,8 +851,12 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): If monotonic_cst is None, no constraints are applied. - The constraints are only valid for binary classifications and hold - over the probability of the positive class. + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. .. versionadded:: 1.4 @@ -1242,8 +1246,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): If monotonic_cst is None, no constraints are applied. - Monotonicity constraints are not supported for multioutput regression - (i.e. when `n_outputs_ > 1`). + Monotonicity constraints are not supported for: + - multioutput regressions (i.e. when `n_outputs_ > 1`), + - regressions trained on data with missing values. .. versionadded:: 1.4 @@ -1582,8 +1587,12 @@ class ExtraTreeClassifier(DecisionTreeClassifier): If monotonic_cst is None, no constraints are applied. - The constraints are only valid for binary classifications and hold - over the probability of the positive class. + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. .. versionadded:: 1.4 @@ -1841,8 +1850,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor): If monotonic_cst is None, no constraints are applied. - Monotonicity constraints are not supported for multioutput regression - (i.e. when `n_outputs_ > 1`). + Monotonicity constraints are not supported for: + - multioutput regressions (i.e. when `n_outputs_ > 1`), + - regressions trained on data with missing values. .. versionadded:: 1.4 From e77c505a4770287a0cfcfd0bfbc401ae1c802c89 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Sat, 24 Jun 2023 18:17:05 +0200 Subject: [PATCH 169/174] DOC Add changelog entries Signed-off-by: Julien Jerphanion --- doc/whats_new/v1.4.rst | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 4946576d83056..ff7f2523cc681 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -58,3 +58,24 @@ TODO: update at the time of the release. :meth:`base.OutlierMixin.fit_predict` now accept ``**kwargs`` which are passed to the ``fit`` method of the the estimator. :pr:`26506` by `Adrin Jalali`_. + + +:mod:`sklearn.ensemble` +....................... + +- |Feature| :class:`ensemble.RandomForestClassifier`, + :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier` + and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints in + the absence of missing values for single-output problems. + :pr:`13649` by :user:`Samuel Ronsin `, + initiated by :user:`Patrick O'Reilly `. + + +:mod:`sklearn.tree` +................... + +- |Feature| :class:`tree.DecisionTreeClassifier`, :class:`tree.DecisionTreeRegressor`, + :class:`tree.ExtraTreeClassifier` and :class:`tree.ExtraTreeRegressor` now support + monotonic constraints in the absence of missing values for single-output problems. + :pr:`13649` by :user:`Samuel Ronsin `, initiated by + :user:`Patrick O'Reilly `. From 9b3bc221b7a8822020a89fe8db9ddac45d5b1595 Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Fri, 30 Jun 2023 10:27:28 +0200 Subject: [PATCH 170/174] Add comment regarding monotonic_cst sign flipping for classification trees --- sklearn/tree/_classes.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index d0b976cae685a..bf43f79e65d12 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -441,7 +441,11 @@ def _fit( "Monotonicity constraints are not supported with multiclass " "classification" ) - # Imposing the constraint on the probability of the positive class + # Binary classification trees are built by constraining the probabilities + # of the *negative class* in order to make the implementation similar to + # regression trees. + # Since self.monotonic_cst encodes constraints on the *positive class*, + # its signs must be flipped before passing it to the Splitter. monotonic_cst *= -1 if not isinstance(self.splitter, Splitter): From 816778d16f23d9cf9dbb18aad68b34a34503da9c Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Fri, 30 Jun 2023 12:10:03 +0200 Subject: [PATCH 171/174] Use np.isin to validate monotonic_cst --- sklearn/tree/_classes.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index bf43f79e65d12..7af692d23f1e8 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -417,7 +417,7 @@ def _fit( # Check to correct monotonicity constraint' specification, # by applying element-wise logical conjunction # Note: we do not cast `np.asarray(self.monotonic_cst, dtype=np.int8)` - # straight away here so as to generate error messages for invalid + # straight away here so as to generate error messages for invalid # values using the original values prior to any dtype related conversion. monotonic_cst = np.asarray(self.monotonic_cst) if monotonic_cst.shape[0] != X.shape[1]: @@ -425,10 +425,8 @@ def _fit( "monotonic_cst has shape {} but the input data " "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) ) - unsatisfied_constraints_conditions = ( - (monotonic_cst != -1) * (monotonic_cst != 0) * (monotonic_cst != 1) - ) - if np.any(unsatisfied_constraints_conditions): + valid_constraints = np.isin(monotonic_cst, (-1, 0, 1)) + if not np.all(valid_constraints): unique_constaints_value = np.unique(monotonic_cst) raise ValueError( "monotonic_cst must be None or an array-like of -1, 0 or 1, but" From 88ae965c4b9a62747cfc1c4053d0c9256d4cc6ad Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Fri, 30 Jun 2023 12:23:48 +0200 Subject: [PATCH 172/174] Update whats_new --- doc/whats_new/v1.4.rst | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index ff7f2523cc681..5533adff4343f 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -65,8 +65,9 @@ TODO: update at the time of the release. - |Feature| :class:`ensemble.RandomForestClassifier`, :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier` - and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints in - the absence of missing values for single-output problems. + and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints, + useful when features are supposed to have a positive/negative effect on the target. + Missing values in the train data and multi-output targets are not supported. :pr:`13649` by :user:`Samuel Ronsin `, initiated by :user:`Patrick O'Reilly `. @@ -76,6 +77,8 @@ TODO: update at the time of the release. - |Feature| :class:`tree.DecisionTreeClassifier`, :class:`tree.DecisionTreeRegressor`, :class:`tree.ExtraTreeClassifier` and :class:`tree.ExtraTreeRegressor` now support - monotonic constraints in the absence of missing values for single-output problems. + monotonic constraints, useful when features are supposed to have a positive/negative + effect on the target. Missing values in the train data and multi-output targets are + not supported. :pr:`13649` by :user:`Samuel Ronsin `, initiated by :user:`Patrick O'Reilly `. From 31d5edb8efe8b71666f1e123bdfe4ab1e50486e1 Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Fri, 30 Jun 2023 12:48:11 +0200 Subject: [PATCH 173/174] Fix lign too long --- sklearn/tree/_classes.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 7af692d23f1e8..673137e1cc98a 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -439,11 +439,11 @@ def _fit( "Monotonicity constraints are not supported with multiclass " "classification" ) - # Binary classification trees are built by constraining the probabilities - # of the *negative class* in order to make the implementation similar to - # regression trees. - # Since self.monotonic_cst encodes constraints on the *positive class*, - # its signs must be flipped before passing it to the Splitter. + # Binary classification trees are built by constraining probabilities + # of the *negative class* in order to make the implementation similar + # to regression trees. + # Since self.monotonic_cst encodes constraints on probabilities of the + # *positive class*, all signs must be flipped. monotonic_cst *= -1 if not isinstance(self.splitter, Splitter): From 57c78a5fe5e346f84dc5669e3c72144d059f8a2d Mon Sep 17 00:00:00 2001 From: Samuel Ronsin Date: Fri, 30 Jun 2023 18:25:30 +0200 Subject: [PATCH 174/174] Add pointers to user guide --- sklearn/ensemble/_forest.py | 8 ++++++++ sklearn/tree/_classes.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 87323393a3284..3d984c104f891 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1288,6 +1288,8 @@ class RandomForestClassifier(ForestClassifier): The constraints hold over the probability of the positive class. + Read more in the :ref:`User Guide `. + .. versionadded:: 1.4 Attributes @@ -1659,6 +1661,8 @@ class RandomForestRegressor(ForestRegressor): - multioutput regressions (i.e. when `n_outputs_ > 1`), - regressions trained on data with missing values. + Read more in the :ref:`User Guide `. + .. versionadded:: 1.4 Attributes @@ -2027,6 +2031,8 @@ class ExtraTreesClassifier(ForestClassifier): The constraints hold over the probability of the positive class. + Read more in the :ref:`User Guide `. + .. versionadded:: 1.4 Attributes @@ -2383,6 +2389,8 @@ class ExtraTreesRegressor(ForestRegressor): - multioutput regressions (i.e. when `n_outputs_ > 1`), - regressions trained on data with missing values. + Read more in the :ref:`User Guide `. + .. versionadded:: 1.4 Attributes diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 673137e1cc98a..a9f367f0b21d3 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -860,6 +860,8 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): The constraints hold over the probability of the positive class. + Read more in the :ref:`User Guide `. + .. versionadded:: 1.4 Attributes @@ -1252,6 +1254,8 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): - multioutput regressions (i.e. when `n_outputs_ > 1`), - regressions trained on data with missing values. + Read more in the :ref:`User Guide `. + .. versionadded:: 1.4 Attributes @@ -1596,6 +1600,8 @@ class ExtraTreeClassifier(DecisionTreeClassifier): The constraints hold over the probability of the positive class. + Read more in the :ref:`User Guide `. + .. versionadded:: 1.4 Attributes @@ -1856,6 +1862,8 @@ class ExtraTreeRegressor(DecisionTreeRegressor): - multioutput regressions (i.e. when `n_outputs_ > 1`), - regressions trained on data with missing values. + Read more in the :ref:`User Guide `. + .. versionadded:: 1.4 Attributes