Skip to content

Commit c3478e3

Browse files
committed
ENH/FIX timing and training score.
* ENH separate fit / score times * Make score_time=0 if errored; Ignore warnings in test * Cleanup docstrings * ENH Use helper to store the results * Move fit time computation to else of try...except...else * DOC readable sample scores * COSMIT Add a commnent on why time test is >= 0 instead of > 0 (Windows time.time precision is not accurate enought to be non-zero for trivial fits)
1 parent 30df3f0 commit c3478e3

File tree

4 files changed

+269
-185
lines changed

4 files changed

+269
-185
lines changed

doc/whats_new.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,20 @@ Model Selection Enhancements and API Changes
100100
The parameter ``n_labels`` in the newly renamed
101101
:class:`model_selection.LeavePGroupsOut` is changed to ``n_groups``.
102102

103+
- Training scores and Timing information
104+
105+
``cv_results_`` also includes the training scores for each
106+
cross-validation split (with keys such as ``'split0_train_score'``), as
107+
well as their mean (``'mean_train_score'``) and standard deviation
108+
(``'std_train_score'``). To avoid the cost of evaluating training score,
109+
set ``return_train_score=False``.
110+
111+
Additionally the mean and standard deviation of the times taken to split,
112+
train and score the model across all the cross-validation splits is
113+
available at the key ``'mean_time'`` and ``'std_time'`` respectively.
114+
115+
Changelog
116+
---------
103117

104118
New features
105119
............
@@ -349,6 +363,12 @@ Enhancements
349363
now accept arbitrary kernel functions in addition to strings ``knn`` and ``rbf``.
350364
(`#5762 <https://github.com/scikit-learn/scikit-learn/pull/5762>`_) By `Utkarsh Upadhyay`_.
351365

366+
- The training scores and time taken for training followed by scoring for
367+
each search candidate are now available at the ``cv_results_`` dict.
368+
See :ref:`model_selection_changes` for more information.
369+
(`#7324 <https://github.com/scikit-learn/scikit-learn/pull/7325>`)
370+
By `Eugene Chen`_ and `Raghav RV`_.
371+
352372

353373
Bug fixes
354374
.........
@@ -4651,3 +4671,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
46514671
.. _Russell Smith: https://github.com/rsmith54
46524672

46534673
.. _Utkarsh Upadhyay: https://github.com/musically-ut
4674+
4675+
.. _Eugene Chen: https://github.com/eyc88

sklearn/model_selection/_search.py

Lines changed: 75 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,9 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
319319
"""
320320
score, n_samples_test, _ = _fit_and_score(estimator, X, y, scorer, train,
321321
test, verbose, parameters,
322-
fit_params, error_score)
322+
fit_params=fit_params,
323+
return_n_test_samples=True,
324+
error_score=error_score)
323325
return score, parameters, n_samples_test
324326

325327

@@ -552,77 +554,61 @@ def _fit(self, X, y, groups, parameter_iterable):
552554
pre_dispatch=pre_dispatch
553555
)(delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
554556
train, test, self.verbose, parameters,
555-
self.fit_params,
557+
fit_params=self.fit_params,
556558
return_train_score=self.return_train_score,
557-
return_parameters=True,
559+
return_n_test_samples=True,
560+
return_times=True, return_parameters=True,
558561
error_score=self.error_score)
559562
for parameters in parameter_iterable
560563
for train, test in cv.split(X, y, groups))
561564

562565
# if one choose to see train score, "out" will contain train score info
563566
if self.return_train_score:
564-
train_scores, test_scores, test_sample_counts, time, parameters =\
565-
zip(*out)
567+
(train_scores, test_scores, test_sample_counts,
568+
fit_time, score_time, parameters) = zip(*out)
566569
else:
567-
test_scores, test_sample_counts, time, parameters = zip(*out)
570+
(test_scores, test_sample_counts,
571+
fit_time, score_time, parameters) = zip(*out)
568572

569573
candidate_params = parameters[::n_splits]
570574
n_candidates = len(candidate_params)
571575

572-
# if one choose to return train score, reshape the train_scores array
573-
if self.return_train_score:
574-
train_scores = np.array(train_scores,
575-
dtype=np.float64).reshape(n_candidates,
576+
results = dict()
577+
578+
def _store(key_name, array, weights=None, splits=False, rank=False):
579+
"""A small helper to store the scores/times to the cv_results_"""
580+
array = np.array(array, dtype=np.float64).reshape(n_candidates,
576581
n_splits)
577-
test_scores = np.array(test_scores,
578-
dtype=np.float64).reshape(n_candidates,
579-
n_splits)
582+
if splits:
583+
for split_i in range(n_splits):
584+
results["split%d_%s"
585+
% (split_i, key_name)] = array[:, split_i]
586+
587+
array_means = np.average(array, axis=1, weights=weights)
588+
results['mean_%s' % key_name] = array_means
589+
# Weighted std is not directly available in numpy
590+
array_stds = np.sqrt(np.average((array -
591+
array_means[:, np.newaxis]) ** 2,
592+
axis=1, weights=weights))
593+
results['std_%s' % key_name] = array_stds
594+
595+
if rank:
596+
results["rank_%s" % key_name] = np.asarray(
597+
rankdata(-array_means, method='min'), dtype=np.int32)
598+
599+
# Computed the (weighted) mean and std for test scores alone
580600
# NOTE test_sample counts (weights) remain the same for all candidates
581601
test_sample_counts = np.array(test_sample_counts[:n_splits],
582602
dtype=np.int)
583603

584-
# Computed the (weighted) mean and std for test scores
585-
weights = test_sample_counts if self.iid else None
586-
test_means = np.average(test_scores, axis=1, weights=weights)
587-
test_stds = np.sqrt(
588-
np.average((test_scores - test_means[:, np.newaxis]) ** 2, axis=1,
589-
weights=weights))
590-
591-
time = np.array(time, dtype=np.float64).reshape(n_candidates, n_splits)
592-
time_means = np.average(time, axis=1)
593-
time_stds = np.sqrt(
594-
np.average((time - time_means[:, np.newaxis]) ** 2,
595-
axis=1))
596-
if self.return_train_score:
597-
train_means = np.average(train_scores, axis=1)
598-
train_stds = np.sqrt(
599-
np.average((train_scores - train_means[:, np.newaxis]) ** 2,
600-
axis=1))
601-
602-
cv_results = dict()
603-
for split_i in range(n_splits):
604-
cv_results["split%d_test_score" % split_i] = test_scores[:,
605-
split_i]
606-
cv_results["mean_test_score"] = means
607-
cv_results["std_test_score"] = stds
604+
_store('test_score', test_scores, splits=True, rank=True,
605+
weights=test_sample_counts if self.iid else None)
606+
_store('train_score', train_scores, splits=True)
607+
_store('fit_time', fit_time)
608+
_store('score_time', score_time)
608609

609-
if self.return_train_score:
610-
for split_i in range(n_splits):
611-
results["train_split%d_score" % split_i] = (
612-
train_scores[:, split_i])
613-
results["mean_train_score"] = train_means
614-
results["std_train_scores"] = train_stds
615-
results["rank_train_scores"] = np.asarray(rankdata(-train_means,
616-
method='min'),
617-
dtype=np.int32)
618-
619-
results["mean_test_time"] = time_means
620-
results["std_test_time"] = time_stds
621-
ranks = np.asarray(rankdata(-test_means, method='min'), dtype=np.int32)
622-
623-
best_index = np.flatnonzero(ranks == 1)[0]
610+
best_index = np.flatnonzero(results["rank_test_score"] == 1)[0]
624611
best_parameters = candidate_params[best_index]
625-
cv_results["rank_test_score"] = ranks
626612

627613
# Use one np.MaskedArray and mask all the places where the param is not
628614
# applicable for that candidate. Use defaultdict as each candidate may
@@ -636,12 +622,12 @@ def _fit(self, X, y, groups, parameter_iterable):
636622
# Setting the value at an index also unmasks that index
637623
param_results["param_%s" % name][cand_i] = value
638624

639-
cv_results.update(param_results)
625+
results.update(param_results)
640626

641627
# Store a list of param dicts at the key 'params'
642-
cv_results['params'] = candidate_params
628+
results['params'] = candidate_params
643629

644-
self.cv_results_ = cv_results
630+
self.cv_results_ = results
645631
self.best_index_ = best_index
646632
self.n_splits_ = n_splits
647633

@@ -783,8 +769,8 @@ class GridSearchCV(BaseSearchCV):
783769
FitFailedWarning is raised. This parameter does not affect the refit
784770
step, which will always raise the error.
785771
786-
return_train_score: boolean, default=True
787-
If ``'False'``, the results_ attribute will not include training
772+
return_train_score : boolean, default=True
773+
If ``'False'``, the ``cv_results_`` attribute will not include training
788774
scores.
789775
790776
@@ -809,10 +795,12 @@ class GridSearchCV(BaseSearchCV):
809795
scoring=..., verbose=...)
810796
>>> sorted(clf.cv_results_.keys())
811797
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
812-
['mean_test_score', 'mean_test_time', 'mean_train_score',...
813-
'param_C', 'param_kernel', 'params', 'rank_test_score',...
814-
'split0_test_score', 'split1_test_score',...
815-
'split2_test_score', 'std_test_score', 'std_test_time'...]
798+
['mean_fit_time', 'mean_score_time', 'mean_test_score',...
799+
'mean_train_score', 'param_C', 'param_kernel', 'params',...
800+
'rank_test_score', 'split0_test_score',...
801+
'split0_train_score', 'split1_test_score', 'split1_train_score',...
802+
'split2_test_score', 'split2_train_score',...
803+
'std_fit_time', 'std_score_time', 'std_test_score', 'std_train_score'...]
816804
817805
Attributes
818806
----------
@@ -843,25 +831,24 @@ class GridSearchCV(BaseSearchCV):
843831
mask = [ True True False False]...),
844832
'param_degree': masked_array(data = [2.0 3.0 -- --],
845833
mask = [False False True True]...),
846-
'split0_test_score' : [0.8, 0.7, 0.8, 0.9],
847-
'split1_test_score' : [0.82, 0.5, 0.7, 0.78],
848-
'mean_test_score' : [0.81, 0.60, 0.75, 0.82],
849-
'std_test_score' : [0.02, 0.01, 0.03, 0.03],
850-
'rank_test_score' : [2, 4, 3, 1],
851-
'split0_train_score': [0.9, 0.8, 0.85, 1.]
852-
'split1_train_score': [0.95, 0.7, 0.8, 0.8]
853-
'mean_train_score' : [0.93, 0.75, 0.83, 0.9]
854-
'std_train_score' : [0.02, 0.01, 0.03, 0.03],
855-
'rank_train_score' : [2, 4, 3, 1],
856-
'mean_test_time' : [0.00073, 0.00063, 0.00043, 0.00049]
857-
'std_test_time' : [1.62e-4, 3.37e-5, 1.42e-5, 1.1e-5]
858-
'params' : [{'kernel': 'poly', 'degree': 2}, ...],
834+
'split0_test_score' : [0.8, 0.7, 0.8, 0.9],
835+
'split1_test_score' : [0.82, 0.5, 0.7, 0.78],
836+
'mean_test_score' : [0.81, 0.60, 0.75, 0.82],
837+
'std_test_score' : [0.02, 0.01, 0.03, 0.03],
838+
'rank_test_score' : [2, 4, 3, 1],
839+
'split0_train_score' : [0.8, 0.9, 0.7],
840+
'split1_train_score' : [0.82, 0.5, 0.7],
841+
'mean_train_score' : [0.81, 0.7, 0.7],
842+
'std_train_score' : [0.03, 0.03, 0.04],
843+
'mean_fit_time' : [0.73, 0.63, 0.43, 0.49],
844+
'std_fit_time' : [0.01, 0.02, 0.01, 0.01],
845+
'mean_score_time' : [0.007, 0.06, 0.04, 0.04],
846+
'std_score_time' : [0.001, 0.002, 0.003, 0.005],
847+
'params' : [{'kernel': 'poly', 'degree': 2}, ...],
859848
}
860849
861850
NOTE that the key ``'params'`` is used to store a list of parameter
862-
settings dict for all the parameter candidates. Besides,
863-
``'train_mean_score'``, ``'train_split*_score'``, ... will be present
864-
when ``return_train_score=True``.
851+
settings dict for all the parameter candidates.
865852
866853
best_estimator_ : estimator
867854
Estimator that was chosen by the search, i.e. estimator
@@ -920,7 +907,7 @@ class GridSearchCV(BaseSearchCV):
920907
def __init__(self, estimator, param_grid, scoring=None, fit_params=None,
921908
n_jobs=1, iid=True, refit=True, cv=None, verbose=0,
922909
pre_dispatch='2*n_jobs', error_score='raise',
923-
return_train_score=False):
910+
return_train_score=True):
924911
super(GridSearchCV, self).__init__(
925912
estimator=estimator, scoring=scoring, fit_params=fit_params,
926913
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
@@ -1059,8 +1046,8 @@ class RandomizedSearchCV(BaseSearchCV):
10591046
FitFailedWarning is raised. This parameter does not affect the refit
10601047
step, which will always raise the error.
10611048
1062-
return_train_score: boolean, default=True
1063-
If ``'False'``, the results_ attribute will not include training
1049+
return_train_score : boolean, default=True
1050+
If ``'False'``, the ``cv_results_`` attribute will not include training
10641051
scores.
10651052
10661053
Attributes
@@ -1095,19 +1082,16 @@ class RandomizedSearchCV(BaseSearchCV):
10951082
'split0_train_score' : [0.8, 0.9, 0.7],
10961083
'split1_train_score' : [0.82, 0.5, 0.7],
10971084
'mean_train_score' : [0.81, 0.7, 0.7],
1098-
'std_train_score' : [0.00073, 0.00063, 0.00043]
1099-
'rank_train_score' : [1.62e-4, 3.37e-5, 1.1e-5]
1100-
'test_mean_time' : [0.00073, 0.00063, 0.00043]
1101-
'test_std_time' : [1.62e-4, 3.37e-5, 1.1e-5]
1102-
'test_std_score' : [0.02, 0.2, 0.],
1103-
'test_rank_score' : [3, 1, 1],
1085+
'std_train_score' : [0.03, 0.03, 0.04],
1086+
'mean_fit_time' : [0.73, 0.63, 0.43, 0.49],
1087+
'std_fit_time' : [0.01, 0.02, 0.01, 0.01],
1088+
'mean_score_time' : [0.007, 0.06, 0.04, 0.04],
1089+
'std_score_time' : [0.001, 0.002, 0.003, 0.005],
11041090
'params' : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...],
11051091
}
11061092
11071093
NOTE that the key ``'params'`` is used to store a list of parameter
1108-
settings dict for all the parameter candidates. Besides,
1109-
'train_mean_score', 'train_split*_score', ... will be present when
1110-
return_train_score is set to True.
1094+
settings dict for all the parameter candidates.
11111095
11121096
best_estimator_ : estimator
11131097
Estimator that was chosen by the search, i.e. estimator
@@ -1162,7 +1146,7 @@ class RandomizedSearchCV(BaseSearchCV):
11621146
def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
11631147
fit_params=None, n_jobs=1, iid=True, refit=True, cv=None,
11641148
verbose=0, pre_dispatch='2*n_jobs', random_state=None,
1165-
error_score='raise', return_train_score=False):
1149+
error_score='raise', return_train_score=True):
11661150
self.param_distributions = param_distributions
11671151
self.n_iter = n_iter
11681152
self.random_state = random_state

sklearn/model_selection/_validation.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
"""
23
The :mod:`sklearn.model_selection._validation` module includes classes and
34
functions to validate the model.
@@ -142,7 +143,8 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
142143

143144
def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
144145
parameters, fit_params, return_train_score=False,
145-
return_parameters=False, error_score='raise'):
146+
return_parameters=False, return_n_test_samples=False,
147+
return_times=False, error_score='raise'):
146148
"""Fit estimator and compute scores for a given dataset split.
147149
148150
Parameters
@@ -199,8 +201,11 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
199201
n_test_samples : int
200202
Number of test samples.
201203
202-
scoring_time : float
203-
Time spent for fitting and scoring in seconds.
204+
fit_time : float
205+
Time spent for fitting in seconds.
206+
207+
score_time : float
208+
Time spent for scoring in seconds.
204209
205210
parameters : dict or None, optional
206211
The parameters that have been evaluated.
@@ -233,6 +238,9 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
233238
estimator.fit(X_train, y_train, **fit_params)
234239

235240
except Exception as e:
241+
# Note fit time as time until error
242+
fit_time = time.time() - start_time
243+
score_time = 0.0
236244
if error_score == 'raise':
237245
raise
238246
elif isinstance(error_score, numbers.Number):
@@ -248,20 +256,24 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
248256
" make sure that it has been spelled correctly.)")
249257

250258
else:
259+
fit_time = time.time() - start_time
251260
test_score = _score(estimator, X_test, y_test, scorer)
261+
score_time = time.time() - start_time - fit_time
252262
if return_train_score:
253263
train_score = _score(estimator, X_train, y_train, scorer)
254264

255-
scoring_time = time.time() - start_time
256-
257265
if verbose > 2:
258266
msg += ", score=%f" % test_score
259267
if verbose > 1:
260-
end_msg = "%s -%s" % (msg, logger.short_format_time(scoring_time))
268+
end_msg = "%s -%s" % (msg, logger.short_format_time(score_time))
261269
print("[CV] %s %s" % ((64 - len(end_msg)) * '.', end_msg))
262270

263-
ret = [train_score] if return_train_score else []
264-
ret.extend([test_score, _num_samples(X_test), scoring_time])
271+
ret = [train_score, test_score] if return_train_score else [test_score]
272+
273+
if return_n_test_samples:
274+
ret.append(_num_samples(X_test))
275+
if return_times:
276+
ret.extend([fit_time, score_time])
265277
if return_parameters:
266278
ret.append(parameters)
267279
return ret
@@ -758,7 +770,7 @@ def learning_curve(estimator, X, y, groups=None,
758770
verbose, parameters=None, fit_params=None, return_train_score=True)
759771
for train, test in cv_iter
760772
for n_train_samples in train_sizes_abs)
761-
out = np.array(out)[:, :2]
773+
out = np.array(out)
762774
n_cv_folds = out.shape[0] // n_unique_ticks
763775
out = out.reshape(n_cv_folds, n_unique_ticks, 2)
764776

@@ -941,7 +953,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
941953
parameters={param_name: v}, fit_params=None, return_train_score=True)
942954
for train, test in cv.split(X, y, groups) for v in param_range)
943955

944-
out = np.asarray(out)[:, :2]
956+
out = np.asarray(out)
945957
n_params = len(param_range)
946958
n_cv_folds = out.shape[0] // n_params
947959
out = out.reshape(n_cv_folds, n_params, 2).transpose((2, 1, 0))

0 commit comments

Comments
 (0)