diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 928225ef7454e..e9b8cf7bac5b1 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -1539,7 +1539,7 @@ def validation_curve(estimator, X, y, *, param_name, param_range, groups=None, def _aggregate_score_dicts(scores): """Aggregate the list of dict to dict of np ndarray - The aggregated output of _fit_and_score will be a list of dict + The aggregated output of _aggregate_score_dicts will be a list of dict of form [{'prec': 0.1, 'acc':1.0}, {'prec': 0.1, 'acc':1.0}, ...] Convert it to a dict of array {'prec': np.array([0.1 ...]), ...} @@ -1559,5 +1559,9 @@ def _aggregate_score_dicts(scores): {'a': array([1, 2, 3, 10]), 'b': array([10, 2, 3, 10])} """ - return {key: np.asarray([score[key] for score in scores]) - for key in scores[0]} + return { + key: np.asarray([score[key] for score in scores]) + if isinstance(scores[0][key], numbers.Number) + else [score[key] for score in scores] + for key in scores[0] + } diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 00b11191d902a..1cb08cc13f767 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -356,6 +356,23 @@ def test_cross_validate_invalid_scoring_param(): cross_validate, SVC(), X, y, scoring="mse") +def test_cross_validate_nested_estimator(): + # Non-regression test to ensure that nested + # estimators are properly returned in a list + # https://github.com/scikit-learn/scikit-learn/pull/17745 + (X, y) = load_iris(return_X_y=True) + pipeline = Pipeline([ + ("imputer", SimpleImputer()), + ("classifier", MockClassifier()), + ]) + + results = cross_validate(pipeline, X, y, return_estimator=True) + estimators = results["estimator"] + + assert isinstance(estimators, list) + assert all(isinstance(estimator, Pipeline) for estimator in estimators) + + def test_cross_validate(): # Compute train and test mse/r2 scores cv = KFold()