Skip to content

Stratified splitters user warnings #28628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

myenugula
Copy link
Contributor

Reference Issues/PRs

#28609

What does this implement/fix? Explain your changes.

Implement a warning in StratifiedKFold, StratifiedGroupKFold, and StratifiedShuffleSplit to alert users when only a single class is present in the target variable, suggesting that stratified splitting might not be appropriate and guiding towards more suitable cross-validation strategies.

Any other comments?

Here's a code sample of using the 3 classes, and it shows a user warning accordingly.

import numpy as np
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold, StratifiedShuffleSplit
import warnings

# Function to suppress the traceback of warnings
def custom_formatwarning(msg, *args, **kwargs):
    return str(msg) + '\n'

warnings.formatwarning = custom_formatwarning
warnings.filterwarnings('default', category=UserWarning)

# StratifiedKFold example
print("StratifiedKFold Example:")
X = np.random.rand(10, 2)  # 10 samples, 2 features
y_skf = np.zeros(10)  # Target variable with one class
skf = StratifiedKFold(n_splits=3)
for train_index, test_index in skf.split(X, y_skf):
    print(f"TRAIN indices: {train_index}, TEST indices: {test_index}")

# StratifiedGroupKFold example
print("\nStratifiedGroupKFold Example:")
X = np.random.rand(15, 2)  # 15 samples, 2 features
y_sgkf = np.ones(15)  # Target variable with one class
groups = np.array([1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7])
sgkf = StratifiedGroupKFold(n_splits=3)
for train_index, test_index in sgkf.split(X, y_sgkf, groups=groups):
    print(f"TRAIN indices: {train_index}, TEST indices: {test_index}")

# StratifiedShuffleSplit example
print("\nStratifiedShuffleSplit Example:")
X = np.random.rand(20, 2)  # 20 samples, 2 features
y_sss = np.zeros(20)  # Target variable with one class
sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=42)
for train_index, test_index in sss.split(X, y_sss):
    print(f"TRAIN indices: {train_index}, TEST indices: {test_index}")

Output:

Only one class present in y. StratifiedKFold is designed to be used with data that contains two or more classes. Consider using KFold instead.
Only one class present in y. StratifiedGroupKFold is designed to be used with data that contains two or more classes. The single-class scenario might not be suitable for stratified folds.
Only one class present in y. StratifiedShuffleSplit is designed to be used with data that contains two or more classes. Consider using ShuffleSplit instead.

Copy link

github-actions bot commented Mar 14, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: c21608d. Link to the linter CI: here

@oasidorshin
Copy link

@myenugula Also dont forget train_test_split with stratify=True and RepeatedStratifiedKFold

@myenugula
Copy link
Contributor Author

@oasidorshin The RepeatedStratifiedKFold class uses StratifiedKFold internally to split, that's how the warning gets raised. ex:

import numpy as np
from sklearn.model_selection import RepeatedStratifiedKFold
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([0, 0, 0, 0])
rskf = RepeatedStratifiedKFold(n_splits=2, n_repeats=1)

for train, test in rskf.split(X, y):
    pass

Output:

/Users/myenugula/PycharmProjects/scikit-learn/sklearn/model_selection/_split.py:850: UserWarning: Only one class present in y. StratifiedKFold is designed to be used with data that contains two or more classes. Consider using KFold instead.

Similarly, train_test_split uses StratifiedShuffleSplit to split the data. ex:

import numpy as np
from sklearn.model_selection import train_test_split
X, y = np.arange(10).reshape((5, 2)), range(5)
train_test_split(y, shuffle=True, stratify=[0, 0, 0, 0, 0])

Output:

/Users/myenugula/PycharmProjects/scikit-learn/sklearn/model_selection/_split.py:2343: UserWarning: Only one class present in y. StratifiedShuffleSplit is designed to be used with data that contains two or more classes. Consider using ShuffleSplit instead.

However, I'm thinking of replacing

        if stratify is not None:
            CVClass = StratifiedShuffleSplit
        else:
            CVClass = ShuffleSplit

with

        n_classes = len(np.unique(stratify))
        if stratify is not None and n_classes > 1:
            CVClass = StratifiedShuffleSplit
        else:
            CVClass = ShuffleSplit

in the train_test_split to avoid the user warning in cases when only one class is passed

@oasidorshin
Copy link

@myenugula Thank you, looks good!

I think that

        n_classes = len(np.unique(stratify))
        if stratify is not None and n_classes > 1:
            CVClass = StratifiedShuffleSplit
        else:
            CVClass = ShuffleSplit

is the way to go

@myenugula
Copy link
Contributor Author

Any update on this?

@VyankateshRohokale
Copy link

Is there something more you want to add in this issue ? maybe i could help

@myenugula
Copy link
Contributor Author

Hello any update ?

2 similar comments
@myenugula
Copy link
Contributor Author

Hello any update ?

@myenugula
Copy link
Contributor Author

Hello any update ?

@myenugula
Copy link
Contributor Author

Hi @lesteve, Could you please review this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants