@@ -104,11 +104,12 @@ def test_same_predictions_classification(
104
104
105
105
rng = np .random .RandomState (seed = seed )
106
106
max_iter = 1
107
+ n_classes = 2
107
108
max_bins = 255
108
109
109
110
X , y = make_classification (
110
111
n_samples = n_samples ,
111
- n_classes = 2 ,
112
+ n_classes = n_classes ,
112
113
n_features = 5 ,
113
114
n_informative = 5 ,
114
115
n_redundant = 0 ,
@@ -174,13 +175,14 @@ def test_same_predictions_multiclass_classification(
174
175
pytest .importorskip ("lightgbm" )
175
176
176
177
rng = np .random .RandomState (seed = seed )
178
+ n_classes = 3
177
179
max_iter = 1
178
180
max_bins = 255
179
181
lr = 1
180
182
181
183
X , y = make_classification (
182
184
n_samples = n_samples ,
183
- n_classes = 3 ,
185
+ n_classes = n_classes ,
184
186
n_features = 5 ,
185
187
n_informative = 5 ,
186
188
n_redundant = 0 ,
@@ -204,7 +206,9 @@ def test_same_predictions_multiclass_classification(
204
206
min_samples_leaf = min_samples_leaf ,
205
207
max_leaf_nodes = max_leaf_nodes ,
206
208
)
207
- est_lightgbm = get_equivalent_estimator (est_sklearn , lib = "lightgbm" )
209
+ est_lightgbm = get_equivalent_estimator (
210
+ est_sklearn , lib = "lightgbm" , n_classes = n_classes
211
+ )
208
212
209
213
est_lightgbm .fit (X_train , y_train )
210
214
est_sklearn .fit (X_train , y_train )
@@ -224,7 +228,8 @@ def test_same_predictions_multiclass_classification(
224
228
225
229
acc_lightgbm = accuracy_score (y_train , pred_lightgbm )
226
230
acc_sklearn = accuracy_score (y_train , pred_sklearn )
227
- np .testing .assert_almost_equal (acc_lightgbm , acc_sklearn , decimal = 2 )
231
+
232
+ np .testing .assert_allclose (acc_lightgbm , acc_sklearn , rtol = 0 , atol = 5e-2 )
228
233
229
234
if max_leaf_nodes < 10 and n_samples >= 1000 :
230
235
0 commit comments