Skip to content

Commit 1f781e6

Browse files
nelson-liujnothman
authored andcommitted
[MRG+3] CV splitters: train/test_size default behavior will change in 0.21 (scikit-learn#7459)
1 parent 7ab0a96 commit 1f781e6

File tree

3 files changed

+96
-38
lines changed

3 files changed

+96
-38
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,11 @@ API changes summary
454454
method ``check_decision_proba_consistency`` has been added in
455455
**sklearn.utils.estimator_checks** to check their consistency.
456456
:issue:`7578` by :user:`Shubham Bhardwaj <shubham0704>`
457+
458+
- In version 0.21, the default behavior of splitters that use the
459+
    ``test_size`` and ``train_size`` parameter will change, such that
460+
specifying ``train_size`` alone will cause ``test_size`` to be the
461+
remainder. :issue:`7459` by :user:`Nelson Liu <nelson-liu>`.
457462

458463
- All tree based estimators now accept a ``min_impurity_decrease``
459464
parameter in lieu of the ``min_impurity_split``, which is now deprecated.
@@ -506,7 +511,6 @@ API changes summary
506511
- ``utils.stats.rankdata``
507512
- ``neighbors.approximate.LSHForest``
508513

509-
510514
.. _changes_0_18_1:
511515

512516
Version 0.18.1

sklearn/model_selection/_split.py

Lines changed: 77 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ def __init__(self, n_splits=5, n_repeats=10, random_state=None):
11331133
class BaseShuffleSplit(with_metaclass(ABCMeta)):
11341134
"""Base class for ShuffleSplit and StratifiedShuffleSplit"""
11351135

1136-
def __init__(self, n_splits=10, test_size=0.1, train_size=None,
1136+
def __init__(self, n_splits=10, test_size="default", train_size=None,
11371137
random_state=None):
11381138
_validate_shuffle_split_init(test_size, train_size)
11391139
self.n_splits = n_splits
@@ -1211,16 +1211,20 @@ class ShuffleSplit(BaseShuffleSplit):
12111211
12121212
Parameters
12131213
----------
1214-
n_splits : int (default 10)
1214+
n_splits : int, default 10
12151215
Number of re-shuffling & splitting iterations.
12161216
1217-
test_size : float, int, or None, default 0.1
1218-
If float, should be between 0.0 and 1.0 and represent the
1219-
proportion of the dataset to include in the test split. If
1220-
int, represents the absolute number of test samples. If None,
1221-
the value is automatically set to the complement of the train size.
1222-
1223-
train_size : float, int, or None (default is None)
1217+
test_size : float, int, None, default=0.1
1218+
If float, should be between 0.0 and 1.0 and represent the proportion
1219+
of the dataset to include in the test split. If int, represents the
1220+
absolute number of test samples. If None, the value is set to the
1221+
complement of the train size. By default (the is parameter
1222+
unspecified), the value is set to 0.1.
1223+
The default will change in version 0.21. It will remain 0.1 only
1224+
if ``train_size`` is unspecified, otherwise it will complement
1225+
the specified ``train_size``.
1226+
1227+
train_size : float, int, or None, default=None
12241228
If float, should be between 0.0 and 1.0 and represent the
12251229
proportion of the dataset to include in the train split. If
12261230
int, represents the absolute number of train samples. If None,
@@ -1260,7 +1264,8 @@ class ShuffleSplit(BaseShuffleSplit):
12601264

12611265
def _iter_indices(self, X, y=None, groups=None):
12621266
n_samples = _num_samples(X)
1263-
n_train, n_test = _validate_shuffle_split(n_samples, self.test_size,
1267+
n_train, n_test = _validate_shuffle_split(n_samples,
1268+
self.test_size,
12641269
self.train_size)
12651270
rng = check_random_state(self.random_state)
12661271
for i in range(self.n_splits):
@@ -1299,13 +1304,16 @@ class GroupShuffleSplit(ShuffleSplit):
12991304
n_splits : int (default 5)
13001305
Number of re-shuffling & splitting iterations.
13011306
1302-
test_size : float (default 0.2), int, or None
1303-
If float, should be between 0.0 and 1.0 and represent the
1304-
proportion of the groups to include in the test split. If
1305-
int, represents the absolute number of test groups. If None,
1306-
the value is automatically set to the complement of the train size.
1307+
test_size : float, int, None, optional
1308+
If float, should be between 0.0 and 1.0 and represent the proportion
1309+
of the dataset to include in the test split. If int, represents the
1310+
absolute number of test samples. If None, the value is set to the
1311+
complement of the train size. By default, the value is set to 0.2.
1312+
The default will change in version 0.21. It will remain 0.2 only
1313+
if ``train_size`` is unspecified, otherwise it will complement
1314+
the specified ``train_size``.
13071315
1308-
train_size : float, int, or None (default is None)
1316+
train_size : float, int, or None, default is None
13091317
If float, should be between 0.0 and 1.0 and represent the
13101318
proportion of the groups to include in the train split. If
13111319
int, represents the absolute number of train groups. If None,
@@ -1319,8 +1327,16 @@ class GroupShuffleSplit(ShuffleSplit):
13191327
13201328
'''
13211329

1322-
def __init__(self, n_splits=5, test_size=0.2, train_size=None,
1330+
def __init__(self, n_splits=5, test_size="default", train_size=None,
13231331
random_state=None):
1332+
if test_size == "default":
1333+
if train_size is not None:
1334+
warnings.warn("From version 0.21, test_size will always "
1335+
"complement train_size unless both "
1336+
"are specified.",
1337+
FutureWarning)
1338+
test_size = 0.2
1339+
13241340
super(GroupShuffleSplit, self).__init__(
13251341
n_splits=n_splits,
13261342
test_size=test_size,
@@ -1428,16 +1444,19 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
14281444
14291445
Parameters
14301446
----------
1431-
n_splits : int (default 10)
1447+
n_splits : int, default 10
14321448
Number of re-shuffling & splitting iterations.
14331449
1434-
test_size : float (default 0.1), int, or None
1435-
If float, should be between 0.0 and 1.0 and represent the
1436-
proportion of the dataset to include in the test split. If
1437-
int, represents the absolute number of test samples. If None,
1438-
the value is automatically set to the complement of the train size.
1450+
test_size : float, int, None, optional
1451+
If float, should be between 0.0 and 1.0 and represent the proportion
1452+
of the dataset to include in the test split. If int, represents the
1453+
absolute number of test samples. If None, the value is set to the
1454+
complement of the train size. By default, the value is set to 0.1.
1455+
The default will change in version 0.21. It will remain 0.1 only
1456+
if ``train_size`` is unspecified, otherwise it will complement
1457+
the specified ``train_size``.
14391458
1440-
train_size : float, int, or None (default is None)
1459+
train_size : float, int, or None, default is None
14411460
If float, should be between 0.0 and 1.0 and represent the
14421461
proportion of the dataset to include in the train split. If
14431462
int, represents the absolute number of train samples. If None,
@@ -1468,7 +1487,7 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
14681487
TRAIN: [0 2] TEST: [3 1]
14691488
"""
14701489

1471-
def __init__(self, n_splits=10, test_size=0.1, train_size=None,
1490+
def __init__(self, n_splits=10, test_size="default", train_size=None,
14721491
random_state=None):
14731492
super(StratifiedShuffleSplit, self).__init__(
14741493
n_splits, test_size, train_size, random_state)
@@ -1563,6 +1582,14 @@ def _validate_shuffle_split_init(test_size, train_size):
15631582
NOTE This does not take into account the number of samples which is known
15641583
only at split
15651584
"""
1585+
if test_size == "default":
1586+
if train_size is not None:
1587+
warnings.warn("From version 0.21, test_size will always "
1588+
"complement train_size unless both "
1589+
"are specified.",
1590+
FutureWarning)
1591+
test_size = 0.1
1592+
15661593
if test_size is None and train_size is None:
15671594
raise ValueError('test_size and train_size can not both be None')
15681595

@@ -1597,16 +1624,21 @@ def _validate_shuffle_split(n_samples, test_size, train_size):
15971624
Validation helper to check if the test/test sizes are meaningful wrt to the
15981625
size of the data (n_samples)
15991626
"""
1600-
if (test_size is not None and np.asarray(test_size).dtype.kind == 'i' and
1627+
if (test_size is not None and
1628+
np.asarray(test_size).dtype.kind == 'i' and
16011629
test_size >= n_samples):
16021630
raise ValueError('test_size=%d should be smaller than the number of '
16031631
'samples %d' % (test_size, n_samples))
16041632

1605-
if (train_size is not None and np.asarray(train_size).dtype.kind == 'i' and
1633+
if (train_size is not None and
1634+
np.asarray(train_size).dtype.kind == 'i' and
16061635
train_size >= n_samples):
16071636
raise ValueError("train_size=%d should be smaller than the number of"
16081637
" samples %d" % (train_size, n_samples))
16091638

1639+
if test_size == "default":
1640+
test_size = 0.1
1641+
16101642
if np.asarray(test_size).dtype.kind == 'f':
16111643
n_test = ceil(test_size * n_samples)
16121644
elif np.asarray(test_size).dtype.kind == 'i':
@@ -1844,14 +1876,16 @@ def train_test_split(*arrays, **options):
18441876
Allowed inputs are lists, numpy arrays, scipy-sparse
18451877
matrices or pandas dataframes.
18461878
1847-
test_size : float, int, or None (default is None)
1848-
If float, should be between 0.0 and 1.0 and represent the
1849-
proportion of the dataset to include in the test split. If
1850-
int, represents the absolute number of test samples. If None,
1851-
the value is automatically set to the complement of the train size.
1852-
If train size is also None, test size is set to 0.25.
1879+
test_size : float, int, None, optional
1880+
If float, should be between 0.0 and 1.0 and represent the proportion
1881+
of the dataset to include in the test split. If int, represents the
1882+
absolute number of test samples. If None, the value is set to the
1883+
complement of the train size. By default, the value is set to 0.25.
1884+
The default will change in version 0.21. It will remain 0.25 only
1885+
if ``train_size`` is unspecified, otherwise it will complement
1886+
the specified ``train_size``.
18531887
1854-
train_size : float, int, or None (default is None)
1888+
train_size : float, int, or None, default None
18551889
If float, should be between 0.0 and 1.0 and represent the
18561890
proportion of the dataset to include in the train split. If
18571891
int, represents the absolute number of train samples. If None,
@@ -1917,7 +1951,7 @@ def train_test_split(*arrays, **options):
19171951
n_arrays = len(arrays)
19181952
if n_arrays == 0:
19191953
raise ValueError("At least one array required as input")
1920-
test_size = options.pop('test_size', None)
1954+
test_size = options.pop('test_size', 'default')
19211955
train_size = options.pop('train_size', None)
19221956
random_state = options.pop('random_state', None)
19231957
stratify = options.pop('stratify', None)
@@ -1926,6 +1960,14 @@ def train_test_split(*arrays, **options):
19261960
if options:
19271961
raise TypeError("Invalid parameters passed: %s" % str(options))
19281962

1963+
if test_size == 'default':
1964+
test_size = None
1965+
if train_size is not None:
1966+
warnings.warn("From version 0.21, test_size will always "
1967+
"complement train_size unless both "
1968+
"are specified.",
1969+
FutureWarning)
1970+
19291971
if test_size is None and train_size is None:
19301972
test_size = 0.25
19311973

sklearn/model_selection/tests/test_split.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sklearn.utils.testing import assert_array_almost_equal
2121
from sklearn.utils.testing import assert_array_equal
2222
from sklearn.utils.testing import assert_warns_message
23+
from sklearn.utils.testing import assert_warns
2324
from sklearn.utils.testing import assert_raise_message
2425
from sklearn.utils.testing import ignore_warnings
2526
from sklearn.utils.validation import _num_samples
@@ -163,8 +164,8 @@ def test_cross_validator_with_default_params():
163164
skf_repr = "StratifiedKFold(n_splits=2, random_state=None, shuffle=False)"
164165
lolo_repr = "LeaveOneGroupOut()"
165166
lopo_repr = "LeavePGroupsOut(n_groups=2)"
166-
ss_repr = ("ShuffleSplit(n_splits=10, random_state=0, test_size=0.1, "
167-
"train_size=None)")
167+
ss_repr = ("ShuffleSplit(n_splits=10, random_state=0, "
168+
"test_size='default',\n train_size=None)")
168169
ps_repr = "PredefinedSplit(test_fold=array([1, 1, 2, 2]))"
169170

170171
n_splits_expected = [n_samples, comb(n_samples, p), n_splits, n_splits,
@@ -527,6 +528,7 @@ def test_shuffle_split():
527528
assert_array_equal(t3[1], t4[1])
528529

529530

531+
@ignore_warnings
530532
def test_stratified_shuffle_split_init():
531533
X = np.arange(7)
532534
y = np.asarray([0, 1, 1, 1, 2, 2, 2])
@@ -859,6 +861,7 @@ def test_leave_one_p_group_out_error_on_fewer_number_of_groups():
859861
LeavePGroupsOut(n_groups=3).split(X, y, groups))
860862

861863

864+
@ignore_warnings
862865
def test_repeated_cv_value_errors():
863866
# n_repeats is not integer or <= 0
864867
for cv in (RepeatedKFold, RepeatedStratifiedKFold):
@@ -1070,6 +1073,7 @@ def train_test_split_list_input():
10701073
np.testing.assert_equal(y_test3, y_test2)
10711074

10721075

1076+
@ignore_warnings
10731077
def test_shufflesplit_errors():
10741078
# When the {test|train}_size is a float/invalid, error is raised at init
10751079
assert_raises(ValueError, ShuffleSplit, test_size=None, train_size=None)
@@ -1366,6 +1370,14 @@ def test_nested_cv():
13661370
fit_params={'groups': groups})
13671371

13681372

1373+
def test_train_test_default_warning():
1374+
assert_warns(FutureWarning, ShuffleSplit, train_size=0.75)
1375+
assert_warns(FutureWarning, GroupShuffleSplit, train_size=0.75)
1376+
assert_warns(FutureWarning, StratifiedShuffleSplit, train_size=0.75)
1377+
assert_warns(FutureWarning, train_test_split, range(3),
1378+
train_size=0.75)
1379+
1380+
13691381
def test_build_repr():
13701382
class MockSplitter:
13711383
def __init__(self, a, b=0, c=None):

0 commit comments

Comments
 (0)