-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
WIP new, simpler scorer API #2123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
# License: BSD 3 clause | ||
|
||
from abc import ABCMeta, abstractmethod | ||
from collections import Mapping, namedtuple, Sized | ||
from collections import Mapping, namedtuple, Sequence, Sized | ||
from functools import partial, reduce | ||
from itertools import product | ||
import numbers | ||
|
@@ -28,7 +28,7 @@ | |
from .externals import six | ||
from .utils import safe_mask, check_random_state | ||
from .utils.validation import _num_samples, check_arrays | ||
from .metrics import SCORERS, Scorer | ||
from .metrics import make_scorer, SCORERS | ||
|
||
|
||
__all__ = ['GridSearchCV', 'ParameterGrid', 'fit_grid_point', | ||
|
@@ -316,8 +316,10 @@ def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer, | |
else: | ||
this_score = clf.score(X_test) | ||
|
||
if not isinstance(this_score, numbers.Number): | ||
raise ValueError("scoring must return a number, got %s (%s)" | ||
if not isinstance(this_score, numbers.Number) \ | ||
and not (isinstance(this_score, Sequence) | ||
and isinstance(this_score[0], numbers.Number)): | ||
raise ValueError("scoring must return a number or tuple, got %s (%s)" | ||
" instead." % (str(this_score), type(this_score))) | ||
|
||
if verbose > 2: | ||
|
@@ -364,10 +366,17 @@ class _CVScoreTuple (namedtuple('_CVScoreTuple', | |
|
||
def __repr__(self): | ||
"""Simple custom repr to summarize the main info""" | ||
std = np.std([sc if isinstance(sc, numbers.Number) else sc[0] | ||
for sc in self.cv_validation_scores]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two remarks independent from this PR but that I think should be addressed now (i.e. before merge):
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current state of affairs in master is that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I know, and I think that it is wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't think so. It's not an attrib of an estimator, but an attrib of an object returned by an underscored attrib of an estimator. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Fair enough. But I still think that it would be good (not mandatory, |
||
|
||
return "mean: {0:.5f}, std: {1:.5f}, params: {2}".format( | ||
self.mean_validation_score, | ||
np.std(self.cv_validation_scores), | ||
self.parameters) | ||
self.mean_validation_score, std, self.parameters) | ||
|
||
def __str__(self): | ||
"""More extensive reporting than from repr.""" | ||
per_fold = ("\n fold {0}: {1}".format(i, sc) | ||
for i, sc in enumerate(self.cv_validation_scores)) | ||
return repr(self) + "".join(per_fold) | ||
|
||
|
||
class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, | ||
|
@@ -392,6 +401,33 @@ def __init__(self, estimator, scoring=None, loss_func=None, | |
self.pre_dispatch = pre_dispatch | ||
self._check_estimator() | ||
|
||
def report(self, file=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not convinced by the format of this. Do we really need a report function that's little different from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or, indeed which is identical to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be much more useful to output something like a CSV, but that requires interpreting the data more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a proof of concept. I wanted to make clear in some way that just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the long run, we might want such features, but in the short run, I'd There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think that teaching people to use pprint is a good idea. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think |
||
"""Generate a report of the scores achieved. | ||
|
||
Reports on the scores achieved across the folds for the various | ||
parameter settings tried. This also prints the additional information | ||
reported by some scorers, such as "f1", which tracks precision and | ||
recall as well. | ||
|
||
Parameters | ||
---------- | ||
file : file-like, optional | ||
File to which the report is written. If None or not given, the | ||
report is returned as a string. | ||
""" | ||
if not hasattr(self, "cv_scores_"): | ||
raise AttributeError("no cv_scores_ found; run fit first") | ||
|
||
return_string = (file is None) | ||
if return_string: | ||
file = six.StringIO() | ||
|
||
for cvs in self.cv_scores_: | ||
print(cvs, file=file) | ||
|
||
if return_string: | ||
return file.getvalue() | ||
|
||
def score(self, X, y=None): | ||
"""Returns the score on the given test data and labels, if the search | ||
estimator has been refit. The ``score`` function of the best estimator | ||
|
@@ -465,13 +501,13 @@ def _fit(self, X, y, parameter_iterable): | |
"deprecated and will be removed in 0.15. " | ||
"Either use strings or score objects." | ||
"The relevant new parameter is called ''scoring''. ") | ||
scorer = Scorer(self.loss_func, greater_is_better=False) | ||
scorer = make_scorer(self.loss_func, greater_is_better=False) | ||
elif self.score_func is not None: | ||
warnings.warn("Passing function as ``score_func`` is " | ||
"deprecated and will be removed in 0.15. " | ||
"Either use strings or score objects." | ||
"The relevant new parameter is called ''scoring''.") | ||
scorer = Scorer(self.score_func) | ||
scorer = make_scorer(self.score_func) | ||
elif isinstance(self.scoring, six.string_types): | ||
scorer = SCORERS[self.scoring] | ||
else: | ||
|
@@ -507,7 +543,7 @@ def _fit(self, X, y, parameter_iterable): | |
for parameters in parameter_iterable | ||
for train, test in cv) | ||
|
||
# Out is a list of triplet: score, estimator, n_test_samples | ||
# Out is a list of triples: score, estimator, n_test_samples | ||
n_fits = len(out) | ||
n_folds = len(cv) | ||
|
||
|
@@ -519,7 +555,11 @@ def _fit(self, X, y, parameter_iterable): | |
all_scores = [] | ||
for this_score, parameters, this_n_test_samples in \ | ||
out[grid_start:grid_start + n_folds]: | ||
all_scores.append(this_score) | ||
full_info = this_score | ||
if isinstance(this_score, Sequence): | ||
# Structured score. | ||
this_score = this_score[0] | ||
all_scores.append(full_info) | ||
if self.iid: | ||
this_score *= this_n_test_samples | ||
n_test_samples += this_n_test_samples | ||
|
@@ -530,18 +570,14 @@ def _fit(self, X, y, parameter_iterable): | |
score /= float(n_folds) | ||
scores.append((score, parameters)) | ||
# TODO: shall we also store the test_fold_sizes? | ||
cv_scores.append(_CVScoreTuple( | ||
parameters, | ||
score, | ||
np.array(all_scores))) | ||
cv_scores.append(_CVScoreTuple(parameters, score, all_scores)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm... Sticking all the scores in one field has its advantages, but it's not clear how we fit training scores or times in here without changing the length of the namedtuple (breaking forwards compatibility), or without somehow modifying and restructuring the namedtuple returned by the scorer. I still think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 on both accounts. |
||
# Store the computed scores | ||
self.cv_scores_ = cv_scores | ||
|
||
# Find the best parameters by comparing on the mean validation score: | ||
# note that `sorted` is deterministic in the way it breaks ties | ||
greater_is_better = getattr(self.scorer_, 'greater_is_better', True) | ||
best = sorted(cv_scores, key=lambda x: x.mean_validation_score, | ||
reverse=greater_is_better)[0] | ||
reverse=True)[0] | ||
self.best_params_ = best.parameters | ||
self.best_score_ = best.mean_validation_score | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did that change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the scoring from accuracy (the default) to F1 score to demo and test the structured return values from
f_scorer
and F1 score ≤ accuracy. This is also why the best parameter set changed.