diff --git a/examples/svm/plot_rbf_parameters.py b/examples/svm/plot_rbf_parameters.py index f298ebf01205c..9fb06dc21056a 100644 --- a/examples/svm/plot_rbf_parameters.py +++ b/examples/svm/plot_rbf_parameters.py @@ -28,7 +28,7 @@ from sklearn.preprocessing import StandardScaler from sklearn.datasets import load_iris from sklearn.cross_validation import StratifiedKFold -from sklearn.grid_search import GridSearchCV +from sklearn.grid_search import GridSearchCV, ParameterGrid ############################################################################## # Load and prepare data set @@ -110,7 +110,8 @@ # We extract just the scores scores = [x[1] for x in score_dict] -scores = np.array(scores).reshape(len(C_range), len(gamma_range)) +_, index = ParameterGrid(grid.param_grid).build_index(['C', 'gamma']) +scores = np.asarray(scores)[index] # draw heatmap of accuracy as a function of gamma and C pl.figure(figsize=(8, 6)) diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index 0d6615a6e0408..807924b586e6f 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -47,11 +47,20 @@ class ParameterGrid(object): Examples -------- >>> from sklearn.grid_search import ParameterGrid - >>> param_grid = {'a':[1, 2], 'b':[True, False]} - >>> list(ParameterGrid(param_grid)) #doctest: +NORMALIZE_WHITESPACE + >>> param_grid = ParameterGrid({'a':[1, 2], 'b':[True, False]}) + >>> list(param_grid) #doctest: +NORMALIZE_WHITESPACE [{'a': 1, 'b': True}, {'a': 1, 'b': False}, {'a': 2, 'b': True}, {'a': 2, 'b': False}] + Using `build_index` to access points by parameter value: + >>> import numpy + >>> order, index = param_grid.build_index(['b', 'a']) + >>> index #doctest: +NORMALIZE_WHITESPACE + array([[0, 2], [1, 3]]) + >>> numpy.asarray(list(param_grid))[index] #doctest: +NORMALIZE_WHITESPACE + array([[{'a': 1, 'b': True}, {'a': 2, 'b': True}], + [{'a': 1, 'b': False}, {'a': 2, 'b': False}]], dtype=object) + See also -------- :class:`GridSearchCV`: @@ -77,18 +86,77 @@ def __iter__(self): """ for p in self.param_grid: # Always sort the keys of a dictionary, for reproducibility - items = sorted(p.items()) + items = self._ordered_items(p) keys, values = zip(*items) for v in product(*values): params = dict(zip(keys, v)) yield params + @staticmethod + def _ordered_items(p): + return sorted(p.items()) + def __len__(self): """Number of points on the grid.""" # Product function that can handle iterables (np.product can't). - product = partial(reduce, operator.mul) - return sum(product(len(v) for v in p.values()) - for p in self.param_grid) + return sum(self._len(p) for p in self.param_grid) + + @staticmethod + def _len(params, product=partial(reduce, operator.mul)): + """Number of points for a single param dict""" + return product(len(v) for v in params.values()) + + def build_index(self, order=(), grid=0, ravel=False): + """Build an index over grid points by parameter values. + + Parameters + ---------- + `order` : sequence of strings, optional + Parameter names corresponding to the first axes of the returned + index. Any remaining parameters will be returned in arbitrary + order. + `grid` : integer, default 0 + The grid index if an array of grids are represented. + `ravel` : boolean, default False + When True, only parameters listed in `order` are assigned their own + dimension, so the output index has `len(order) + 1` dimensions. + This simplifies aggregating over the indexed data grouped by + selected parameters. + + Returns + ------- + `order` : sequence of strings + Parameter names corresponding to axes in `index`. Where `ravel` is + True, the final axis is not included. + `index` : array + An integer index into the list of grid points, such that each + parameter corresponds to an axis. + """ + n = self._len(self.param_grid[grid]) + offset = sum(self._len(p) for p in self.param_grid[:grid]) + keys, values = zip(*self._ordered_items(self.param_grid[grid])) + index = np.arange(offset, offset + n).reshape([len(v) for v in values]) + + if order: + axis_order = [] + for key in order: + try: + axis_order.append(keys.index(key)) + except IndexError: + raise ValueError('Key {!r} unknown'.format(key)) + if len(axis_order) > len(set(axis_order)): + raise ValueError('order contains duplicate keys') + axis_order.extend(i for i in xrange(len(keys)) + if i not in axis_order) + + keys = np.asarray(keys)[axis_order] + index = index.transpose(axis_order) + + if ravel: + keys = keys[:len(order)] + index = index.reshape(index.shape[:len(order)] + (-1,)) + + return list(keys), index class IterGrid(ParameterGrid): diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index b06e9e7ed6a26..29715e1090965 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -124,6 +124,76 @@ def test_parameter_grid(): for x, y in product(params2["foo"], params2["bar"]))) +def test_build_index(): + params = {"foo": [4, 2], + "bar": ["ham", "spam", "eggs"]} + grid = ParameterGrid(params) + + keys, index1a = grid.build_index(['foo', 'bar']) + assert_equal(keys, ['foo', 'bar']) + assert_equal(index1a.shape, (2, 3)) + + keys, index1b = grid.build_index(['foo']) + assert_equal(keys, ['foo', 'bar']) + assert_array_equal(index1a, index1b) + + keys, index2a = grid.build_index(['bar', 'foo']) + assert_equal(keys, ['bar', 'foo']) + assert_equal(index2a.shape, (3, 2)) + + keys, index2b = grid.build_index(['bar']) + assert_equal(keys, ['bar', 'foo']) + assert_array_equal(index2a, index2b) + + keys, index = grid.build_index() # dimensions are implementation-specific? + assert_true(keys == ['foo', 'bar'] or keys == ['bar', 'foo']) + assert_true(np.all(index == index1a) or np.all(index == index2a)) + + points = np.asarray(list(grid)) + points[index1a] + assert_array_equal( + points[index1a].ravel(), + np.asarray([{"foo": x, "bar": y} + for x, y in product(params["foo"], params["bar"])])) + assert_array_equal( + points[index2a].ravel(), + np.asarray([{"foo": y, "bar": x} + for x, y in product(params["bar"], params["foo"])])) + + # Test bad `order`s + assert_raises(ValueError, grid.build_index, ['foo', 'foo']) + assert_raises(ValueError, grid.build_index, ['argh']) + + # Test a list of grids and the `grid` parameter + + grid2 = ParameterGrid([params, {'foo': [3]}]) + + keys, index1c = grid2.build_index(['foo']) + assert_equal(keys, ['foo', 'bar']) + assert_array_equal(index1a, index1c) + + keys, index1d = grid2.build_index(['foo'], grid=0) + assert_equal(keys, ['foo', 'bar']) + assert_array_equal(index1a, index1d) + + keys, index3 = grid2.build_index(['foo'], grid=1) + assert_equal(keys, ['foo']) + assert_array_equal(index3, np.array([len(grid2) - 1])) + + params3 = {"foo": [4, 2], + "bar": ["ham", "spam", "eggs"], + "jon": [2.0, 1.0]} + grid3 = ParameterGrid(params3) + keys, index = grid3.build_index(['foo']) + assert_true(keys == ['foo', 'bar', 'jon'] or keys == ['foo', 'jon', 'bar']) + keys, index_ravel = grid3.build_index(['foo'], ravel=True) + assert_equal(keys, ['foo']) + assert_equal(index.ndim, 3) + assert_equal(index_ravel.ndim, 2) + for row, row_ravel in zip(index, index_ravel): + assert_equal(sorted(row.flat), sorted(row_ravel)) + + def test_grid_search(): """Test that the best estimator contains the right value for foo_param""" clf = MockClassifier()