@@ -1004,3 +1004,35 @@ def check_transformer_n_iter(name, estimator):
1004
1004
assert_greater (iter_ , 1 )
1005
1005
else :
1006
1006
assert_greater (estimator .n_iter_ , 1 )
1007
+
1008
+
1009
+ @ignore_warnings
1010
+ def check_fit_reset (name , Estimator ):
1011
+ X1 , y1 = make_blobs (n_samples = 50 , n_features = 2 , center_box = (- 200 , - 150 ),
1012
+ centers = 2 , random_state = 0 )
1013
+ X2 , y2 = make_blobs (n_samples = 100 , n_features = 3 , center_box = (- 1 , 1 ),
1014
+ centers = 1 , random_state = 1 )
1015
+ X3 , y3 = make_blobs (n_samples = 200 , n_features = 4 , center_box = (- 100 , - 50 ),
1016
+ centers = 5 , random_state = 2 )
1017
+ X4 , y4 = make_blobs (n_samples = 150 , n_features = 5 , center_box = (50 , 100 ),
1018
+ centers = 10 , random_state = 3 )
1019
+
1020
+ estimator_1 = Estimator ()
1021
+ estimator_2 = Estimator ()
1022
+
1023
+ set_fast_parameters (estimator_1 )
1024
+ set_fast_parameters (estimator_2 )
1025
+
1026
+ set_random_state (estimator_1 )
1027
+ set_random_state (estimator_2 )
1028
+
1029
+ _fit (estimator_1 , X1 , y1 )
1030
+ _fit (estimator_2 , X3 , y3 )
1031
+ assert_not_same_model (estimator_1 , estimator_2 )
1032
+
1033
+ _fit (estimator_2 , X4 , y4 )
1034
+ assert_not_same_model (estimator_1 , estimator_2 )
1035
+
1036
+ _fit (estimator_1 , X2 , y2 )
1037
+ _fit (estimator_2 , X2 , y2 )
1038
+ assert_same_model (estimator_1 , estimator_2 )
0 commit comments