Skip to content

Commit 133f3fd

Browse files
committed
if ParamSampler is used with all lists, instantiate grid and sample from it.
1 parent 71ee3a7 commit 133f3fd

File tree

2 files changed

+36
-20
lines changed

2 files changed

+36
-20
lines changed

sklearn/grid_search.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -169,30 +169,38 @@ def __iter__(self):
169169
# in this case we want to sample without replacement
170170
all_lists = np.all([not hasattr(v, "rvs")
171171
for v in self.param_distributions.values()])
172+
rnd = check_random_state(self.random_state)
173+
172174
if all_lists:
173-
# size of complete grid
174-
grid_size = np.prod([len(v) for v in self.param_distributions.values()])
175+
param_grid = list(ParameterGrid(self.param_distributions))
176+
grid_size = len(param_grid)
177+
178+
if all_lists and self.n_iter > 0.1 * grid_size:
179+
# get complete grid and yield from it
175180
if grid_size < self.n_iter:
176181
raise ValueError("The total space of parameters %d is smaller than n_iter=%d. "
177182
% (grid_size, self.n_iter)
178183
+ "For exhaustive searches, use GridSearchCV.")
179-
rnd = check_random_state(self.random_state)
180-
# Always sort the keys of a dictionary, for reproducibility
181-
items = sorted(self.param_distributions.items())
182-
while len(samples) < self.n_iter:
183-
params = dict()
184-
for k, v in items:
185-
if hasattr(v, "rvs"):
186-
params[k] = v.rvs()
187-
else:
188-
params[k] = v[rnd.randint(len(v))]
189-
if all_lists and params in samples:
190-
# do sampling without replacement only if all_lists
191-
# otherwise distributions with finite support might
192-
# cause infinite loops
193-
continue
194-
samples.append(params)
195-
yield params
184+
for i in rnd.permutation(grid_size)[:self.n_iter]:
185+
yield param_grid[i]
186+
187+
else:
188+
# Always sort the keys of a dictionary, for reproducibility
189+
items = sorted(self.param_distributions.items())
190+
while len(samples) < self.n_iter:
191+
params = dict()
192+
for k, v in items:
193+
if hasattr(v, "rvs"):
194+
params[k] = v.rvs()
195+
else:
196+
params[k] = v[rnd.randint(len(v))]
197+
if all_lists and params in samples:
198+
# do sampling without replacement only if all_lists
199+
# otherwise distributions with finite support might
200+
# cause infinite loops
201+
continue
202+
samples.append(params)
203+
yield params
196204

197205
def __len__(self):
198206
"""Number of points that will be sampled."""
@@ -266,7 +274,7 @@ def _check_param_grid(param_grid):
266274
raise ValueError("Parameter array should be one-dimensional.")
267275

268276
check = [isinstance(v, k) for k in (list, tuple, np.ndarray)]
269-
if not True in check:
277+
if True not in check:
270278
raise ValueError("Parameter values should be a list.")
271279

272280
if len(v) == 0:

sklearn/tests/test_grid_search.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,14 @@ def test_parameters_sampler_replacement():
741741
for values in ParameterGrid(params):
742742
assert_true(values in samples)
743743

744+
# test sampling without replacement in a large grid
745+
params = {'a': range(10), 'b': range(10), 'c': range(10)}
746+
sampler = ParameterSampler(params, n_iter=99, random_state=42)
747+
samples = list(sampler)
748+
assert_equal(len(samples), 99)
749+
hashable_samples = ["a%db%dc%d" % (p['a'], p['b'], p['c']) for p in samples]
750+
assert_equal(len(set(hashable_samples)), 99)
751+
744752
# doesn't go into infinite loops
745753
params_distribution = {'first': bernoulli(.5), 'second': ['a', 'b', 'c']}
746754
sampler = ParameterSampler(params_distribution, n_iter=7)

0 commit comments

Comments
 (0)