diff --git a/examples/plot_validation_contours.py b/examples/plot_validation_contours.py new file mode 100644 index 0000000000000..ac8c308f53a2b --- /dev/null +++ b/examples/plot_validation_contours.py @@ -0,0 +1,50 @@ +import matplotlib.pyplot as plt +import numpy as np +from sklearn.datasets import load_digits +from sklearn.svm import SVC +from sklearn.learning_curve import validation_curve +from sklearn.externals.joblib import Memory + +memory = Memory(cachedir=".", verbose=0) + +@memory.cache +def grid(X, y, Cs, gammas): + param_grid = {"C": Cs, "gamma": gammas} + + tr, te, times = validation_curve(SVC(kernel="rbf"), X, y, param_grid, cv=3) + + shape = (len(Cs), len(gammas)) + tr = tr.mean(axis=1).reshape(shape) + te = te.mean(axis=1).reshape(shape) + times = times.mean(axis=1).reshape(shape) + + return tr, te, times + +digits = load_digits() +X, y = digits.data, digits.target + +gammas = np.logspace(-6, -1, 5) +Cs = np.logspace(-3, 3, 5) + +tr, te, times = grid(X, y, Cs, gammas) + + +for title, values in (("Training accuracy", tr), + ("Test accuracy", te), + ("Training time", times)): + + plt.figure() + + plt.title(title) + plt.xlabel("C") + plt.xscale("log") + + plt.ylabel("gamma") + plt.yscale("log") + + X1, X2 = np.meshgrid(Cs, gammas) + cs = plt.contour(X1, X2, values) + + plt.colorbar(cs) + +plt.show() diff --git a/examples/plot_validation_curve.py b/examples/plot_validation_curve.py index 7b5f05050183a..0c6a056089c0b 100644 --- a/examples/plot_validation_curve.py +++ b/examples/plot_validation_curve.py @@ -23,9 +23,9 @@ X, y = digits.data, digits.target param_range = np.logspace(-6, -1, 5) -train_scores, test_scores = validation_curve( - SVC(), X, y, param_name="gamma", param_range=param_range, - cv=10, scoring="accuracy", n_jobs=1) +param_grid = {"gamma": param_range} +train_scores, test_scores, train_times = validation_curve( + SVC(), X, y, param_grid, cv=10, scoring="accuracy", n_jobs=1) train_scores_mean = np.mean(train_scores, axis=1) train_scores_std = np.std(train_scores, axis=1) test_scores_mean = np.mean(test_scores, axis=1) diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index c679ea9caf6cb..46cf2e926333d 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -27,7 +27,8 @@ from .utils.fixes import unique from .externals.joblib import Parallel, delayed, logger from .externals.six import with_metaclass -from .metrics.scorer import check_scoring +from .metrics.scorer import check_scoring, _evaluate_scorers + __all__ = ['Bootstrap', 'KFold', @@ -1041,7 +1042,7 @@ def __len__(self): def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, verbose=0, fit_params=None, score_func=None, pre_dispatch='2*n_jobs'): - """Evaluate a score by cross-validation + """Evaluate test score by cross-validation Parameters ---------- @@ -1055,10 +1056,12 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, The target variable to try to predict in the case of supervised learning. - scoring : string, callable or None, optional, default: None + scoring : string, callable, list of strings/callables or None, optional, + default: None A string (see model evaluation documentation) or a scorer callable object / function with signature ``scorer(estimator, X, y)``. + Lists can be used for randomized search of multiple metrics. cv : cross-validation generator, optional, default: None A cross-validation generator. If None, a 3-fold cross @@ -1094,78 +1097,57 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, Returns ------- - scores : array of float, shape=(len(list(cv)),) + scores : array of float, shape=(n_folds,) or (n_scoring, n_folds) Array of scores of the estimator for each run of the cross validation. + The returned array is 2d if `scoring` is a list. """ X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True) cv = _check_cv(cv, X, y, classifier=is_classifier(estimator)) - scorer = check_scoring(estimator, score_func=score_func, scoring=scoring) - # We clone the estimator to make sure that all the folds are - # independent, and that it is pickle-able. - parallel = Parallel(n_jobs=n_jobs, verbose=verbose, - pre_dispatch=pre_dispatch) - scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer, - train, test, verbose, None, - fit_params) - for train, test in cv) - return np.array(scores)[:, 0] - - -def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, - fit_params, return_train_score=False, - return_parameters=False): - """Fit estimator and compute scores for a given dataset split. - - Parameters - ---------- - estimator : estimator object implementing 'fit' - The object to use to fit the data. - - X : array-like of shape at least 2D - The data to fit. - - y : array-like, optional, default: None - The target variable to try to predict in the case of - supervised learning. - - scoring : callable - A scorer callable object / function with signature - ``scorer(estimator, X, y)``. - train : array-like, shape = (n_train_samples,) - Indices of training samples. + if isinstance(scoring, (tuple, list)): + scorers = [check_scoring(estimator, scoring=s) for s in scoring] + ret_1d = False + else: + scorers = [check_scoring(estimator, score_func=score_func, + scoring=scoring)] + ret_1d = True - test : array-like, shape = (n_test_samples,) - Indices of test samples. + parallel = Parallel(n_jobs=n_jobs, verbose=verbose, + pre_dispatch=pre_dispatch) - verbose : integer - The verbosity level. + # `out` is a list of size n_folds. Each element of the list is a tuple + # (test_scores, n_test, train_time) + out = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorers, + train, test, verbose, None, + fit_params) + for train, test in cv) - parameters : dict or None - Parameters to be set on the estimator. + # Retrieve n_scorers x n_folds 2d-array. + test_scores = np.array([o[0] for o in out]).T - fit_params : dict or None - Parameters that will be passed to ``estimator.fit``. + if ret_1d: + return test_scores[0] + else: + return test_scores - return_train_score : boolean, optional, default: False - Compute and return score on training set. - return_parameters : boolean, optional, default: False - Return parameters that has been used for the estimator. +def _fit_and_score(estimator, X, y, scorers, train, test, verbose, parameters, + fit_params, return_train_scores=False): + """Fit estimator and compute scores for a given dataset split. Returns ------- - test_score : float - Score on test set. + train_score : array of floats, optional + Scores on training set. - train_score : float, optional - Score on training set. + test_score : array of floats + Scores on test set. n_test_samples : int Number of test samples. - scoring_time : float - Time spent for fitting and scoring in seconds. + train_time : float + Time spent for fitting in seconds. parameters : dict or None, optional The parameters that have been evaluated. @@ -1188,30 +1170,35 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, if parameters is not None: estimator.set_params(**parameters) - start_time = time.time() - X_train, y_train = _safe_split(estimator, X, y, train) X_test, y_test = _safe_split(estimator, X, y, test, train) + + start_time = time.time() + if y_train is None: estimator.fit(X_train, **fit_params) else: estimator.fit(X_train, y_train, **fit_params) - test_score = _score(estimator, X_test, y_test, scorer) - if return_train_score: - train_score = _score(estimator, X_train, y_train, scorer) - scoring_time = time.time() - start_time + train_time = time.time() - start_time + + test_scores = _evaluate_scorers(estimator, X_test, y_test, scorers) + + if return_train_scores: + if len(scorers) == 1: + train_scores = np.array([scorers[0](estimator, X_train, y_train)]) + else: + train_scores = _evaluate_scorers(estimator, X_train, y_train, + scorers) if verbose > 2: - msg += ", score=%f" % test_score + msg += ", score=%s" % test_scores if verbose > 1: - end_msg = "%s -%s" % (msg, logger.short_format_time(scoring_time)) + end_msg = "%s -%s" % (msg, logger.short_format_time(train_time)) print("[CV] %s %s" % ((64 - len(end_msg)) * '.', end_msg)) - ret = [train_score] if return_train_score else [] - ret.extend([test_score, _num_samples(X_test), scoring_time]) - if return_parameters: - ret.append(parameters) + ret = [train_scores] if return_train_scores else [] + ret += [test_scores, _num_samples(X_test), train_time, parameters] return ret @@ -1247,18 +1234,6 @@ def _safe_split(estimator, X, y, indices, train_indices=None): return X_subset, y_subset -def _score(estimator, X_test, y_test, scorer): - """Compute the score of an estimator on a given test set.""" - if y_test is None: - score = scorer(estimator, X_test) - else: - score = scorer(estimator, X_test, y_test) - if not isinstance(score, numbers.Number): - raise ValueError("scoring must return a number, got %s (%s) instead." - % (str(score), type(score))) - return score - - def _permutation_test_score(estimator, X, y, cv, scorer): """Auxiliary function for permutation_test_score""" avg_score = [] diff --git a/sklearn/feature_selection/rfe.py b/sklearn/feature_selection/rfe.py index 01c99ceb526f4..05f376250c024 100644 --- a/sklearn/feature_selection/rfe.py +++ b/sklearn/feature_selection/rfe.py @@ -13,7 +13,7 @@ from ..base import clone from ..base import is_classifier from ..cross_validation import _check_cv as check_cv -from ..cross_validation import _safe_split, _score +from ..cross_validation import _safe_split from .base import SelectorMixin from ..metrics.scorer import check_scoring @@ -342,7 +342,7 @@ def fit(self, X, y): mask = np.where(ranking_ <= k + 1)[0] estimator = clone(self.estimator) estimator.fit(X_train[:, mask], y_train) - score = _score(estimator, X_test[:, mask], y_test, scorer) + score = scorer(estimator, X_test[:, mask], y_test) if self.verbose > 0: print("Finished fold with %d / %d feature ranks, score=%f" diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index 8d217521f1269..c6bcb358def9c 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -30,8 +30,8 @@ from .metrics.scorer import check_scoring -__all__ = ['GridSearchCV', 'ParameterGrid', 'fit_grid_point', - 'ParameterSampler', 'RandomizedSearchCV'] +__all__ = ['GridSearchCV', 'ParameterGrid', 'ParameterSampler', + 'RandomizedSearchCV'] class ParameterGrid(object): @@ -182,6 +182,103 @@ def __len__(self): return self.n_iter +def _fit_param_iter(estimator, X, y, scoring, parameter_iterable, refit, + cv, pre_dispatch, fit_params, iid, n_jobs, verbose): + """Actual fitting, performing the search over parameters.""" + + estimator = clone(estimator) + + if isinstance(scoring, (tuple, list)): + scorers = [check_scoring(estimator, scoring=s) for s in scoring] + ret_1d = False + else: + scorers = [check_scoring(estimator, scoring=scoring)] + ret_1d = True + + n_samples = _num_samples(X) + X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr') + + if y is not None: + if len(y) != n_samples: + raise ValueError('Target variable (y) has a different number ' + 'of samples (%i) than data (X: %i samples)' + % (len(y), n_samples)) + y = np.asarray(y) + cv = check_cv(cv, X, y, classifier=is_classifier(estimator)) + + if verbose > 0: + if isinstance(parameter_iterable, Sized): + n_candidates = len(parameter_iterable) + print("Fitting {0} folds for each of {1} candidates, totalling" + " {2} fits".format(len(cv), n_candidates, + n_candidates * len(cv))) + + out = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch)( + delayed(_fit_and_score)( + clone(estimator), X, y, scorers, train, test, + verbose, parameters, fit_params) + for parameters in parameter_iterable + for train, test in cv) + + # `out` is a list of tuples (fold_scores, n_test, train_time, params). + n_fits = len(out) + n_folds = len(cv) + n_scorers = len(scorers) + + grid_scores = [] + for i in xrange(n_scorers): + grid_scores.append([]) + + for grid_start in range(0, n_fits, n_folds): + + grid_stop = grid_start + n_folds + fold_scores, n_test, _, parameters = zip(*out[grid_start:grid_stop]) + # `params` contains the same parameters n_fold times. + parameters = parameters[0] + # `fold_scores` is an n_folds x n_scorers 2-d array. + fold_scores = np.array(fold_scores) + weights = n_test if iid else None + mean_scores = np.average(fold_scores, axis=0, weights=weights) + + for i in xrange(n_scorers): + # TODO: shall we also store the test_fold_sizes? + tup = _CVScoreTuple(parameters, mean_scores[i], fold_scores[:, i]) + grid_scores[i].append(tup) + + # Find the best parameters by comparing on the mean validation score: + # note that `sorted` is deterministic in the way it breaks ties + bests = [sorted(grid_scores[i], key=lambda x: x.mean_validation_score, + reverse=True)[0] for i in xrange(n_scorers)] + best_params = [best.parameters for best in bests] + best_scores = [best.mean_validation_score for best in bests] + + best_estimators = [] + if refit: + for i in xrange(len(scorers)): + estimator = clone(estimator) + best_estimator = estimator.set_params(**best_params[i]) + best_estimators.append(best_estimator) + if y is not None: + best_estimator.fit(X, y, **fit_params) + else: + best_estimator.fit(X, **fit_params) + + if ret_1d: + scorers = scorers[0] + best_params = best_params[0] + best_scores = best_scores[0] + grid_scores = grid_scores[0] + if refit: + best_estimators = best_estimators[0] + + ret = [scorers, best_params, best_scores, grid_scores] + + if refit: + ret.append(best_estimators) + + return ret + + def fit_grid_point(X, y, estimator, parameters, train, test, scorer, verbose, **fit_params): """Run fit on one set of parameters. @@ -228,10 +325,13 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer, n_samples_test : int Number of test samples in this split. """ - score, n_samples_test, _ = _fit_and_score(estimator, X, y, scorer, train, + warnings.warn("fit_grid_point is deprecated and will be " + "removed in 0.17", DeprecationWarning, stacklevel=1) + + scores, n_samples_test, _ = _fit_and_score(estimator, X, y, [scorer], train, test, verbose, parameters, fit_params) - return score, parameters, n_samples_test + return scores[0], parameters, n_samples_test def _check_param_grid(param_grid): @@ -340,93 +440,25 @@ def transform(self): return self.best_estimator_.transform def _fit(self, X, y, parameter_iterable): - """Actual fitting, performing the search over parameters.""" - - estimator = self.estimator - cv = self.cv - self.scorer_ = check_scoring(self.estimator, scoring=self.scoring, - loss_func=self.loss_func, - score_func=self.score_func) - - n_samples = _num_samples(X) - X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr') - - if y is not None: - if len(y) != n_samples: - raise ValueError('Target variable (y) has a different number ' - 'of samples (%i) than data (X: %i samples)' - % (len(y), n_samples)) - y = np.asarray(y) - cv = check_cv(cv, X, y, classifier=is_classifier(estimator)) - - if self.verbose > 0: - if isinstance(parameter_iterable, Sized): - n_candidates = len(parameter_iterable) - print("Fitting {0} folds for each of {1} candidates, totalling" - " {2} fits".format(len(cv), n_candidates, - n_candidates * len(cv))) - - base_estimator = clone(self.estimator) - - pre_dispatch = self.pre_dispatch - - out = Parallel( - n_jobs=self.n_jobs, verbose=self.verbose, - pre_dispatch=pre_dispatch - )( - delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_, - train, test, self.verbose, parameters, - self.fit_params, return_parameters=True) - for parameters in parameter_iterable - for train, test in cv) + ret = _fit_param_iter(self.estimator, X, y, self.scoring, + parameter_iterable, self.refit, self.cv, + self.pre_dispatch, self.fit_params, self.iid, + self.n_jobs, self.verbose) - # Out is a list of triplet: score, estimator, n_test_samples - n_fits = len(out) - n_folds = len(cv) - - scores = list() - grid_scores = list() - for grid_start in range(0, n_fits, n_folds): - n_test_samples = 0 - score = 0 - all_scores = [] - for this_score, this_n_test_samples, _, parameters in \ - out[grid_start:grid_start + n_folds]: - all_scores.append(this_score) - if self.iid: - this_score *= this_n_test_samples - n_test_samples += this_n_test_samples - score += this_score - if self.iid: - score /= float(n_test_samples) - else: - score /= float(n_folds) - scores.append((score, parameters)) - # TODO: shall we also store the test_fold_sizes? - grid_scores.append(_CVScoreTuple( - parameters, - score, - np.array(all_scores))) - # Store the computed scores - self.grid_scores_ = grid_scores - - # Find the best parameters by comparing on the mean validation score: - # note that `sorted` is deterministic in the way it breaks ties - best = sorted(grid_scores, key=lambda x: x.mean_validation_score, - reverse=True)[0] - self.best_params_ = best.parameters - self.best_score_ = best.mean_validation_score + self.scorer_ = ret[0] + self.best_params_ = ret[1] + self.best_score_ = ret[2] + self.grid_scores_ = ret[3] if self.refit: - # fit the best estimator using the entire dataset - # clone first to work around broken estimators - best_estimator = clone(base_estimator).set_params( - **best.parameters) - if y is not None: - best_estimator.fit(X, y, **self.fit_params) + if isinstance(ret[4], list): + self.best_estimators_ = ret[4] + # By default, select the best estimator corresponding to the + # first scorer. + self.best_estimator_ = ret[4][0] else: - best_estimator.fit(X, **self.fit_params) - self.best_estimator_ = best_estimator + self.best_estimator_ = ret[4] + return self @@ -599,6 +631,7 @@ def fit(self, X, y=None, **params): warnings.warn("Additional parameters to GridSearchCV are ignored!" " The params argument will be removed in 0.15.", DeprecationWarning) + return self._fit(X, y, ParameterGrid(self.param_grid)) diff --git a/sklearn/learning_curve.py b/sklearn/learning_curve.py index 21a3f12111ee6..5185eb915ff0a 100644 --- a/sklearn/learning_curve.py +++ b/sklearn/learning_curve.py @@ -10,7 +10,8 @@ from .cross_validation import _check_cv from .utils import check_arrays from .externals.joblib import Parallel, delayed -from .cross_validation import _safe_split, _score, _fit_and_score +from .cross_validation import _safe_split, _fit_and_score +from .grid_search import ParameterGrid from .metrics.scorer import check_scoring @@ -127,17 +128,18 @@ def learning_curve(estimator, X, y, train_sizes=np.linspace(0.1, 1.0, 10), clone(estimator), X, y, classes, train, test, train_sizes_abs, scorer, verbose) for train, test in cv) else: + # ret is a list of size n_folds. Each element of the list contains the + # tuple returned by _fit_and_score. out = parallel(delayed(_fit_and_score)( - clone(estimator), X, y, scorer, train[:n_train_samples], test, - verbose, parameters=None, fit_params=None, return_train_score=True) + clone(estimator), X, y, [scorer], train[:n_train_samples], test, + verbose, parameters=None, fit_params=None, return_train_scores=True) for train, test in cv for n_train_samples in train_sizes_abs) - out = np.array(out)[:, :2] - n_cv_folds = out.shape[0] / n_unique_ticks - out = out.reshape(n_cv_folds, n_unique_ticks, 2) - out = np.asarray(out).transpose((2, 1, 0)) + out = np.array(out).reshape(len(cv), len(train_sizes_abs), -1) + train_scores = out[:, :, 0].T + test_scores = out[:, :, 1].T - return train_sizes_abs, out[0], out[1] + return train_sizes_abs, train_scores, test_scores def _translate_train_sizes(train_sizes, n_max_training_samples): @@ -204,23 +206,27 @@ def _incremental_fit_estimator(estimator, X, y, classes, train, test, """Train estimator on training subsets incrementally and compute scores.""" train_scores, test_scores = [], [] partitions = zip(train_sizes, np.split(train, train_sizes)[:-1]) + for n_train_samples, partial_train in partitions: train_subset = train[:n_train_samples] X_train, y_train = _safe_split(estimator, X, y, train_subset) X_partial_train, y_partial_train = _safe_split(estimator, X, y, partial_train) X_test, y_test = _safe_split(estimator, X, y, test, train_subset) + if y_partial_train is None: estimator.partial_fit(X_partial_train, classes=classes) else: estimator.partial_fit(X_partial_train, y_partial_train, classes=classes) - train_scores.append(_score(estimator, X_train, y_train, scorer)) - test_scores.append(_score(estimator, X_test, y_test, scorer)) + + train_scores.append(scorer(estimator, X_train, y_train)) + test_scores.append(scorer(estimator, X_test, y_test)) + return np.array((train_scores, test_scores)).T -def validation_curve(estimator, X, y, param_name, param_range, cv=None, +def validation_curve(estimator, X, y, param_grid, cv=None, scoring=None, n_jobs=1, pre_dispatch="all", verbose=0): """Validation curve. @@ -244,11 +250,9 @@ def validation_curve(estimator, X, y, param_name, param_range, cv=None, Target relative to X for classification or regression; None for unsupervised learning. - param_name : string - Name of the parameter that will be varied. - - param_range : array-like, shape (n_values,) - The values of the parameter that will be evaluated. + param_grid : dict or list of dictionaries + Dictionary with parameters names (string) as keys and lists of + parameter settings to try as values. cv : integer, cross-validation generator, optional If an integer is passed, it is the number of folds (defaults to 3). @@ -273,12 +277,17 @@ def validation_curve(estimator, X, y, param_name, param_range, cv=None, Returns ------- - train_scores : array, shape (n_ticks, n_cv_folds) + train_scores : array, shape (n_params, n_cv_folds) or + (n_scorers, n_params, n_cv_folds) Scores on training sets. - test_scores : array, shape (n_ticks, n_cv_folds) + test_scores : array, shape (n_params, n_cv_folds) or + (n_scorers, n_params, n_cv_folds) Scores on test set. + train_times : array, shape (n_params, n_cv_folds) + Training times. + Notes ----- See @@ -286,18 +295,34 @@ def validation_curve(estimator, X, y, param_name, param_range, cv=None, """ X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True) cv = _check_cv(cv, X, y, classifier=is_classifier(estimator)) - scorer = check_scoring(estimator, scoring=scoring) + + if isinstance(scoring, (tuple, list)): + scorer = [check_scoring(estimator, scoring=s) for s in scoring] + one_scorer = False + else: + scorer = [check_scoring(estimator, scoring=scoring)] + one_scorer = True + + param_grid = ParameterGrid(param_grid) + n_params = len(param_grid) parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch, verbose=verbose) out = parallel(delayed(_fit_and_score)( estimator, X, y, scorer, train, test, verbose, - parameters={param_name : v}, fit_params=None, return_train_score=True) - for train, test in cv for v in param_range) + parameters=params, fit_params=None, return_train_scores=True) + for train, test in cv for params in param_grid) + + n_folds = len(out) / n_params + + shape = (n_folds, n_params, -1) + train_scores = np.array([o[0] for o in out]).reshape(shape).T + test_scores = np.array([o[1] for o in out]).reshape(shape).T + + train_times = np.array([o[3] for o in out]).reshape(n_folds, n_params).T - out = np.asarray(out)[:, :2] - n_params = len(param_range) - n_cv_folds = out.shape[0] / n_params - out = out.reshape(n_cv_folds, n_params, 2).transpose((2, 1, 0)) + if one_scorer: + train_scores = train_scores[0] + test_scores = test_scores[0] - return out[0], out[1] + return train_scores, test_scores, train_times diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index 1d26671b67851..5be832dbbd51e 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -18,8 +18,8 @@ # Arnaud Joly # License: Simplified BSD -from abc import ABCMeta, abstractmethod from warnings import warn +import numbers import numpy as np @@ -30,132 +30,135 @@ from .cluster import adjusted_rand_score from ..utils.multiclass import type_of_target from ..externals import six +from ..base import is_classifier -class _BaseScorer(six.with_metaclass(ABCMeta, object)): - def __init__(self, score_func, sign, kwargs): - self._kwargs = kwargs - self._score_func = score_func - self._sign = sign +class _Scorer(object): + + def __init__(self, score_func, greater_is_better=True, needs_proba=False, + needs_threshold=False, kwargs={}): + self.score_func = score_func + self.greater_is_better = greater_is_better + self.needs_proba = needs_proba + self.needs_threshold = needs_threshold + self.kwargs = kwargs - @abstractmethod def __call__(self, estimator, X, y): - pass + return _evaluate_scorers(estimator, X, y, [self])[0] def __repr__(self): kwargs_string = "".join([", %s=%s" % (str(k), str(v)) - for k, v in self._kwargs.items()]) - return ("make_scorer(%s%s%s%s)" - % (self._score_func.__name__, - "" if self._sign > 0 else ", greater_is_better=False", - self._factory_args(), kwargs_string)) - - def _factory_args(self): - """Return non-default make_scorer arguments for repr.""" - return "" - - -class _PredictScorer(_BaseScorer): - def __call__(self, estimator, X, y_true): - """Evaluate predicted target values for X relative to y_true. - - Parameters - ---------- - estimator : object - Trained estimator to use for scoring. Must have a predict_proba - method; the output of that is used to compute the score. - - X : array-like or sparse matrix - Test data that will be fed to estimator.predict. - - y_true : array-like - Gold standard target values for X. - - Returns - ------- - score : float - Score function applied to prediction of estimator on X. - """ - y_pred = estimator.predict(X) - return self._sign * self._score_func(y_true, y_pred, **self._kwargs) - - -class _ProbaScorer(_BaseScorer): - def __call__(self, clf, X, y): - """Evaluate predicted probabilities for X relative to y_true. - - Parameters - ---------- - clf : object - Trained classifier to use for scoring. Must have a predict_proba - method; the output of that is used to compute the score. - - X : array-like or sparse matrix - Test data that will be fed to clf.predict_proba. - - y : array-like - Gold standard target values for X. These must be class labels, - not probabilities. - - Returns - ------- - score : float - Score function applied to prediction of estimator on X. - """ - y_pred = clf.predict_proba(X) - return self._sign * self._score_func(y, y_pred, **self._kwargs) - - def _factory_args(self): - return ", needs_proba=True" - - -class _ThresholdScorer(_BaseScorer): - def __call__(self, clf, X, y): - """Evaluate decision function output for X relative to y_true. - - Parameters - ---------- - clf : object - Trained classifier to use for scoring. Must have either a - decision_function method or a predict_proba method; the output of - that is used to compute the score. - - X : array-like or sparse matrix - Test data that will be fed to clf.decision_function or - clf.predict_proba. - - y : array-like - Gold standard target values for X. These must be class labels, - not decision function values. - - Returns - ------- - score : float - Score function applied to prediction of estimator on X. - """ - y_type = type_of_target(y) - if y_type not in ("binary", "multilabel-indicator"): - raise ValueError("{0} format is not supported".format(y_type)) - + for k, v in self.kwargs.items()]) + return ("make_scorer(%s%s%s)" + % (self.score_func.__name__, + "" if self.greater_is_better else ", greater_is_better=False", + kwargs_string)) + + +def _evaluate_scorers(estimator, X, y, scorers): + """Evaluate a list of scorers. `scorers` may contain _Scorer objects or + callables of the form callable(estimator, X, y).""" + + if len(scorers) == 1 and not isinstance(scorers[0], _Scorer): + # We won't need any predictions if there is only one callable in the + # list. + return np.array([scorers[0](estimator, X, y)]) + + has_pb = hasattr(estimator, "predict_proba") + has_df = hasattr(estimator, "decision_function") + _is_classifier = is_classifier(estimator) + _type_of_y = type_of_target(y) if y is not None else None + + # Make a first pass through scorers to determine if we need + # predict_proba or decision_function. + needs_proba = False + needs_df = False + for scorer in scorers: + if not isinstance(scorer, _Scorer): + continue # assumed to be a callable + + if scorer.needs_proba: + if not has_pb: + raise ValueError("%s needs probabilities but predict_proba is" + "not available in %s." % (scorer, estimator)) + needs_proba = True + + elif scorer.needs_threshold: + if has_pb: + # We choose predict_proba first because its interface + # is more consistent across the project. + needs_proba = True + continue + + if _is_classifier and not has_df: + raise ValueError("%s needs continuous outputs but neither" + "predict_proba nor decision_function " + "are available in %s." % (scorer, estimator)) + + if _is_classifier: + needs_df = True + + # Compute predict_proba if needed. + y_proba = None + y_pred = None + if needs_proba: try: - y_pred = clf.decision_function(X) + y_proba = estimator.predict_proba(X) + + y_pred = estimator.classes_[y_proba.argmax(axis=1)] - # For multi-output multi-class estimator - if isinstance(y_pred, list): - y_pred = np.vstack(p for p in y_pred).T + if _type_of_y == "binary": + y_proba = y_proba[:, 1] except (NotImplementedError, AttributeError): - y_pred = clf.predict_proba(X) + # SVC has predict_proba but it may raise NotImplementedError + # if probabilities are not enabled. + needs_proba = False + needs_df = True + + # Compute decision_function. + df = None + if needs_df: + df = estimator.decision_function(X) + + if len(df.shape) == 2 and df.shape[1] >= 2: + y_pred = estimator.classes_[df.argmax(axis=1)] + else: + y_pred = estimator.classes_[(df >= 0).astype(int)] + + # Compute y_pred if needed. + if y_pred is None: + y_pred = estimator.predict(X) - if y_type == "binary": - y_pred = y_pred[:, 1] - elif isinstance(y_pred, list): - y_pred = np.vstack([p[:, -1] for p in y_pred]).T + # Compute scores. + scores = [] + for scorer in scorers: + if not isinstance(scorer, _Scorer): + scores.append(scorer(estimator, X, y)) + continue - return self._sign * self._score_func(y, y_pred, **self._kwargs) + if scorer.needs_proba: + score = scorer.score_func(y, y_proba, **scorer.kwargs) - def _factory_args(self): - return ", needs_threshold=True" + elif scorer.needs_threshold: + if y_proba is not None: + score = scorer.score_func(y, y_proba, **scorer.kwargs) + elif df is not None: + score = scorer.score_func(y, df, **scorer.kwargs) + else: + score = scorer.score_func(y, y_pred, **scorer.kwargs) + + else: + score = scorer.score_func(y, y_pred, **scorer.kwargs) + + if not isinstance(score, numbers.Number): + raise ValueError("scoring must return a number, got %s (%s)" + " instead." % (str(score), type(score))) + + sign = 1 if scorer.greater_is_better else -1 + scores.append(sign * score) + + return np.array(scores) def get_scorer(scoring): @@ -171,9 +174,12 @@ def get_scorer(scoring): return scorer -def _passthrough_scorer(estimator, *args, **kwargs): +def _default_scorer(estimator, X, y, *args, **kwargs): """Function that wraps estimator.score""" - return estimator.score(*args, **kwargs) + if y is None: + return estimator.score(X, *args, **kwargs) + else: + return estimator.score(X, y, *args, **kwargs) def check_scoring(estimator, scoring=None, allow_none=False, loss_func=None, @@ -198,10 +204,13 @@ def check_scoring(estimator, scoring=None, allow_none=False, loss_func=None, Returns ------- - scoring : callable + scorer : callable A scorer callable object / function with signature ``scorer(estimator, X, y)``. """ + if isinstance(scoring, _Scorer): + return scoring + has_scoring = not (scoring is None and loss_func is None and score_func is None) if not hasattr(estimator, 'fit'): @@ -229,7 +238,7 @@ def check_scoring(estimator, scoring=None, allow_none=False, loss_func=None, scorer = get_scorer(scoring) return scorer elif hasattr(estimator, 'score'): - return _passthrough_scorer + return _default_scorer elif not has_scoring: if allow_none: return None @@ -293,17 +302,15 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False, >>> grid = GridSearchCV(LinearSVC(), param_grid={'C': [1, 10]}, ... scoring=ftwo_scorer) """ - sign = 1 if greater_is_better else -1 if needs_proba and needs_threshold: raise ValueError("Set either needs_proba or needs_threshold to True," " but not both.") - if needs_proba: - cls = _ProbaScorer - elif needs_threshold: - cls = _ThresholdScorer - else: - cls = _PredictScorer - return cls(score_func, sign, kwargs) + + return _Scorer(score_func, + greater_is_better=greater_is_better, + needs_proba=needs_proba, + needs_threshold=needs_threshold, + kwargs=kwargs) # Standard regression scores diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 6a97044ccd3e6..cff1e72c5e6fd 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -8,22 +8,94 @@ from sklearn.utils.testing import assert_true from sklearn.utils.testing import ignore_warnings -from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score, - log_loss) +from sklearn.metrics import (accuracy_score, f1_score, r2_score, roc_auc_score, + fbeta_score, log_loss, mean_squared_error, + average_precision_score) from sklearn.metrics.cluster import adjusted_rand_score -from sklearn.metrics.scorer import check_scoring +from sklearn.metrics.scorer import check_scoring, _evaluate_scorers from sklearn.metrics import make_scorer, SCORERS -from sklearn.svm import LinearSVC +from sklearn.svm import LinearSVC, SVC from sklearn.cluster import KMeans from sklearn.linear_model import Ridge, LogisticRegression -from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.datasets import make_blobs +from sklearn.datasets import load_iris +from sklearn.datasets import make_classification from sklearn.datasets import make_multilabel_classification from sklearn.datasets import load_diabetes from sklearn.cross_validation import train_test_split, cross_val_score from sklearn.grid_search import GridSearchCV from sklearn.multiclass import OneVsRestClassifier +# FIXME: temporary, to demonstrate ranking with several relevance levels. +def dcg_score(y_true, y_score, k=10, gains="exponential"): + order = np.argsort(y_score)[::-1] + y_true = np.take(y_true, order[:k]) + + if gains == "exponential": + gains = 2 ** y_true - 1 + elif gains == "linear": + gains = y_true + else: + raise ValueError("Invalid gains option.") + + # highest rank is 1 so +2 instead of +1 + discounts = np.log2(np.arange(len(y_true)) + 2) + return np.sum(gains / discounts) + +dcg_scorer = make_scorer(dcg_score, needs_threshold=True) + + +class EstimatorWithoutFit(object): + """Dummy estimator to test check_scoring""" + pass + + +class EstimatorWithFit(object): + """Dummy estimator to test check_scoring""" + def fit(self, X, y): + return self + + +class EstimatorWithFitAndScore(object): + """Dummy estimator to test check_scoring""" + def fit(self, X, y): + return self + + def score(self, X, y): + return 1.0 + + +class EstimatorWithFitAndPredict(object): + """Dummy estimator to test check_scoring""" + def fit(self, X, y): + self.y = y + return self + + def predict(self, X): + return self.y + + +def test_check_scoring(): + """Test all branches of check_scoring""" + estimator = EstimatorWithoutFit() + assert_raises(TypeError, check_scoring, estimator) + + estimator = EstimatorWithFitAndScore() + estimator.fit([[1]], [1]) + scorer = check_scoring(estimator) + assert_almost_equal(scorer(estimator, [[1]], [1]), 1.0) + + estimator = EstimatorWithFitAndPredict() + estimator.fit([[1]], [1]) + assert_raises(TypeError, check_scoring, estimator) + + scorer = check_scoring(estimator, "accuracy") + assert_almost_equal(scorer(estimator, [[1]], [1]), 1.0) + + estimator = EstimatorWithFit() + assert_raises(TypeError, check_scoring, estimator) + class EstimatorWithoutFit(object): """Dummy estimator to test check_scoring""" @@ -170,27 +242,6 @@ def test_thresholded_scorers_multilabel_indicator_data(): random_state=0) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - # Multi-output multi-class predict_proba - clf = DecisionTreeClassifier() - clf.fit(X_train, y_train) - y_proba = clf.predict_proba(X_test) - score1 = SCORERS['roc_auc'](clf, X_test, y_test) - score2 = roc_auc_score(y_test, np.vstack(p[:, -1] for p in y_proba).T) - assert_almost_equal(score1, score2) - - # Multi-output multi-class decision_function - # TODO Is there any yet? - clf = DecisionTreeClassifier() - clf.fit(X_train, y_train) - clf._predict_proba = clf.predict_proba - clf.predict_proba = None - clf.decision_function = lambda X: [p[:, 1] for p in clf._predict_proba(X)] - - y_proba = clf.decision_function(X_test) - score1 = SCORERS['roc_auc'](clf, X_test, y_test) - score2 = roc_auc_score(y_test, np.vstack(p for p in y_proba).T) - assert_almost_equal(score1, score2) - # Multilabel predict_proba clf = OneVsRestClassifier(DecisionTreeClassifier()) clf.fit(X_train, y_train) @@ -218,6 +269,112 @@ def test_unsupervised_scorers(): assert_almost_equal(score1, score2) +def test_evaluate_scorers_binary(): + X, y = make_classification(n_classes=2, random_state=0) + + # Test a classifier with decision_function. + for clf in (SVC(), LinearSVC()): + clf.fit(X, y) + + s1, s2 = _evaluate_scorers(clf, X, y, [SCORERS["f1"], + SCORERS["roc_auc"]]) + df = clf.decision_function(X) + y_pred = clf.predict(X) + + assert_almost_equal(s1, f1_score(y, y_pred)) + assert_almost_equal(s2, roc_auc_score(y, df)) + + # Test a classifier with predict_proba. + clf = LogisticRegression() + clf.fit(X, y) + + s1, s2 = _evaluate_scorers(clf, X, y, [SCORERS["f1"], + SCORERS["roc_auc"]]) + y_proba = clf.predict_proba(X)[:, 1] + y_pred = clf.predict(X) + + assert_almost_equal(s1, f1_score(y, y_pred)) + assert_almost_equal(s2, roc_auc_score(y, y_proba)) + + +def test_evaluate_scorers_multiclass(): + iris = load_iris() + X, y = iris.data, iris.target + + # Test a classifier with decision_function. + clf = LinearSVC() + clf.fit(X, y) + + s1, s2 = _evaluate_scorers(clf, X, y, [SCORERS["f1"], + SCORERS["accuracy"]]) + y_pred = clf.predict(X) + + assert_almost_equal(s1, f1_score(y, y_pred)) + assert_almost_equal(s2, accuracy_score(y, y_pred)) + + # Test a classifier with predict_proba. + clf = LogisticRegression() + clf.fit(X, y) + + s1, s2, s3 = _evaluate_scorers(clf, X, y, [SCORERS["f1"], + SCORERS["accuracy"], + SCORERS["log_loss"]]) + y_proba = clf.predict_proba(X) + y_pred = clf.predict(X) + + assert_almost_equal(s1, f1_score(y, y_pred)) + assert_almost_equal(s2, accuracy_score(y, y_pred)) + assert_almost_equal(s3, -log_loss(y, y_proba)) + + +def test_evaluate_scorers_regression(): + diabetes = load_diabetes() + X, y = diabetes.data, diabetes.target + + reg = Ridge() + reg.fit(X, y) + + s1, s2 = _evaluate_scorers(reg, X, y, [SCORERS["r2"], + SCORERS["mean_squared_error"]]) + y_pred = reg.predict(X) + + assert_almost_equal(s1, r2_score(y, y_pred)) + assert_almost_equal(s2, -mean_squared_error(y, y_pred)) + + +def test_evaluate_scorers_ranking_by_regression(): + X, y = make_classification(n_classes=2, random_state=0) + + reg = DecisionTreeRegressor() + reg.fit(X, y) + + s1, s2 = _evaluate_scorers(reg, X, y, [SCORERS["roc_auc"], + SCORERS["average_precision"]]) + y_pred = reg.predict(X) + + assert_almost_equal(s1, roc_auc_score(y, y_pred)) + assert_almost_equal(s2, average_precision_score(y, y_pred)) + + diabetes = load_diabetes() + X, y = diabetes.data, diabetes.target + + reg.fit(X, y) + + s1, s2 = _evaluate_scorers(reg, X, y, [SCORERS["r2"], + dcg_scorer]) + y_pred = reg.predict(X) + + assert_almost_equal(s1, r2_score(y, y_pred)) + assert_almost_equal(s2, dcg_score(y, y_pred)) + + +def test_evaluate_scorers_exceptions(): + clf = LinearSVC() + # log_loss needs probabilities but LinearSVC does not have predict_proba. + assert_raises(ValueError, _evaluate_scorers, clf, [], [], + [SCORERS["log_loss"]]) + + @ignore_warnings def test_raises_on_score_list(): """Test that when a list of scores is returned, we raise proper errors.""" diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py index 206e34e03cc4b..f1e5cc22731ce 100644 --- a/sklearn/tests/test_cross_validation.py +++ b/sklearn/tests/test_cross_validation.py @@ -22,6 +22,7 @@ from sklearn import cross_validation as cval from sklearn.base import BaseEstimator +from sklearn.datasets import make_classification from sklearn.datasets import make_regression from sklearn.datasets import load_digits from sklearn.datasets import load_iris @@ -32,7 +33,7 @@ from sklearn.metrics import make_scorer from sklearn.externals import six -from sklearn.linear_model import Ridge +from sklearn.linear_model import Ridge, Perceptron from sklearn.svm import SVC @@ -510,6 +511,23 @@ def test_cross_val_score_precomputed(): linear_kernel.tolist(), y) +def test_cross_val_score_multiple_scorers(): + X, y = make_classification(n_classes=2, random_state=0) + clf = Perceptron(random_state=0) + + scores = cval.cross_val_score(clf, X, y, cv=3, scoring=["f1", "roc_auc"]) + assert_equal(scores.shape, (2, 3)) + + # Check that the results are the same as when cross_val_score is called + # individually. + f1_scores = cval.cross_val_score(clf, X, y, cv=3, scoring="f1") + auc_scores = cval.cross_val_score(clf, X, y, cv=3, scoring="roc_auc") + scores2 = np.array([f1_scores, auc_scores]) + assert_equal(scores2.shape, (2, 3)) + + assert_array_almost_equal(scores, scores2) + + def test_cross_val_score_fit_params(): clf = MockClassifier() n_samples = X.shape[0] diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index 92e3ffa9f19c2..248e1b02c9ebc 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -24,7 +24,7 @@ from scipy.stats import distributions -from sklearn.base import BaseEstimator +from sklearn.base import BaseEstimator, clone from sklearn.datasets import make_classification from sklearn.datasets import make_blobs from sklearn.datasets import make_multilabel_classification @@ -644,3 +644,34 @@ def test_grid_search_with_multioutput_data(): correct_score = est.score(X[test], y[test]) assert_almost_equal(correct_score, cv_validation_scores[i]) + + +def test_multiple_grid_search_cv(): + clf = LinearSVC(random_state=0) + X, y = make_blobs(random_state=0, centers=2) + param_grid = {"C": [0.1, 1, 10]} + scoring = ["f1", "roc_auc"] + + gs = GridSearchCV(clf, param_grid, scoring=scoring) + rs = RandomizedSearchCV(clf, param_grid, scoring=scoring, random_state=0) + + for n, est in enumerate((gs, rs)): + est.fit(X, y) + + for attr in ("scorer_", "best_score_", "grid_scores_", "best_params_"): + attr = getattr(est, attr) + assert_equal(len(attr), 2) + + est_f1 = clone(est) + est_f1.scoring = "f1" + est_f1.fit(X, y) + + est_auc = clone(est) + est_auc.scoring = "roc_auc" + est_auc.fit(X, y) + + for attr in ("best_score_", "best_params_"): + assert_equal(getattr(est, attr)[0], + getattr(est_f1, attr)) + assert_equal(getattr(est, attr)[1], + getattr(est_auc, attr)) diff --git a/sklearn/tests/test_learning_curve.py b/sklearn/tests/test_learning_curve.py index 42985823345ed..72e8bda244e2c 100644 --- a/sklearn/tests/test_learning_curve.py +++ b/sklearn/tests/test_learning_curve.py @@ -15,6 +15,7 @@ from sklearn.datasets import make_classification from sklearn.cross_validation import KFold from sklearn.linear_model import PassiveAggressiveClassifier +from sklearn.svm import LinearSVC class MockImprovingEstimator(BaseEstimator): @@ -237,8 +238,42 @@ def test_validation_curve(): n_redundant=0, n_classes=2, n_clusters_per_class=1, random_state=0) param_range = np.linspace(0, 1, 10) - train_scores, test_scores = validation_curve(MockEstimatorWithParameter(), - X, y, param_name="param", - param_range=param_range, cv=2) + param_grid = {"param": param_range} + est = MockEstimatorWithParameter() + train_scores, test_scores, train_times = validation_curve(est, X, y, + param_grid, cv=2) + assert_equal(train_scores.shape, (10, 2)) + assert_equal(test_scores.shape, (10, 2)) + assert_equal(train_times.shape, (10, 2)) assert_array_almost_equal(train_scores.mean(axis=1), param_range) assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range) + + +def test_validation_curve_2d(): + X, y = make_classification(n_classes=2, random_state=0) + param_grid = {"C": [1, 10, 100], "fit_intercept": [True, False]} + clf = LinearSVC(random_state=0) + train_scores, test_scores, train_times = validation_curve(clf, X, y, + param_grid, cv=2) + assert_equal(train_scores.shape, (6, 2)) + assert_equal(test_scores.shape, (6, 2)) + assert_equal(train_times.shape, (6, 2)) + + +def test_validation_curve_multiple_scorers(): + X, y = make_classification(n_classes=2, random_state=0) + clf = LinearSVC(random_state=0) + param_grid = {"C": [0.1, 1, 10, 100]} + scoring = ["f1", "roc_auc"] + train_scores, test_scores, train_times = validation_curve(clf, X, y, + param_grid, cv=3, + scoring=scoring) + assert_equal(train_scores.shape, (2, 4, 3)) + assert_equal(test_scores.shape, (2, 4, 3)) + assert_equal(train_times.shape, (4, 3)) + + for i, scoring in enumerate(("f1", "roc_auc")): + tr, te, ti = validation_curve(clf, X, y, param_grid, cv=3, + scoring=scoring) + assert_array_almost_equal(train_scores[i], tr) + assert_array_almost_equal(test_scores[i], te)