From 945aea9ee3f955b56d8ce93c06faf7f71b653e7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joa=CC=83o=20Pedro=20Morais?= <15629444+ojoaomorais@users.noreply.github.com> Date: Tue, 21 May 2024 15:37:48 -0300 Subject: [PATCH 1/2] removing warnings from plot_cv_indices examples --- examples/model_selection/plot_cv_indices.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/model_selection/plot_cv_indices.py b/examples/model_selection/plot_cv_indices.py index e6c3580c787f0..58e09644a59d7 100644 --- a/examples/model_selection/plot_cv_indices.py +++ b/examples/model_selection/plot_cv_indices.py @@ -12,10 +12,15 @@ """ +import warnings + import matplotlib.pyplot as plt import numpy as np from matplotlib.patches import Patch +# Removing warnings from examples +warnings.filterwarnings("ignore") + from sklearn.model_selection import ( GroupKFold, GroupShuffleSplit, From bf10111d8e4ea2fcaa91994b2a82e3c854514b4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joa=CC=83o=20Pedro=20Morais?= <15629444+ojoaomorais@users.noreply.github.com> Date: Wed, 22 May 2024 08:46:18 -0300 Subject: [PATCH 2/2] Applying #29072 PR suggestions --- examples/model_selection/plot_cv_indices.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/examples/model_selection/plot_cv_indices.py b/examples/model_selection/plot_cv_indices.py index 58e09644a59d7..d456546891069 100644 --- a/examples/model_selection/plot_cv_indices.py +++ b/examples/model_selection/plot_cv_indices.py @@ -12,15 +12,10 @@ """ -import warnings - import matplotlib.pyplot as plt import numpy as np from matplotlib.patches import Patch -# Removing warnings from examples -warnings.filterwarnings("ignore") - from sklearn.model_selection import ( GroupKFold, GroupShuffleSplit, @@ -104,9 +99,10 @@ def visualize_groups(classes, groups, name): def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): """Create a sample plot for indices of a cross-validation object.""" - + use_groups = "Group" in type(cv).__name__ + groups = group if use_groups else None # Generate the training/testing visualizations for each CV split - for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)): + for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=groups)): # Fill in indices with the training/test groups indices = np.array([np.nan] * len(X)) indices[tt] = 1