Skip to content

ENH method to index ParameterGrid points by parameter values #1842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/svm/plot_rbf_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_iris
from sklearn.cross_validation import StratifiedKFold
from sklearn.grid_search import GridSearchCV
from sklearn.grid_search import GridSearchCV, ParameterGrid

##############################################################################
# Load and prepare data set
Expand Down Expand Up @@ -110,7 +110,8 @@

# We extract just the scores
scores = [x[1] for x in score_dict]
scores = np.array(scores).reshape(len(C_range), len(gamma_range))
_, index = ParameterGrid(grid.param_grid).build_index(['C', 'gamma'])
scores = np.asarray(scores)[index]

# draw heatmap of accuracy as a function of gamma and C
pl.figure(figsize=(8, 6))
Expand Down
80 changes: 74 additions & 6 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,20 @@ class ParameterGrid(object):
Examples
--------
>>> from sklearn.grid_search import ParameterGrid
>>> param_grid = {'a':[1, 2], 'b':[True, False]}
>>> list(ParameterGrid(param_grid)) #doctest: +NORMALIZE_WHITESPACE
>>> param_grid = ParameterGrid({'a':[1, 2], 'b':[True, False]})
>>> list(param_grid) #doctest: +NORMALIZE_WHITESPACE
[{'a': 1, 'b': True}, {'a': 1, 'b': False},
{'a': 2, 'b': True}, {'a': 2, 'b': False}]

Using `build_index` to access points by parameter value:
>>> import numpy
>>> order, index = param_grid.build_index(['b', 'a'])
>>> index #doctest: +NORMALIZE_WHITESPACE
array([[0, 2], [1, 3]])
>>> numpy.asarray(list(param_grid))[index] #doctest: +NORMALIZE_WHITESPACE
array([[{'a': 1, 'b': True}, {'a': 2, 'b': True}],
[{'a': 1, 'b': False}, {'a': 2, 'b': False}]], dtype=object)

See also
--------
:class:`GridSearchCV`:
Expand All @@ -77,18 +86,77 @@ def __iter__(self):
"""
for p in self.param_grid:
# Always sort the keys of a dictionary, for reproducibility
items = sorted(p.items())
items = self._ordered_items(p)
keys, values = zip(*items)
for v in product(*values):
params = dict(zip(keys, v))
yield params

@staticmethod
def _ordered_items(p):
return sorted(p.items())

def __len__(self):
"""Number of points on the grid."""
# Product function that can handle iterables (np.product can't).
product = partial(reduce, operator.mul)
return sum(product(len(v) for v in p.values())
for p in self.param_grid)
return sum(self._len(p) for p in self.param_grid)

@staticmethod
def _len(params, product=partial(reduce, operator.mul)):
"""Number of points for a single param dict"""
return product(len(v) for v in params.values())

def build_index(self, order=(), grid=0, ravel=False):
"""Build an index over grid points by parameter values.

Parameters
----------
`order` : sequence of strings, optional
Parameter names corresponding to the first axes of the returned
index. Any remaining parameters will be returned in arbitrary
order.
`grid` : integer, default 0
The grid index if an array of grids are represented.
`ravel` : boolean, default False
When True, only parameters listed in `order` are assigned their own
dimension, so the output index has `len(order) + 1` dimensions.
This simplifies aggregating over the indexed data grouped by
selected parameters.

Returns
-------
`order` : sequence of strings
Parameter names corresponding to axes in `index`. Where `ravel` is
True, the final axis is not included.
`index` : array
An integer index into the list of grid points, such that each
parameter corresponds to an axis.
"""
n = self._len(self.param_grid[grid])
offset = sum(self._len(p) for p in self.param_grid[:grid])
keys, values = zip(*self._ordered_items(self.param_grid[grid]))
index = np.arange(offset, offset + n).reshape([len(v) for v in values])

if order:
axis_order = []
for key in order:
try:
axis_order.append(keys.index(key))
except IndexError:
raise ValueError('Key {!r} unknown'.format(key))
if len(axis_order) > len(set(axis_order)):
raise ValueError('order contains duplicate keys')
axis_order.extend(i for i in xrange(len(keys))
if i not in axis_order)

keys = np.asarray(keys)[axis_order]
index = index.transpose(axis_order)

if ravel:
keys = keys[:len(order)]
index = index.reshape(index.shape[:len(order)] + (-1,))

return list(keys), index


class IterGrid(ParameterGrid):
Expand Down
70 changes: 70 additions & 0 deletions sklearn/tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,76 @@ def test_parameter_grid():
for x, y in product(params2["foo"], params2["bar"])))


def test_build_index():
params = {"foo": [4, 2],
"bar": ["ham", "spam", "eggs"]}
grid = ParameterGrid(params)

keys, index1a = grid.build_index(['foo', 'bar'])
assert_equal(keys, ['foo', 'bar'])
assert_equal(index1a.shape, (2, 3))

keys, index1b = grid.build_index(['foo'])
assert_equal(keys, ['foo', 'bar'])
assert_array_equal(index1a, index1b)

keys, index2a = grid.build_index(['bar', 'foo'])
assert_equal(keys, ['bar', 'foo'])
assert_equal(index2a.shape, (3, 2))

keys, index2b = grid.build_index(['bar'])
assert_equal(keys, ['bar', 'foo'])
assert_array_equal(index2a, index2b)

keys, index = grid.build_index() # dimensions are implementation-specific?
assert_true(keys == ['foo', 'bar'] or keys == ['bar', 'foo'])
assert_true(np.all(index == index1a) or np.all(index == index2a))

points = np.asarray(list(grid))
points[index1a]
assert_array_equal(
points[index1a].ravel(),
np.asarray([{"foo": x, "bar": y}
for x, y in product(params["foo"], params["bar"])]))
assert_array_equal(
points[index2a].ravel(),
np.asarray([{"foo": y, "bar": x}
for x, y in product(params["bar"], params["foo"])]))

# Test bad `order`s
assert_raises(ValueError, grid.build_index, ['foo', 'foo'])
assert_raises(ValueError, grid.build_index, ['argh'])

# Test a list of grids and the `grid` parameter

grid2 = ParameterGrid([params, {'foo': [3]}])

keys, index1c = grid2.build_index(['foo'])
assert_equal(keys, ['foo', 'bar'])
assert_array_equal(index1a, index1c)

keys, index1d = grid2.build_index(['foo'], grid=0)
assert_equal(keys, ['foo', 'bar'])
assert_array_equal(index1a, index1d)

keys, index3 = grid2.build_index(['foo'], grid=1)
assert_equal(keys, ['foo'])
assert_array_equal(index3, np.array([len(grid2) - 1]))

params3 = {"foo": [4, 2],
"bar": ["ham", "spam", "eggs"],
"jon": [2.0, 1.0]}
grid3 = ParameterGrid(params3)
keys, index = grid3.build_index(['foo'])
assert_true(keys == ['foo', 'bar', 'jon'] or keys == ['foo', 'jon', 'bar'])
keys, index_ravel = grid3.build_index(['foo'], ravel=True)
assert_equal(keys, ['foo'])
assert_equal(index.ndim, 3)
assert_equal(index_ravel.ndim, 2)
for row, row_ravel in zip(index, index_ravel):
assert_equal(sorted(row.flat), sorted(row_ravel))


def test_grid_search():
"""Test that the best estimator contains the right value for foo_param"""
clf = MockClassifier()
Expand Down