Skip to content

FEA Add Gamma deviance as loss function to HGBT #22409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jan 30, 2023

Conversation

lorentzenchr
Copy link
Member

@lorentzenchr lorentzenchr commented Feb 7, 2022

Reference Issues/PRs

Follow-up of #20567 and #20811.

What does this implement/fix? Explain your changes.

This PR adds the Gamma deviance as loss function to HistGradientBoostingRegressor(loss="gamma").

Any other comments?

"Loss-wise", this brings HGBT closer to paritity with XGBoost and LightGBM.

Open question: In practice, I (almost) always observed that a post-train calibration of gradient boosted trees with a gamma loss is beneficial not only for calibration but also for out-of-sample performance. The step is to add a multiplicative constant (or in link-space additive constant = reset _baseline_prediction) such that on the training set one has y_train.mean() == model.predict(X_train).mean(). Should we include it by default or add an option or not consider it (which would make me sad)?

Copy link
Member Author

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need help with constructing a good test.

gbdt.fit(np.zeros(shape=(len(y), 1)), y)


def test_gamma():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, this test fails because the HGBT with poisson loss out-performs all other models on the test set.
I observe the exact same behaviour with (often) similar number with LightGBM:

from lightgbm import LGBMRegressor

lgbm_gamma = LGBMRegressor(objective="gamma")
lgbm_pois = LGBMRegressor(objective="poisson")

I don't know how to construct a better test - for now.

Copy link
Member

@ogrisel ogrisel Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true only for the small sample sizes that we typically use in our tests our would be this still true for a very large n_samples?

Is it similar to the fact that the median can be a better estimate of the expected value than the sample mean for a finite sample of a long tailed distribution (robustness to outliers)? This was observed empirically when computing metrics on a left out validation set here: https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played with n_sample and n_features without success. This problem that Poisson HGBT always beats Gamma HGBT out-of-sample prevails. I think one reason is that the log-link is not canonical to Gamma deciance, but it is canonical to the Poisson deviance. Therefore, the Poisson HGBT is better calibrated, which here also translates to a better out-of-sample performance. Mabye, the features are not informative enough in this example.

I can try to generate a decision tree and use it to generate a Gamma distributed sample per tree node. That should be easier to fit for tree based models than this GLM based data generation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try to generate a decision tree and use it to generate a Gamma distributed sample per tree node. That should be easier to fit for tree based models than this GLM based data generation.

That would be interesting to know indeed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to create gamma distributed data with a decision tree structure for the mean value, see details.
This gives the same result: poisson loss beats gamma loss when comparing out-of-sample gamma deviance.

Proposition for solutions:

  1. Remove this test and add (and rely on) a test comparing with LGBT.
  2. Do not compare with poisson loss, just with squared error (which depends a lot on the clipping constant!)

I'm in favor of the first point.

import numpy as np
from sklearn._loss import HalfGammaLoss
from sklearn.dummy import DummyRegressor
from sklearn.datasets import make_low_rank_matrix
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.metrics import mean_gamma_deviance
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor

from lightgbm import LGBMRegressor


### Data Generation
rng = np.random.RandomState(42)
n_train, n_test, n_features = 500, 500, 20
X = make_low_rank_matrix(
    n_samples=n_train + n_test, n_features=n_features, random_state=rng, 
)
# First, we create a normal distributed data
coef = rng.uniform(low=-10, high=20, size=n_features)
y_normal = np.abs(rng.normal(loc=np.exp(X @ coef)))

# Second, we fit a decision tree
dtree = DecisionTreeRegressor().fit(X, y_normal)

# Finally, we generate gamma distributed data according to the tree
# mean(y) = dtree.predict(X)
# var(y) = dispersion * mean(y)**2
dispersion = 0.5
y = rng.gamma(shape=1 / dispersion, scale=dispersion * dtree.predict(X))

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=n_test, random_state=rng
)


### Model fitting
params = {"early_stopping": False, "random_state": 123}
p_lgbt = {"min_child_weight": 0, "random_state": 123}

gbdt_gamma = HistGradientBoostingRegressor(loss="gamma", **params)
gbdt_pois = HistGradientBoostingRegressor(loss="poisson", **params)
gbdt_ls = HistGradientBoostingRegressor(loss="squared_error", **params)
lgbm_gamma = LGBMRegressor(objective="gamma", **p_lgbt)
lgbm_pois = LGBMRegressor(objective="poisson", **p_lgbt)
for model in (gbdt_gamma, gbdt_pois, gbdt_ls, lgbm_gamma, lgbm_pois):
    model.fit(X_train, y_train)
dummy = DummyRegressor(strategy="mean").fit(X_train, y_train)

for m in (dummy, gbdt_gamma, gbdt_ls, gbdt_pois, lgbm_gamma, lgbm_pois):
    message = f"training gamma deviance {m.__class__.__name__}"
    if hasattr(m, "loss"):
        message += " " + m.loss
    elif hasattr(m, "objective"):
        message += " " + m.objective
    print(f"{message: <68}: {mean_gamma_deviance(y_train, np.clip(m.predict(X_train), 1e-12, None))}")

for m in (dummy, gbdt_gamma, gbdt_ls, gbdt_pois, lgbm_gamma, lgbm_pois):
    message = f"test gamma deviance {m.__class__.__name__}"
    if hasattr(m, "loss"):
        message += " " + m.loss
    elif hasattr(m, "objective"):
        message += " " + m.objective
    print(f"{message: <68}: {mean_gamma_deviance(y_test, np.clip(m.predict(X_test), 1e-12, None))}")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced Poisson with squared error as model to compare against, bb234ee. Now it works. Comments on comparison to Poisson are preserved.

gbdt.fit(np.zeros(shape=(len(y), 1)), y)


def test_gamma():
Copy link
Member

@ogrisel ogrisel Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true only for the small sample sizes that we typically use in our tests our would be this still true for a very large n_samples?

Is it similar to the fact that the median can be a better estimate of the expected value than the sample mean for a finite sample of a long tailed distribution (robustness to outliers)? This was observed empirically when computing metrics on a left out validation set here: https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I messed up: I used direct comments while I was thinking of using a making review comments. Please ignore this.

@lorentzenchr
Copy link
Member Author

lorentzenchr commented Apr 20, 2022

I excluded the out-of-sample test of Gamma vs Poisson, but added poisson and gamma to the comparison with lightgbm.

Poisson passes as expected, but Gamma does not. Some tree leaves agree, others do not, see details.

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.datasets import make_classification, make_regression
import numpy as np
import pytest

from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper
from sklearn.ensemble._hist_gradient_boosting.utils import get_equivalent_estimator

seed = 0
loss = "gamma"
min_samples_leaf = 20
n_samples, max_leaf_nodes = 255, 4096

rng = np.random.RandomState(seed=seed)
max_iter = 1
max_bins = 255

X, y = make_regression(
    n_samples=n_samples, n_features=5, n_informative=5, random_state=0
)

if loss in ("gamma", "poisson"):
    # make the target positive
    y = np.abs(y) + np.mean(np.abs(y))

if n_samples > 255:
    # bin data and convert it to float32 so that the estimator doesn't
    # treat it as pre-binned
    X = _BinMapper(n_bins=max_bins + 1).fit_transform(X).astype(np.float32)

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng)

est_sklearn = HistGradientBoostingRegressor(
    loss=loss,
    max_iter=max_iter,
    max_bins=max_bins,
    learning_rate=1,
    early_stopping=False,
    min_samples_leaf=min_samples_leaf,
    max_leaf_nodes=max_leaf_nodes,
)
est_lightgbm = get_equivalent_estimator(est_sklearn, lib="lightgbm")
est_lightgbm.set_params(poisson_max_delta_step=1e-10, min_sum_hessian_in_leaf=1e-10)

est_lightgbm.fit(X_train, y_train)
est_sklearn.fit(X_train, y_train)

# We need X to be treated an numerical data, not pre-binned data.
X_train, X_test = X_train.astype(np.float32), X_test.astype(np.float32)

pred_lightgbm = est_lightgbm.predict(X_train)
pred_sklearn = est_sklearn.predict(X_train)
# less than 1% of the predictions are different up to the 3rd decimal
assert np.mean(abs(pred_lightgbm - pred_sklearn) > 1e-3) < 0.011
print(est_lightgbm.booster_.model_to_string())
tree
version=v3
num_class=1
num_tree_per_iteration=1
label_index=0
max_feature_idx=4
objective=gamma
feature_names=Column_0 Column_1 Column_2 Column_3 Column_4
feature_infos=[-2.8345545052747023:2.3207998392802982] [-2.6591722379967409:3.1709747732901796] [-2.7773591454274333:2.1167910214836754] [-2.7396771671895563:2.3039166976839418] [-2.369586905226603:2.7593551140215822]
tree_sizes=832

Tree=0
num_leaves=7
num_cat=0
split_feature=4 3 3 4 1 3
split_gain=1.50966 3.80344 2.42188 1.54375 0.428615 0.932329
threshold=0.66013127401116201 -0.31332876935095394 0.47338878305964766 -0.83951219291483059 -0.36042762073349738 0.61440687356221846
decision_type=2 2 2 2 2 2
left_child=1 3 -2 -1 -3 -6
right_child=2 4 -4 -5 5 -7
leaf_value=5.7732130694905219 5.4075671050861942 5.1374153352283081 5.8184182891726719 5.4470463945554721 5.170449726462687 5.4388259391229958
leaf_weight=26.181041717529297 31.334081172943115 22.670285880565643 26.466757297515869 32.554425358772278 26.350793361663818 25.442615479230881
leaf_count=18 33 30 17 33 34 26
internal_value=5.46073 5.40217 5.59569 5.59243 5.25209 5.30228
internal_weight=0 133.199 57.8008 58.7355 74.4637 51.7934
internal_count=191 141 50 51 90 60
is_linear=0
shrinkage=1


end of trees

feature_importances:
Column_3=3
Column_4=2
Column_1=1

parameters:
[boosting: gbdt]
[objective: gamma]
[metric: gamma]
[tree_learner: serial]
[device_type: cpu]
[data: ]
[valid: ]
[num_iterations: 1]
[learning_rate: 1]
[num_leaves: 4096]
[num_threads: -1]
[deterministic: 0]
[force_col_wise: 0]
[force_row_wise: 0]
[histogram_pool_size: -1]
[max_depth: -1]
[min_data_in_leaf: 20]
[min_sum_hessian_in_leaf: 1e-10]
[bagging_fraction: 1]
[pos_bagging_fraction: 1]
[neg_bagging_fraction: 1]
[bagging_freq: 0]
[bagging_seed: 3]
[feature_fraction: 1]
[feature_fraction_bynode: 1]
[feature_fraction_seed: 2]
[extra_trees: 0]
[extra_seed: 6]
[early_stopping_round: 0]
[first_metric_only: 0]
[max_delta_step: 0]
[lambda_l1: 0]
[lambda_l2: 0]
[linear_lambda: 0]
[min_gain_to_split: 0]
[drop_rate: 0.1]
[max_drop: 50]
[skip_drop: 0.5]
[xgboost_dart_mode: 0]
[uniform_drop: 0]
[drop_seed: 4]
[top_rate: 0.2]
[other_rate: 0.1]
[min_data_per_group: 100]
[max_cat_threshold: 32]
[cat_l2: 10]
[cat_smooth: 10]
[max_cat_to_onehot: 4]
[top_k: 20]
[monotone_constraints: ]
[monotone_constraints_method: basic]
[monotone_penalty: 0]
[feature_contri: ]
[forcedsplits_filename: ]
[refit_decay_rate: 0.9]
[cegb_tradeoff: 1]
[cegb_penalty_split: 0]
[cegb_penalty_feature_lazy: ]
[cegb_penalty_feature_coupled: ]
[path_smooth: 0]
[interaction_constraints: ]
[verbosity: -10]
[saved_feature_importance_type: 0]
[linear_tree: 0]
[max_bin: 255]
[max_bin_by_feature: ]
[min_data_in_bin: 1]
[bin_construct_sample_cnt: 200000]
[data_random_seed: 1]
[is_enable_sparse: 1]
[enable_bundle: 0]
[use_missing: 1]
[zero_as_missing: 0]
[feature_pre_filter: 1]
[pre_partition: 0]
[two_round: 0]
[header: 0]
[label_column: ]
[weight_column: ]
[group_column: ]
[ignore_column: ]
[categorical_feature: ]
[forcedbins_filename: ]
[precise_float_parser: 0]
[objective_seed: 5]
[num_class: 1]
[is_unbalance: 0]
[scale_pos_weight: 1]
[sigmoid: 1]
[boost_from_average: 1]
[reg_sqrt: 0]
[alpha: 0.9]
[fair_c: 1]
[poisson_max_delta_step: 1e-10]
[tweedie_variance_power: 1.5]
[lambdarank_truncation_level: 30]
[lambdarank_norm: 1]
[label_gain: ]
[eval_at: ]
[multi_error_top_k: 1]
[auc_mu_weights: ]
[num_machines: 1]
[local_listen_port: 12400]
[time_out: 120]
[machine_list_filename: ]
[machines: ]
[gpu_platform_id: -1]
[gpu_device_id: -1]
[gpu_use_dp: 0]
[num_gpu: 1]

end of parameters

pandas_categorical:null
est_sklearn._predictors[0][0].nodes
array([( 0.        , 191, 4,  0.66013127, 1,  1, 10,  1.50966412, 0, 0, 140, 0, 0),
       (-0.05856522, 141, 3, -0.31332877, 0,  2,  5,  3.80344076, 1, 0,  66, 0, 0),
       ( 0.13170011,  51, 4, -0.76239709, 0,  3,  4,  1.3315025 , 2, 0,  50, 0, 0),
       ( 0.28879573,  20, 0,  0.        , 0,  0,  0, -1.        , 3, 1,   0, 0, 0),
       (-0.0126036 ,  31, 0,  0.        , 0,  0,  0, -1.        , 3, 1,   0, 0, 0),
       (-0.2086427 ,  90, 1, -0.36042762, 0,  6,  7,  0.42861547, 2, 0,  63, 0, 0),
       (-0.32331813,  30, 0,  0.        , 0,  0,  0, -1.        , 3, 1,   0, 0, 0),
       (-0.15844857,  60, 3,  0.61440687, 1,  8,  9,  0.93232857, 3, 0, 131, 0, 0),
       (-0.29028374,  34, 0,  0.        , 0,  0,  0, -1.        , 4, 1,   0, 0, 0),
       (-0.02190753,  26, 0,  0.        , 0,  0,  0, -1.        , 4, 1,   0, 0, 0),
       ( 0.13496065,  50, 3,  0.30072468, 1, 11, 12,  2.13598751, 1, 0, 114, 0, 0),
       (-0.06158485,  30, 0,  0.        , 0,  0,  0, -1.        , 2, 1,   0, 0, 0),
       ( 0.32297953,  20, 0,  0.        , 0,  0,  0, -1.        , 2, 1,   0, 0, 0)],
      dtype=[('value', '<f8'), ('count', '<u4'), ('feature_idx', '<u4'), ('num_threshold', '<f8'), ('missing_go_to_left', 'u1'), ('left', '<u4'), ('right', '<u4'), ('gain', '<f8'), ('depth', '<u4'), ('is_leaf', 'u1'), ('bin_threshold', 'u1'), ('is_categorical', 'u1'), ('bitset_idx', '<u4')])

@lorentzenchr lorentzenchr added this to the 1.2 milestone Oct 25, 2022
@jeremiedbb
Copy link
Member

We won't have time to finish the review on this one before the 1.2 release. Moving it to 1.3

@jeremiedbb jeremiedbb modified the milestones: 1.2, 1.3 Nov 24, 2022
Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding support for the gamma deviance, @lorentzenchr.

This LGTM, is there anything to improve or to add as initially discussed in this thread https://github.com/scikit-learn/scikit-learn/pull/22409/files#r801057198?

@haiatn
Copy link
Contributor

haiatn commented Dec 16, 2022

Waiting for this to merge. Good job!

@lorentzenchr lorentzenchr added the Waiting for Second Reviewer First reviewer is done, need a second one! label Dec 28, 2022
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will let @ogrisel review this one since he already had a look at it.

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hab ein gutes Jahr!, @lorentzenchr

This looks good to me.

@lorentzenchr
Copy link
Member Author

@jjerphan Merci, toi aussi!

@jjerphan jjerphan changed the title FEA add gamma loss to HGBT FEA Add Gamma deviance as loss function to HGBT Jan 12, 2023
@lorentzenchr
Copy link
Member Author

@jjerphan Thanks for helping out. The remaining CI failure seems unrelated.

@jjerphan
Copy link
Member

Yes, the Ubuntu repository of Azure have not been responding lately.

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@lorentzenchr
Copy link
Member Author

@ogrisel @glemaitre Any change this PR could get a second review from you?

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more feedback to try to improve and probably fix a problem in test_same_predictions_regression in particular.

Beyond this, LGTM!


if max_leaf_nodes < 10 and n_samples >= 1000:
if loss in ("gamma", "poisson"):
# More than 65% of the predictions must be close up to the 2nd decimal.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit lax. Any idea why we can be stricter in the case of the squared error? I wouldn't require to make it stricter as a pre-requisite to merge this PR as the Poisson loss is already in main but lightgbm equivalence was previously just not tested.

But we could at least add a TODO comment to inform the reader that we are not 100% satisfied with the state of things :)

Suggested change
# More than 65% of the predictions must be close up to the 2nd decimal.
# More than 65% of the predictions must be close up to the 2nd decimal.
# TODO: investigate the cause of the remaining discrepancies with lightgbm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is because, in the context of the squared error, we only test for the equivalence with the larger dataset with shallow trees (max_leaf_nodes < 10 and n_samples >= 1000): this condition might allow the averaging effect in leafs to mitigate the impact of rounding errors and maybe reduce the likelihood of ties when deciding on which features and thresholds to split on?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, LightGBM's algo deviates more from ours for general losses (=not squared error). For instance in the Poisson case, LightGBM has a poisson_max_delta_step which I set to 1e-12, but I'm not 100% sure of the effective difference this causes.

I'll add a comment as suggested.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And BTW, our losses have a better test suite than the ones in LightGBM:smirk:

# Less than 1% of the predictions may deviate more than 1e-3 in relative terms.
assert np.mean(np.isclose(pred_lightgbm, pred_sklearn, rtol=1e-3)) > 1 - 0.01

if max_leaf_nodes < 10 and n_samples >= 1000 and loss in ("squared_error"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how the current code can possibly work. It think the condition was never met and the assertion skipped:

Suggested change
if max_leaf_nodes < 10 and n_samples >= 1000 and loss in ("squared_error"):
if max_leaf_nodes < 10 and n_samples >= 1000 and loss in ("squared_error",):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pair n_samples, max_leaf_nodes = 1000, 8 triggers this test.

@jjerphan
Copy link
Member

I let @ogrisel merge. 🙂

@ogrisel ogrisel enabled auto-merge (squash) January 30, 2023 15:05
@ogrisel
Copy link
Member

ogrisel commented Jan 30, 2023

Lgtm! I enabled auto merge.

@ogrisel ogrisel merged commit 7b13a8f into scikit-learn:main Jan 30, 2023
@jjerphan
Copy link
Member

Thank you once again for authoring notable contributions, @lorentzenchr. 🤝 💯

@lorentzenchr lorentzenchr deleted the hgbt_gamma branch January 30, 2023 17:07
thomasjpfan pushed a commit to thomasjpfan/scikit-learn that referenced this pull request Feb 3, 2023
* FEA add gamma loss to HGBT

* DOC add whatsnew

* CLN address review comments

* TST make test_gamma pass by not testing out-of-sample

* TST compare gamma and poisson to LightGBM

* TST fix test_gamma by comparing to MSE HGBT instead of Poisson HGBT

* TST fix for test_same_predictions_regression for poisson

* CLN address review comments

* CLN nits

* CLN better comments

* TST use pytest.param with skip mark

* TST Correct conditional test parametrization mark

Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>

* CI Trigger CI

Builds currently fail because requests to Azure Ubuntu repository
timeout.

* DOC add comment for lax comparison with LightGBM

* CLN tuple needs trailing comma

---------

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
AdarshPrusty7 added a commit to AdarshPrusty7/GSGP that referenced this pull request Mar 6, 2023
* ENH Raise NotFittedError in get_feature_names_out for MissingIndicator, KBinsDiscretizer, SplineTransformer, DictVectorizer (scikit-learn#25402)

Co-authored-by: Alex <alex.buzenet.fr@gmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* DOC Update date and contributors list for v1.2.1 (scikit-learn#25459)

* DOC Make MeanShift documentation clearer (scikit-learn#25305)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* Finishes boolean and arithmetic creation

* Skeleton for traditional GP

* DOC Reorder whats_new/v1.2.rst (scikit-learn#25461)

Follow-up of scikit-learn#25459

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>

* FIX fix faulty test in `cross_validate` that used the wrong estimator (scikit-learn#25456)

* ENH Raise NotFittedError in get_feature_names_out for estimators that use ClassNamePrefixFeatureOutMixin and SelectorMixin (scikit-learn#25308)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* EFF Improve IsolationForest predict time (scikit-learn#25186)

Co-authored-by: Felipe Breve Siola <felipe.breve-siola@klarna.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Tim Head <betatim@gmail.com>

* MAINT refactor spectral_clustering to call SpectralClustering (scikit-learn#25392)

* TST reduce warnings in test_logistic.py (scikit-learn#25469)

* CI Build doc on CircleCI (scikit-learn#25466)

* DOC Update news footer for 1.2.1 (scikit-learn#25472)

* MAINT Validate parameter for `sklearn.cluster.cluster_optics_xi` (scikit-learn#25385)

Co-authored-by: adossantosalfam <anthony.dos_santos_alfama@insa-rouen.fr>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* MAINT Parameters validation for additive_chi2_kernel (scikit-learn#25424)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* Initial Program Creation

* CI Include linting in CircleCI (scikit-learn#25475)

* MAINT Update version number to 1.2.1 in SECURITY.md (scikit-learn#25471)

* TST Sets random_state for test_logistic.py (scikit-learn#25446)

* MAINT Remove -Wcpp warnings when compiling sklearn.decomposition._online_lda_fast (scikit-learn#25020)

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>

* FIX Support readonly sparse datasets for `manhattan_distances`  (scikit-learn#25432)

* TST Add non-regression test for scikit-learn#7981

This reproducer is adapted from the one of this message:
scikit-learn#7981 (comment)

Co-authored-by: Loïc Estève <loic.esteve@ymail.com>

* FIX Support readonly sparse datasets for manhattan

* DOC Add entry in whats_new/v1.2.rst for 1.2.1

* FIX Fix comment

* Update sklearn/metrics/tests/test_pairwise.py

Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>

* DOC Move entry to whats_new/v1.3.rst

* Update sklearn/metrics/tests/test_pairwise.py

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* MAINT dynamically expose kulsinski and remove support in BallTree (scikit-learn#25417)

Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
closes scikit-learn#25212

* DOC Adds CirrusCI badge to readme (scikit-learn#25483)

* CI add linter display name (scikit-learn#25485)

* DOC update description of X in `FunctionTransformer.transform()`  (scikit-learn#24844)

* MAINT remove -Wcpp warnings when compiling sklearn.preprocessing._csr_polynomial_expansion (scikit-learn#25041)

* DOC more didactic example of bisecting kmeans (scikit-learn#25494)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* ENH csr_row_norms optimization (scikit-learn#24426)

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>

* TST Allow callables as valid parameter regarding cloning estimator (scikit-learn#25498)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
Co-authored-by: From: Tim Head <betatim@gmail.com>

* DOC Fixes sphinx search on website (scikit-learn#25504)

* FIX make IsotonicRegression always predict NumPy arrays (scikit-learn#25500)



Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* FEA Add Gamma deviance as loss function to HGBT (scikit-learn#22409)

* FEA add gamma loss to HGBT

* DOC add whatsnew

* CLN address review comments

* TST make test_gamma pass by not testing out-of-sample

* TST compare gamma and poisson to LightGBM

* TST fix test_gamma by comparing to MSE HGBT instead of Poisson HGBT

* TST fix for test_same_predictions_regression for poisson

* CLN address review comments

* CLN nits

* CLN better comments

* TST use pytest.param with skip mark

* TST Correct conditional test parametrization mark

Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>

* CI Trigger CI

Builds currently fail because requests to Azure Ubuntu repository
timeout.

* DOC add comment for lax comparison with LightGBM

* CLN tuple needs trailing comma

---------

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>

* MAINT Remove -Wsign-compare warnings when compiling sklearn.tree._tree (scikit-learn#25507)

* MAINT add more intuition on OAS computation based on literature (scikit-learn#23867)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* CI Allow cirrus arm tests to run with cd build commit tag (scikit-learn#25514)

* CI Upload ARM wheels from CirrusCI to nightly and staging index (scikit-learn#25513)



Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* MAINT Remove -Wcpp warnings from sklearn.utils._seq_dataset (scikit-learn#25406)

* FIX Fixes linux ARM CI on CirrusCI (scikit-learn#25536)

* DOC Fix grammatical mistake in `mixture` module (scikit-learn#25541)

* DOC add missing trailing colon (scikit-learn#25542)

* MAINT Parameters validation for sklearn.datasets.make_classification (scikit-learn#25474)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* MNT Expose allow_nan tag in bagging (scikit-learn#25506)

* MAINT Clean-up comments and rename variables in `_middle_term_sparse_sparse_{32, 64}` (scikit-learn#25449)

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>

* DOC: remove incorrect statement (scikit-learn#25544)

* MAINT Parameters validation for reconstruct_from_patches_2d (scikit-learn#25384)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* MAINT Parameter validation for sklearn.metrics.d2_pinball_score (scikit-learn#25414)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for spectral_clustering (scikit-learn#25378)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* MAINT Parameters validation for sklearn.datasets.fetch_kddcup99 (scikit-learn#25463)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* DOC Update MLPRegressor docs (scikit-learn#25556)

Co-authored-by: Ian Thompson <ian.thompson@hrblock.com>

* DOC Update docs for KMeans (scikit-learn#25546)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* FIX BisectingKMeans crashes randomly (scikit-learn#25563)

Fixes scikit-learn#25505

* ENH BaseLabelPropagation to accept sparse matrices (scikit-learn#19664)

Co-authored-by: Kaushik Amar Das <kaushik.amar.das@accenture.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* MAINT Remove travis ci config and related doc (scikit-learn#25562)

* DOC Add pynndescent to Approximate nearest neighbors in TSNE example (scikit-learn#25480)


Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* DOC Add docstring example to make_regression (scikit-learn#25551)

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT ensure that pos_label support all possible types (scikit-learn#25317)

* MAINT Parameters validation for sklearn.metrics.f1_score (scikit-learn#25557)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* ENH Adds `class_names` to `tree.export_text` (scikit-learn#25387)

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* MAINT Replace cnp.ndarray with memory views in sklearn.tree._tree (where possible) (scikit-learn#25540)

* DOC Change print format in TSNE example (scikit-learn#25569)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* FIX ColumnTransformer supports empty selection for pandas output (scikit-learn#25570)

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>

* DOC fix docstring of _plain_sgd (scikit-learn#25573)

* FIX Enable setting of sub-parameters for deprecated base_estimator param (scikit-learn#25477)

* DOC Improve minor and bug-fix release processes documentation (scikit-learn#25457)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@yahoo.fr>

* MAINT Remove ReadonlyArrayWrapper from _loss module (scikit-learn#25555)

* MAINT Remove ReadonlyArrayWrapper from _loss module

* CLN Remove comments about Cython 3.0

* MAINT Remove ReadonlyArrayWrapper from _kmeans (scikit-learn#25554)

* MAINT Remove ReadonlyArrayWrapper from _kmeans

* more const and remove blas compile warnings

* CLN Adds comment about casting to non const pointers

* Update sklearn/utils/_cython_blas.pyx

* MAINT Remove ReadonlyArrayWrapper from DistanceMetric (scikit-learn#25553)

* DOC improve stop_words description w.r.t. max_df range in CountVectorizer (scikit-learn#25489)

* MAINT Removes ReadOnlyWrapper (scikit-learn#25586)

* MAINT Parameters validation for sklearn.metrics.log_loss (scikit-learn#25577)

* MAINT Adds comments and better naming into tree code (scikit-learn#25576)

* MAINT Adds comments and better naming into tree code

* CLN Use feature_values instead of Xf

* Apply suggestions from code review

Co-authored-by: Adam Li <adam2392@gmail.com>

* DOC Improve comment from review

* Apply suggestions from code review

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

---------

Co-authored-by: Adam Li <adam2392@gmail.com>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* FIX error when deserialzing a Tree instance from a read only buffer (scikit-learn#25585)

* DOC: fix typo in California Housing dataset description (scikit-learn#25613)

* ENH: Update KDTree, and example documentation (scikit-learn#25482)

* ENH: Update KDTree, and example documentation

* ENH: Add valid metric function and reference doc

* CHG: Documentation update

Co-authored-by: Adam Li <adam2392@gmail.com>

* CHG: make valid metric property and fix doc string

* FIX: documentation, and add code example

* ENH: Change valid metric to class method, and doc

* ENH: Change valid metric class variable, and doc

* FIX: documentation error

* FIX: documentation error

* CHG: Use class method for valid metrics

* FIX: CI problems

---------

Co-authored-by: Adam Li <adam2392@gmail.com>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>

* TST Common test for checking estimator deserialization from a read only buffer (scikit-learn#25624)

* DOC fix comment in plot_logistic_l1_l2_sparsity.py (scikit-learn#25633)

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* DOC Places governance in navigation bar (scikit-learn#25618)

* MAINT Check pyproject toml is consistent with min_dependencies (scikit-learn#25610)

* MAINT Check pyproject toml is consistent with min_dependencies

* CLN Make it clear that only SciPy and Cython are checked

* CLN Revert auto formatter

* MAINT Use newest NumPy C API in tree._criterion (scikit-learn#25615)

* MAINT Use newest NumPy C API in tree._criterion

* FIX Use pointer for children

* FIX Fixes check_array nonfinite checks with ArrayAPI specification (scikit-learn#25619)

* FIX Fixes check_array nonfinite checks with ArrayAPI specification

* DOC Adds PR number

* FIX Test on both cupy and numpy

* DOC Correctly docstring in StackingRegressor.fit_transform (scikit-learn#25599)

* MAINT Remove Cython compilation warnings ahead of Cython3.0 release (scikit-learn#25621)

* ENH Preserve DataFrame dtypes in transform for feature selectors (scikit-learn#25102)

* FIX report properly n_iter_ when warm_start=True (scikit-learn#25443)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* DOC fix typo in KMeans's param. (scikit-learn#25649)

* FIX use const memory views in hist_gradient_boosting predictor (scikit-learn#25650)

* DOC modified the graph for better readability (scikit-learn#25644)

* MAINT Removes upper limit on setuptools (scikit-learn#25651)

* DOC improve the `warm_start` glossary entry (scikit-learn#25523)

* DOC Update governance document for SLEP020 (scikit-learn#25663)



Co-authored-by: Tim Head <betatim@gmail.com>
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>

* FIX renormalization of y_pred inside log_loss (scikit-learn#25299)

* Remove renormalization of y_pred inside log_loss

* Deprecate eps parameter in log_loss

* ENH Allows target to be pandas nullable dtypes (scikit-learn#25638)

* DOC unify usage of 'w.r.t.' (scikit-learn#25683)

* MAINT Parameters validation for metrics.max_error (scikit-learn#25679)

* MAINT Parameters validation for datasets.make_friedman1 (scikit-learn#25674)

Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>

* MAINT Parameters validation for mean_pinball_loss (scikit-learn#25685)

Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>

* DOC Specify behavior of None for CountVectorizer (scikit-learn#25678)

* DOC Specify behaviour of None for TfIdfVectorizer max_features parameter (scikit-learn#25676)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* MAINT Set random state for plot_anomaly_comparison (scikit-learn#25675)

* MAINT Parameters validation for cluster.mean_shift (scikit-learn#25684)

Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>

* MAINT Parameters validation for sklearn.metrics.jaccard_score (scikit-learn#25680)

Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>

* DOC Add the custom compiler section back (scikit-learn#25667)

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* MAINT Parameters validation for precision_recall_fscore_support (scikit-learn#25681)

Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>

* FIX Allow negative tol in SequentialFeatureSelector (scikit-learn#25664)

* MAINT Replace deprecated cython conditional compilation (scikit-learn#25654)



Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* DOC fix formatting typo in related_projects (scikit-learn#25706)

* MAINT Parameters validation for metrics.mean_absolute_percentage_error (scikit-learn#25695)

* MAINT Parameters validation for metrics.precision_recall_curve (scikit-learn#25698)

Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>

* MAINT Parameter Validation for metrics.precision_score (scikit-learn#25708)

Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>

* CI Stablize build with random_state (scikit-learn#25701)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Remove -Wcpp warnings when compiling arrayfuncs (scikit-learn#25415)

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* DOC Add scikit-learn-intelex to related projects (scikit-learn#23766)

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* ENH Support float32 in SGDClassifier and SGDRegressor (scikit-learn#25587)

* FIX Raise appropriate attribute error in ensemble (scikit-learn#25668)

* FIX Allow OrdinalEncoder's encoded_missing_value set to the cardinality (scikit-learn#25704)

* ENH Let csr_row_norms support multi-thread (scikit-learn#25598)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
Co-authored-by: Vincent M <maladiere.vincent@yahoo.fr>

* MAINT Parameter Validation for feature_selection.chi2 (scikit-learn#25719)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* MAINT Parameter Validation for feature_selection.f_classif (scikit-learn#25720)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for sklearn.metrics.matthews_corrcoef (scikit-learn#25712)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* MAINT parameter validation for sklearn.datasets.dump_svmlight_file (scikit-learn#25726)

Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>

* MAINT Clean dead code in build helpers (scikit-learn#25661)

* MAINT Use newest NumPy C API in metrics._dist_metrics (scikit-learn#25702)

* CI Adds permissions to workflows that use GITHUB_TOKEN (scikit-learn#25600)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* FIX Improves error message in partial_fit when early_stopping=True (scikit-learn#25694)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* DOC Makes navbar static (scikit-learn#25688)

* MAINT Remove redundant sparse square euclidian distances function (scikit-learn#25731)

* MAINT Use float64 for accumulators in WeightVector* (scikit-learn#25721)

* API make PatchExtractor being a real scikit-learn transformer (scikit-learn#24230)

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Update pyparsing.py to use bool instead of double negation (scikit-learn#25724)

* API Deprecates values in partial_dependence in favor of pdp_values (scikit-learn#21809)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* API Use grid_values instead of pdp_values in partial_dependence (scikit-learn#25732)

* MAINT remove np.product and inf/nan aliases in favor of canonical names (scikit-learn#25741)

* MAINT Parameters validation for metrics.label_ranking_loss (scikit-learn#25742)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for metrics.coverage_error (scikit-learn#25748)

* MAINT Parameters validation for metrics.dcg_score (scikit-learn#25749)

* MAINT replace cnp.ndarray with memory views in _fast_dict (scikit-learn#25754)

* MAINT Parameter Validation for feature_selection.f_regression (scikit-learn#25736)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameters validation for feature_selection.r_regression (scikit-learn#25734)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Parameter Validation for metrics.get_scorer (scikit-learn#25738)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* DOC Move allowing pandas nullable dtypes to 1.2.2 (scikit-learn#25692)

* MAINT replace cnp.ndarray with memory views in sparsefuncs_fast (scikit-learn#25764)

* MAINT parameter validation for sklearn.datasets.fetch_covtype (scikit-learn#25759)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>

* MAINT Define centralized generic, but with explicit precision, types (scikit-learn#25739)

* CI Disable network when SciPy requires it (scikit-learn#25743)

* CI Open issue when arm wheel fails on CirrusCI (scikit-learn#25620)

* ENH Speed-up expected mutual information (scikit-learn#25713)

Co-authored-by: Kshitij Mathur <k.mathur68@gmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>

* FIX add retry mechanism to handle quotechar in read_csv (scikit-learn#25511)

* Merge Population Creation (#1)

---------

Co-authored-by: Alex Buzenet <94121450+albuzenet@users.noreply.github.com>
Co-authored-by: Alex <alex.buzenet.fr@gmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Adam Kania <48769688+remilvus@users.noreply.github.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
Co-authored-by: Shady el Gewily <90049412+shadyelgewily-slimstock@users.noreply.github.com>
Co-authored-by: John Pangas <swiftyxswaggy@outlook.com>
Co-authored-by: Felipe Siola <fsiola@gmail.com>
Co-authored-by: Felipe Breve Siola <felipe.breve-siola@klarna.com>
Co-authored-by: Tim Head <betatim@gmail.com>
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
Co-authored-by: Anthony22-dev <122220081+Anthony22-dev@users.noreply.github.com>
Co-authored-by: adossantosalfam <anthony.dos_santos_alfama@insa-rouen.fr>
Co-authored-by: Xiao Yuan <yuanx749@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Co-authored-by: Rahil Parikh <75483881+rprkh@users.noreply.github.com>
Co-authored-by: Gael Varoquaux <gael.varoquaux@normalesup.org>
Co-authored-by: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com>
Co-authored-by: davidblnc <40642621+davidblnc@users.noreply.github.com>
Co-authored-by: Changyao Chen <changyao.chen@gmail.com>
Co-authored-by: Nicola Fanelli <48762613+nicolafan@users.noreply.github.com>
Co-authored-by: Vincent M <maladiere.vincent@yahoo.fr>
Co-authored-by: partev <petrosyan@gmail.com>
Co-authored-by: ouss1508 <121971998+ouss1508@users.noreply.github.com>
Co-authored-by: ashah002 <97778401+ashah002@users.noreply.github.com>
Co-authored-by: Ahmedbgh <83551938+Ahmedbgh@users.noreply.github.com>
Co-authored-by: Pooja M <90301980+pm155@users.noreply.github.com>
Co-authored-by: Ian Thompson <ianiat11@gmail.com>
Co-authored-by: Ian Thompson <ian.thompson@hrblock.com>
Co-authored-by: SANJAI_3 <86285670+sanjail3@users.noreply.github.com>
Co-authored-by: Kaushik Amar Das <cozek@users.noreply.github.com>
Co-authored-by: Kaushik Amar Das <kaushik.amar.das@accenture.com>
Co-authored-by: Nawazish Alam <nawazishmail@gmail.com>
Co-authored-by: William M <64324808+Akbeeh@users.noreply.github.com>
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@yahoo.fr>
Co-authored-by: JanFidor <66260538+JanFidor@users.noreply.github.com>
Co-authored-by: Adam Li <adam2392@gmail.com>
Co-authored-by: Logan Thomas <logan.thomas005@gmail.com>
Co-authored-by: Vyom Pathak <angerstick3@gmail.com>
Co-authored-by: as-90 <88336957+as-90@users.noreply.github.com>
Co-authored-by: Marvin Krawutschke <101656586+Marvvxi@users.noreply.github.com>
Co-authored-by: Haesun Park <haesunrpark@gmail.com>
Co-authored-by: Christine P. Chai <star1327p@gmail.com>
Co-authored-by: Christian Veenhuis <124370897+ChVeen@users.noreply.github.com>
Co-authored-by: Sortofamudkip <wishyutp0328@gmail.com>
Co-authored-by: sonnivs <48860780+sonnivs@users.noreply.github.com>
Co-authored-by: Ali H. El-Kassas <aliabdelmonem234@gmail.com>
Co-authored-by: Yusuf Raji <raji.yusuf234@gmail.com>
Co-authored-by: Tabea Kossen <tabeakossen@gmail.com>
Co-authored-by: Pooja Subramaniam <poojas2086@gmail.com>
Co-authored-by: JuliaSchoepp <63353759+JuliaSchoepp@users.noreply.github.com>
Co-authored-by: Jack McIvor <jacktmcivor@gmail.com>
Co-authored-by: zeeshan lone <56621467+still-learning-ev@users.noreply.github.com>
Co-authored-by: Max Halford <maxhalford25@gmail.com>
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: genvalen <genvalen@protonmail.com>
Co-authored-by: Shiva chauhan <103742975+Shivachauhan17@users.noreply.github.com>
Co-authored-by: Dayne <daynesorvisto@yahoo.ca>
Co-authored-by: Ralf Gommers <ralf.gommers@gmail.com>
Co-authored-by: Kshitij Mathur <k.mathur68@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:ensemble Waiting for Second Reviewer First reviewer is done, need a second one!
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants