Skip to content

Commit 7ce7134

Browse files
jnothmanlesteve
authored andcommitted
[MRG+2] ENH Loop over candidates as outer loop in search (#8322)
This encourages concurrent fits to be over *different datasets* so that fits over the same data subset are more likely to run in serial and hence generate cache hits where memoisation is used.
1 parent 098fd31 commit 7ce7134

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

sklearn/model_selection/_search.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -600,8 +600,8 @@ def fit(self, X, y=None, groups=None, **fit_params):
600600
return_n_test_samples=True,
601601
return_times=True, return_parameters=False,
602602
error_score=self.error_score)
603-
for train, test in cv.split(X, y, groups)
604-
for parameters in candidate_params)
603+
for parameters, (train, test) in product(candidate_params,
604+
cv.split(X, y, groups)))
605605

606606
# if one choose to see train score, "out" will contain train score info
607607
if self.return_train_score:
@@ -615,8 +615,8 @@ def fit(self, X, y=None, groups=None, **fit_params):
615615
def _store(key_name, array, weights=None, splits=False, rank=False):
616616
"""A small helper to store the scores/times to the cv_results_"""
617617
# When iterated first by splits, then by parameters
618-
array = np.array(array, dtype=np.float64).reshape(n_splits,
619-
n_candidates).T
618+
array = np.array(array, dtype=np.float64).reshape(n_candidates,
619+
n_splits)
620620
if splits:
621621
for split_i in range(n_splits):
622622
results["split%d_%s"
@@ -636,7 +636,7 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
636636

637637
# Computed the (weighted) mean and std for test scores alone
638638
# NOTE test_sample counts (weights) remain the same for all candidates
639-
test_sample_counts = np.array(test_sample_counts[::n_candidates],
639+
test_sample_counts = np.array(test_sample_counts[:n_splits],
640640
dtype=np.int)
641641

642642
_store('test_score', test_scores, splits=True, rank=True,

0 commit comments

Comments
 (0)