diff --git a/examples/model_selection/plot_learning_curve.py b/examples/model_selection/plot_learning_curve.py index a7045a0bf6c88..6df16956e6937 100644 --- a/examples/model_selection/plot_learning_curve.py +++ b/examples/model_selection/plot_learning_curve.py @@ -179,9 +179,9 @@ def plot_learning_curve( X, y = load_digits(return_X_y=True) title = "Learning Curves (Naive Bayes)" -# Cross validation with 100 iterations to get smoother mean test and train +# Cross validation with 50 iterations to get smoother mean test and train # score curves, each time with 20% data randomly selected as a validation set. -cv = ShuffleSplit(n_splits=100, test_size=0.2, random_state=0) +cv = ShuffleSplit(n_splits=50, test_size=0.2, random_state=0) estimator = GaussianNB() plot_learning_curve( @@ -190,7 +190,7 @@ def plot_learning_curve( title = r"Learning Curves (SVM, RBF kernel, $\gamma=0.001$)" # SVC is more expensive so we do a lower number of CV iterations: -cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0) +cv = ShuffleSplit(n_splits=5, test_size=0.2, random_state=0) estimator = SVC(gamma=0.001) plot_learning_curve( estimator, title, X, y, axes=axes[:, 1], ylim=(0.7, 1.01), cv=cv, n_jobs=4