Skip to content

[MRG+1] Make ParameterSampler sample without replacement #3850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ Enhancements
:class:`tree.DecisionTreeClassifier`, :class:`ensemble.ExtraTreesClassifier`
and :class:`tree.ExtraTreeClassifier`. By `Trevor Stephens`_.

- :class:`grid_search.RandomizedSearchCV` now does sampling without
replacement if all parameters are given as lists. by `Andreas Mueller`_.

Documentation improvements
..........................

Expand Down
57 changes: 45 additions & 12 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .externals.joblib import Parallel, delayed
from .externals import six
from .utils import check_random_state
from .utils.random import sample_without_replacement
from .utils.validation import _num_samples, indexable
from .utils.metaestimators import if_delegate_has_method
from .metrics.scorer import check_scoring
Expand Down Expand Up @@ -113,7 +114,11 @@ class ParameterSampler(object):
"""Generator on parameters sampled from given distributions.

Non-deterministic iterable over random candidate combinations for hyper-
parameter search.
parameter search. If all parameters are presented as a list,
sampling without replacement is performed. If at least one parameter
is given as a distribution, sampling with replacement is used.
It is highly recommended to use continuous distributions for continuous
parameters.

Note that as of SciPy 0.12, the ``scipy.stats.distributions`` do not accept
a custom RNG instance and always use the singleton RNG from
Expand Down Expand Up @@ -165,17 +170,39 @@ def __init__(self, param_distributions, n_iter, random_state=None):
self.random_state = random_state

def __iter__(self):
samples = []
# check if all distributions are given as lists
# in this case we want to sample without replacement
all_lists = np.all([not hasattr(v, "rvs")
for v in self.param_distributions.values()])
rnd = check_random_state(self.random_state)
# Always sort the keys of a dictionary, for reproducibility
items = sorted(self.param_distributions.items())
for _ in range(self.n_iter):
params = dict()
for k, v in items:
if hasattr(v, "rvs"):
params[k] = v.rvs()
else:
params[k] = v[rnd.randint(len(v))]
yield params

if all_lists:
# get complete grid and yield from it
param_grid = list(ParameterGrid(self.param_distributions))
grid_size = len(param_grid)

if grid_size < self.n_iter:
raise ValueError(
"The total space of parameters %d is smaller "
"than n_iter=%d." % (grid_size, self.n_iter)
+ " For exhaustive searches, use GridSearchCV.")
for i in sample_without_replacement(grid_size, self.n_iter,
random_state=rnd):
yield param_grid[i]

else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment that if the number of requested iterations is much smaller than the number of possible combinations of parameter values with can do sampling without replacement naively by maintaining a tabu list of past samples without running into the risk of infinite loops.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean in the all_list case, right?

# Always sort the keys of a dictionary, for reproducibility
items = sorted(self.param_distributions.items())
while len(samples) < self.n_iter:
params = dict()
for k, v in items:
if hasattr(v, "rvs"):
params[k] = v.rvs()
else:
params[k] = v[rnd.randint(len(v))]
samples.append(params)
yield params

def __len__(self):
"""Number of points that will be sampled."""
Expand Down Expand Up @@ -249,7 +276,7 @@ def _check_param_grid(param_grid):
raise ValueError("Parameter array should be one-dimensional.")

check = [isinstance(v, k) for k in (list, tuple, np.ndarray)]
if not True in check:
if True not in check:
raise ValueError("Parameter values should be a list.")

if len(v) == 0:
Expand Down Expand Up @@ -717,6 +744,12 @@ class RandomizedSearchCV(BaseSearchCV):
distributions. The number of parameter settings that are tried is
given by n_iter.

If all parameters are presented as a list,
sampling without replacement is performed. If at least one parameter
is given as a distribution, sampling with replacement is used.
It is highly recommended to use continuous distributions for continuous
parameters.

Parameters
----------
estimator : object type that implements the "fit" and "predict" methods
Expand Down
54 changes: 34 additions & 20 deletions sklearn/tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from sklearn.utils.testing import ignore_warnings
from sklearn.utils.mocking import CheckingClassifier, MockDataFrame

from scipy.stats import distributions
from scipy.stats import bernoulli, expon, uniform

from sklearn.externals.six.moves import zip
from sklearn.base import BaseEstimator
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_trivial_grid_scores():
grid_search.fit(X, y)
assert_true(hasattr(grid_search, "grid_scores_"))

random_search = RandomizedSearchCV(clf, {'foo_param': [0]})
random_search = RandomizedSearchCV(clf, {'foo_param': [0]}, n_iter=1)
random_search.fit(X, y)
assert_true(hasattr(random_search, "grid_scores_"))

Expand Down Expand Up @@ -530,7 +530,7 @@ def custom_scoring(estimator, X):
def test_param_sampler():
# test basic properties of param sampler
param_distributions = {"kernel": ["rbf", "linear"],
"C": distributions.uniform(0, 1)}
"C": uniform(0, 1)}
sampler = ParameterSampler(param_distributions=param_distributions,
n_iter=10, random_state=0)
samples = [x for x in sampler]
Expand All @@ -549,8 +549,8 @@ def test_randomized_search_grid_scores():
# XXX: as of today (scipy 0.12) it's not possible to set the random seed
# of scipy.stats distributions: the assertions in this test should thus
# not depend on the randomization
params = dict(C=distributions.expon(scale=10),
gamma=distributions.expon(scale=0.1))
params = dict(C=expon(scale=10),
gamma=expon(scale=0.1))
n_cv_iter = 3
n_search_iter = 30
search = RandomizedSearchCV(SVC(), n_iter=n_search_iter, cv=n_cv_iter,
Expand Down Expand Up @@ -615,7 +615,7 @@ def test_pickle():
pickle.dumps(grid_search) # smoke test

random_search = RandomizedSearchCV(clf, {'foo_param': [1, 2, 3]},
refit=True)
refit=True, n_iter=3)
random_search.fit(X, y)
pickle.dumps(random_search) # smoke test

Expand Down Expand Up @@ -647,20 +647,7 @@ def test_grid_search_with_multioutput_data():

# Test with a randomized search
for est in estimators:
random_search = RandomizedSearchCV(est, est_parameters, cv=cv)
random_search.fit(X, y)
for parameters, _, cv_validation_scores in random_search.grid_scores_:
est.set_params(**parameters)

for i, (train, test) in enumerate(cv):
est.fit(X[train], y[train])
correct_score = est.score(X[test], y[test])
assert_almost_equal(correct_score,
cv_validation_scores[i])

# Test with a randomized search
for est in estimators:
random_search = RandomizedSearchCV(est, est_parameters, cv=cv)
random_search = RandomizedSearchCV(est, est_parameters, cv=cv, n_iter=3)
random_search.fit(X, y)
for parameters, _, cv_validation_scores in random_search.grid_scores_:
est.set_params(**parameters)
Expand Down Expand Up @@ -758,3 +745,30 @@ def test_grid_search_failing_classifier_raise():

# FailingClassifier issues a ValueError so this is what we look for.
assert_raises(ValueError, gs.fit, X, y)


def test_parameters_sampler_replacement():
# raise error if n_iter too large
params = {'first': [0, 1], 'second': ['a', 'b', 'c']}
sampler = ParameterSampler(params, n_iter=7)
assert_raises(ValueError, list, sampler)
# degenerates to GridSearchCV if n_iter the same as grid_size
sampler = ParameterSampler(params, n_iter=6)
samples = list(sampler)
assert_equal(len(samples), 6)
for values in ParameterGrid(params):
assert_true(values in samples)

# test sampling without replacement in a large grid
params = {'a': range(10), 'b': range(10), 'c': range(10)}
sampler = ParameterSampler(params, n_iter=99, random_state=42)
samples = list(sampler)
assert_equal(len(samples), 99)
hashable_samples = ["a%db%dc%d" % (p['a'], p['b'], p['c']) for p in samples]
assert_equal(len(set(hashable_samples)), 99)

# doesn't go into infinite loops
params_distribution = {'first': bernoulli(.5), 'second': ['a', 'b', 'c']}
sampler = ParameterSampler(params_distribution, n_iter=7)
samples = list(sampler)
assert_equal(len(samples), 7)
2 changes: 1 addition & 1 deletion sklearn/tests/test_metaestimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, name, construct, skip_methods=(),
skip_methods=['score']),
DelegatorData('RandomizedSearchCV',
lambda est: RandomizedSearchCV(
est, param_distributions={'param': [5]}, cv=2),
est, param_distributions={'param': [5]}, cv=2, n_iter=1),
skip_methods=['score']),
DelegatorData('RFE', RFE,
skip_methods=['transform', 'inverse_transform', 'score']),
Expand Down