Skip to content

[WIP] Multiple-metric grid search #2759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 57 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
b217697
Refactor cv code
AlexanderFabisch Jan 9, 2014
c4d6278
Clean up
AlexanderFabisch Jan 9, 2014
1599952
Refactor RFE and add _check_scorable
AlexanderFabisch Jan 10, 2014
5e52031
FIX typo in docstring
AlexanderFabisch Jan 10, 2014
4b5f468
Merge `fit_grid_point` into `_cross_val_score`
AlexanderFabisch Jan 10, 2014
38081fd
Return time
AlexanderFabisch Jan 10, 2014
30c86ea
Move set_params back to fit_grid_point
AlexanderFabisch Jan 10, 2014
389ed8d
Log score and time in 'cross_val_score'
AlexanderFabisch Jan 11, 2014
1fa3ec3
check_scorable returns scorer
AlexanderFabisch Jan 11, 2014
5b8933d
Clean up
AlexanderFabisch Jan 12, 2014
70aaef2
Replace '_fit_estimator' by '_cross_val_score'
AlexanderFabisch Jan 12, 2014
13c7915
Fix PEP8, style and documentation
AlexanderFabisch Jan 12, 2014
7b951d8
Remove wrong variable names
AlexanderFabisch Jan 12, 2014
5b211cd
Remove helper function '_fit'
AlexanderFabisch Jan 14, 2014
365368e
Merge branch 'refactor_cv' of https://github.com/AlexanderFabisch/sci…
mblondel Jan 14, 2014
13bc90e
Add evaluate_scorers function.
mblondel Jan 14, 2014
4b2cd18
Add more tests for evaluate_scorers.
mblondel Jan 15, 2014
91ff498
Support ranking by regression.
mblondel Jan 15, 2014
4a934f0
Support SVC.
mblondel Jan 15, 2014
314497a
Handle multi-label case.
mblondel Jan 15, 2014
754c72d
Test ranking with more than two relevance levels.
mblondel Jan 15, 2014
79656d5
Merge branch 'multiple_grid_search' of https://github.com/mblondel/sc…
mblondel Jan 16, 2014
f6a44a0
Merge branch 'master' into multiple_grid_search
mblondel Jan 16, 2014
7f4d7ad
Rename evaluate_scorers to _evaluate_scorers.
mblondel Jan 16, 2014
a756083
Remove _score utility function.
mblondel Jan 16, 2014
b4255d8
Support for multiple scorers in cross_val_score.
mblondel Jan 16, 2014
264013f
Refactoring for allowing mutiple scorers.
mblondel Jan 16, 2014
0feed96
Define `parameters` upfront.
mblondel Jan 16, 2014
0a66748
Use more informative name.
mblondel Jan 16, 2014
6f68bfb
Put __repr__ back.
mblondel Jan 16, 2014
114bec6
Deprecate fit_grid_point.
mblondel Jan 16, 2014
aff769d
Add grid_search_cv.
mblondel Jan 16, 2014
55f4126
Refactor code.
mblondel Jan 16, 2014
4bd6c91
Add randomized_search_cv.
mblondel Jan 16, 2014
b02a7e8
Remove multi-output multiclass support from scorers for now.
mblondel Jan 16, 2014
47dd41c
Update docstrings.
mblondel Jan 16, 2014
75a8762
Merge branch 'master' into multiple_grid_search
mblondel Jan 16, 2014
c4905c3
Support multiple metrics directly in GridSearchCV and
mblondel Jan 17, 2014
40f6ef7
Merge branch 'master' into multiple_grid_search
mblondel Jan 17, 2014
aad77c8
Simplify inner loop.
mblondel Feb 3, 2014
a7e79f3
Fix incorrect comment.
mblondel Feb 3, 2014
8040432
Fix comments.
mblondel Feb 3, 2014
2a96384
Return training time only.
mblondel Feb 3, 2014
5933d98
Remove return_parameters.
mblondel Feb 3, 2014
4ee0a8e
Cosmit: used += instead of extend.
mblondel Feb 3, 2014
c08bdd8
Add cross_val_report.
mblondel Feb 3, 2014
e0dfe23
Remove score_func from cross_val_report.
mblondel Feb 4, 2014
9b5fe9a
Accept tuples too.
mblondel Feb 7, 2014
0346fa3
Accept callables in _evaluate_scorers.
mblondel Feb 7, 2014
015e01e
Unused imports.
mblondel Feb 7, 2014
eaa3aeb
Clone early.
mblondel Feb 7, 2014
5d8570b
Merge branch 'master' into multiple_grid_search
mblondel Feb 7, 2014
96c36c7
Multiple scorer support in validation_curve.
mblondel Feb 9, 2014
d4ffc1f
Add rudimentary validation with contours example.
mblondel Feb 9, 2014
c33f0f9
Support param_grid in validation_curve.
mblondel Feb 9, 2014
34ba906
Return training times.
mblondel Feb 9, 2014
7317c31
Remove cross_val_report.
mblondel Feb 9, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions examples/plot_validation_contours.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_digits
from sklearn.svm import SVC
from sklearn.learning_curve import validation_curve
from sklearn.externals.joblib import Memory

memory = Memory(cachedir=".", verbose=0)

@memory.cache
def grid(X, y, Cs, gammas):
param_grid = {"C": Cs, "gamma": gammas}

tr, te, times = validation_curve(SVC(kernel="rbf"), X, y, param_grid, cv=3)

shape = (len(Cs), len(gammas))
tr = tr.mean(axis=1).reshape(shape)
te = te.mean(axis=1).reshape(shape)
times = times.mean(axis=1).reshape(shape)

return tr, te, times

digits = load_digits()
X, y = digits.data, digits.target

gammas = np.logspace(-6, -1, 5)
Cs = np.logspace(-3, 3, 5)

tr, te, times = grid(X, y, Cs, gammas)


for title, values in (("Training accuracy", tr),
("Test accuracy", te),
("Training time", times)):

plt.figure()

plt.title(title)
plt.xlabel("C")
plt.xscale("log")

plt.ylabel("gamma")
plt.yscale("log")

X1, X2 = np.meshgrid(Cs, gammas)
cs = plt.contour(X1, X2, values)

plt.colorbar(cs)

plt.show()
6 changes: 3 additions & 3 deletions examples/plot_validation_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
X, y = digits.data, digits.target

param_range = np.logspace(-6, -1, 5)
train_scores, test_scores = validation_curve(
SVC(), X, y, param_name="gamma", param_range=param_range,
cv=10, scoring="accuracy", n_jobs=1)
param_grid = {"gamma": param_range}
train_scores, test_scores, train_times = validation_curve(
SVC(), X, y, param_grid, cv=10, scoring="accuracy", n_jobs=1)
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
Expand Down
135 changes: 55 additions & 80 deletions sklearn/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from .utils.fixes import unique
from .externals.joblib import Parallel, delayed, logger
from .externals.six import with_metaclass
from .metrics.scorer import check_scoring
from .metrics.scorer import check_scoring, _evaluate_scorers


__all__ = ['Bootstrap',
'KFold',
Expand Down Expand Up @@ -1041,7 +1042,7 @@ def __len__(self):
def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
verbose=0, fit_params=None, score_func=None,
pre_dispatch='2*n_jobs'):
"""Evaluate a score by cross-validation
"""Evaluate test score by cross-validation

Parameters
----------
Expand All @@ -1055,10 +1056,12 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
The target variable to try to predict in the case of
supervised learning.

scoring : string, callable or None, optional, default: None
scoring : string, callable, list of strings/callables or None, optional,
default: None
A string (see model evaluation documentation) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``.
Lists can be used for randomized search of multiple metrics.

cv : cross-validation generator, optional, default: None
A cross-validation generator. If None, a 3-fold cross
Expand Down Expand Up @@ -1094,78 +1097,57 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,

Returns
-------
scores : array of float, shape=(len(list(cv)),)
scores : array of float, shape=(n_folds,) or (n_scoring, n_folds)
Array of scores of the estimator for each run of the cross validation.
The returned array is 2d if `scoring` is a list.
"""
X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True)
cv = _check_cv(cv, X, y, classifier=is_classifier(estimator))
scorer = check_scoring(estimator, score_func=score_func, scoring=scoring)
# We clone the estimator to make sure that all the folds are
# independent, and that it is pickle-able.
parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
pre_dispatch=pre_dispatch)
scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer,
train, test, verbose, None,
fit_params)
for train, test in cv)
return np.array(scores)[:, 0]


def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters,
fit_params, return_train_score=False,
return_parameters=False):
"""Fit estimator and compute scores for a given dataset split.

Parameters
----------
estimator : estimator object implementing 'fit'
The object to use to fit the data.

X : array-like of shape at least 2D
The data to fit.

y : array-like, optional, default: None
The target variable to try to predict in the case of
supervised learning.

scoring : callable
A scorer callable object / function with signature
``scorer(estimator, X, y)``.

train : array-like, shape = (n_train_samples,)
Indices of training samples.
if isinstance(scoring, (tuple, list)):
scorers = [check_scoring(estimator, scoring=s) for s in scoring]
ret_1d = False
else:
scorers = [check_scoring(estimator, score_func=score_func,
scoring=scoring)]
ret_1d = True

test : array-like, shape = (n_test_samples,)
Indices of test samples.
parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
pre_dispatch=pre_dispatch)

verbose : integer
The verbosity level.
# `out` is a list of size n_folds. Each element of the list is a tuple
# (test_scores, n_test, train_time)
out = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorers,
train, test, verbose, None,
fit_params)
for train, test in cv)

parameters : dict or None
Parameters to be set on the estimator.
# Retrieve n_scorers x n_folds 2d-array.
test_scores = np.array([o[0] for o in out]).T

fit_params : dict or None
Parameters that will be passed to ``estimator.fit``.
if ret_1d:
return test_scores[0]
else:
return test_scores

return_train_score : boolean, optional, default: False
Compute and return score on training set.

return_parameters : boolean, optional, default: False
Return parameters that has been used for the estimator.
def _fit_and_score(estimator, X, y, scorers, train, test, verbose, parameters,
fit_params, return_train_scores=False):
"""Fit estimator and compute scores for a given dataset split.

Returns
-------
test_score : float
Score on test set.
train_score : array of floats, optional
Scores on training set.

train_score : float, optional
Score on training set.
test_score : array of floats
Scores on test set.

n_test_samples : int
Number of test samples.

scoring_time : float
Time spent for fitting and scoring in seconds.
train_time : float
Time spent for fitting in seconds.

parameters : dict or None, optional
The parameters that have been evaluated.
Expand All @@ -1188,30 +1170,35 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters,
if parameters is not None:
estimator.set_params(**parameters)

start_time = time.time()

X_train, y_train = _safe_split(estimator, X, y, train)
X_test, y_test = _safe_split(estimator, X, y, test, train)

start_time = time.time()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if y_train is None:
estimator.fit(X_train, **fit_params)
else:
estimator.fit(X_train, y_train, **fit_params)
test_score = _score(estimator, X_test, y_test, scorer)
if return_train_score:
train_score = _score(estimator, X_train, y_train, scorer)

scoring_time = time.time() - start_time
train_time = time.time() - start_time

test_scores = _evaluate_scorers(estimator, X_test, y_test, scorers)

if return_train_scores:
if len(scorers) == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the benefit of this if?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the len(scorers) == 1 case doesn't need to be treated separately indeed.

train_scores = np.array([scorers[0](estimator, X_train, y_train)])
else:
train_scores = _evaluate_scorers(estimator, X_train, y_train,
scorers)

if verbose > 2:
msg += ", score=%f" % test_score
msg += ", score=%s" % test_scores
if verbose > 1:
end_msg = "%s -%s" % (msg, logger.short_format_time(scoring_time))
end_msg = "%s -%s" % (msg, logger.short_format_time(train_time))
print("[CV] %s %s" % ((64 - len(end_msg)) * '.', end_msg))

ret = [train_score] if return_train_score else []
ret.extend([test_score, _num_samples(X_test), scoring_time])
if return_parameters:
ret.append(parameters)
ret = [train_scores] if return_train_scores else []
ret += [test_scores, _num_samples(X_test), train_time, parameters]
return ret


Expand Down Expand Up @@ -1247,18 +1234,6 @@ def _safe_split(estimator, X, y, indices, train_indices=None):
return X_subset, y_subset


def _score(estimator, X_test, y_test, scorer):
"""Compute the score of an estimator on a given test set."""
if y_test is None:
score = scorer(estimator, X_test)
else:
score = scorer(estimator, X_test, y_test)
if not isinstance(score, numbers.Number):
raise ValueError("scoring must return a number, got %s (%s) instead."
% (str(score), type(score)))
return score


def _permutation_test_score(estimator, X, y, cv, scorer):
"""Auxiliary function for permutation_test_score"""
avg_score = []
Expand Down
4 changes: 2 additions & 2 deletions sklearn/feature_selection/rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..base import clone
from ..base import is_classifier
from ..cross_validation import _check_cv as check_cv
from ..cross_validation import _safe_split, _score
from ..cross_validation import _safe_split
from .base import SelectorMixin
from ..metrics.scorer import check_scoring

Expand Down Expand Up @@ -342,7 +342,7 @@ def fit(self, X, y):
mask = np.where(ranking_ <= k + 1)[0]
estimator = clone(self.estimator)
estimator.fit(X_train[:, mask], y_train)
score = _score(estimator, X_test[:, mask], y_test, scorer)
score = scorer(estimator, X_test[:, mask], y_test)

if self.verbose > 0:
print("Finished fold with %d / %d feature ranks, score=%f"
Expand Down
Loading