@@ -169,30 +169,38 @@ def __iter__(self):
169
169
# in this case we want to sample without replacement
170
170
all_lists = np .all ([not hasattr (v , "rvs" )
171
171
for v in self .param_distributions .values ()])
172
+ rnd = check_random_state (self .random_state )
173
+
172
174
if all_lists :
173
- # size of complete grid
174
- grid_size = np .prod ([len (v ) for v in self .param_distributions .values ()])
175
+ param_grid = list (ParameterGrid (self .param_distributions ))
176
+ grid_size = len (param_grid )
177
+
178
+ if all_lists and self .n_iter > 0.1 * grid_size :
179
+ # get complete grid and yield from it
175
180
if grid_size < self .n_iter :
176
181
raise ValueError ("The total space of parameters %d is smaller than n_iter=%d. "
177
182
% (grid_size , self .n_iter )
178
183
+ "For exhaustive searches, use GridSearchCV." )
179
- rnd = check_random_state (self .random_state )
180
- # Always sort the keys of a dictionary, for reproducibility
181
- items = sorted (self .param_distributions .items ())
182
- while len (samples ) < self .n_iter :
183
- params = dict ()
184
- for k , v in items :
185
- if hasattr (v , "rvs" ):
186
- params [k ] = v .rvs ()
187
- else :
188
- params [k ] = v [rnd .randint (len (v ))]
189
- if all_lists and params in samples :
190
- # do sampling without replacement only if all_lists
191
- # otherwise distributions with finite support might
192
- # cause infinite loops
193
- continue
194
- samples .append (params )
195
- yield params
184
+ for i in rnd .permutation (grid_size )[:self .n_iter ]:
185
+ yield param_grid [i ]
186
+
187
+ else :
188
+ # Always sort the keys of a dictionary, for reproducibility
189
+ items = sorted (self .param_distributions .items ())
190
+ while len (samples ) < self .n_iter :
191
+ params = dict ()
192
+ for k , v in items :
193
+ if hasattr (v , "rvs" ):
194
+ params [k ] = v .rvs ()
195
+ else :
196
+ params [k ] = v [rnd .randint (len (v ))]
197
+ if all_lists and params in samples :
198
+ # do sampling without replacement only if all_lists
199
+ # otherwise distributions with finite support might
200
+ # cause infinite loops
201
+ continue
202
+ samples .append (params )
203
+ yield params
196
204
197
205
def __len__ (self ):
198
206
"""Number of points that will be sampled."""
@@ -266,7 +274,7 @@ def _check_param_grid(param_grid):
266
274
raise ValueError ("Parameter array should be one-dimensional." )
267
275
268
276
check = [isinstance (v , k ) for k in (list , tuple , np .ndarray )]
269
- if not True in check :
277
+ if True not in check :
270
278
raise ValueError ("Parameter values should be a list." )
271
279
272
280
if len (v ) == 0 :
0 commit comments