|
23 | 23 | from sklearn.ensemble import HistGradientBoostingClassifier
|
24 | 24 | from sklearn.exceptions import FitFailedWarning
|
25 | 25 | from sklearn.experimental import enable_halving_search_cv # noqa
|
| 26 | +from sklearn.feature_extraction.text import TfidfVectorizer |
26 | 27 | from sklearn.impute import SimpleImputer
|
27 | 28 | from sklearn.linear_model import (
|
28 | 29 | LinearRegression,
|
|
56 | 57 | )
|
57 | 58 | from sklearn.model_selection._search import BaseSearchCV
|
58 | 59 | from sklearn.model_selection.tests.common import OneTimeSplitter
|
| 60 | +from sklearn.naive_bayes import ComplementNB |
59 | 61 | from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor
|
60 | 62 | from sklearn.pipeline import Pipeline
|
61 | 63 | from sklearn.svm import SVC, LinearSVC
|
@@ -2492,6 +2494,35 @@ def test_search_estimator_param(SearchCV, param_search):
|
2492 | 2494 | assert gs.best_estimator_.named_steps["clf"].C == 0.01
|
2493 | 2495 |
|
2494 | 2496 |
|
| 2497 | +def test_search_with_2d_array(): |
| 2498 | + parameter_grid = { |
| 2499 | + "vect__ngram_range": ((1, 1), (1, 2)), # unigrams or bigrams |
| 2500 | + "vect__norm": ("l1", "l2"), |
| 2501 | + } |
| 2502 | + pipeline = Pipeline( |
| 2503 | + [ |
| 2504 | + ("vect", TfidfVectorizer()), |
| 2505 | + ("clf", ComplementNB()), |
| 2506 | + ] |
| 2507 | + ) |
| 2508 | + random_search = RandomizedSearchCV( |
| 2509 | + estimator=pipeline, |
| 2510 | + param_distributions=parameter_grid, |
| 2511 | + n_iter=3, |
| 2512 | + random_state=0, |
| 2513 | + n_jobs=2, |
| 2514 | + verbose=1, |
| 2515 | + cv=3, |
| 2516 | + ) |
| 2517 | + data_train = ["one", "two", "three", "four", "five"] |
| 2518 | + data_target = [0, 0, 1, 0, 1] |
| 2519 | + random_search.fit(data_train, data_target) |
| 2520 | + result = random_search.cv_results_["param_vect__ngram_range"] |
| 2521 | + expected_data = np.empty(3, dtype=object) |
| 2522 | + expected_data[:] = [(1, 2), (1, 2), (1, 1)] |
| 2523 | + np.testing.assert_array_equal(result.data, expected_data) |
| 2524 | + |
| 2525 | + |
2495 | 2526 | # Metadata Routing Tests
|
2496 | 2527 | # ======================
|
2497 | 2528 |
|
|
0 commit comments