Skip to content

Commit 2183d28

Browse files
committed
TST Add test to check if estimators reset model when fit is called
1 parent 92e1e39 commit 2183d28

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

sklearn/tests/test_common.py

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
check_estimator_sparse_data,
3333
check_transformer,
3434
check_clustering,
35+
check_fit_reset,
3536
check_clusterer_compute_labels_predict,
3637
check_regressors_int,
3738
check_regressors_train,
@@ -100,6 +101,8 @@ def test_non_meta_estimators():
100101
yield check_sparsify_coefficients, name, Estimator
101102

102103
yield check_estimator_sparse_data, name, Estimator
104+
# test if fit resets model
105+
yield check_fit_reset, name, Alg
103106

104107

105108
def test_transformers():

sklearn/utils/estimator_checks.py

+32
Original file line numberDiff line numberDiff line change
@@ -1004,3 +1004,35 @@ def check_transformer_n_iter(name, estimator):
10041004
assert_greater(iter_, 1)
10051005
else:
10061006
assert_greater(estimator.n_iter_, 1)
1007+
1008+
1009+
@ignore_warnings
1010+
def check_fit_reset(name, Estimator):
1011+
X1, y1 = make_blobs(n_samples=50, n_features=2, center_box=(-200, -150),
1012+
centers=2, random_state=0)
1013+
X2, y2 = make_blobs(n_samples=100, n_features=3, center_box=(-1, 1),
1014+
centers=1, random_state=1)
1015+
X3, y3 = make_blobs(n_samples=200, n_features=4, center_box=(-100, -50),
1016+
centers=5, random_state=2)
1017+
X4, y4 = make_blobs(n_samples=150, n_features=5, center_box=(50, 100),
1018+
centers=10, random_state=3)
1019+
1020+
estimator_1 = Estimator()
1021+
estimator_2 = Estimator()
1022+
1023+
set_fast_parameters(estimator_1)
1024+
set_fast_parameters(estimator_2)
1025+
1026+
set_random_state(estimator_1)
1027+
set_random_state(estimator_2)
1028+
1029+
_fit(estimator_1, X1, y1)
1030+
_fit(estimator_2, X3, y3)
1031+
assert_not_same_model(estimator_1, estimator_2)
1032+
1033+
_fit(estimator_2, X4, y4)
1034+
assert_not_same_model(estimator_1, estimator_2)
1035+
1036+
_fit(estimator_1, X2, y2)
1037+
_fit(estimator_2, X2, y2)
1038+
assert_same_model(estimator_1, estimator_2)

0 commit comments

Comments
 (0)