-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG + 1] FIX Validate and convert X, y and groups to ndarray before splitting #7593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -1708,3 +1714,23 @@ def _build_repr(self): | |||
params[key] = value | |||
|
|||
return '%s(%s)' % (class_name, _pprint(params, offset=len(class_name))) | |||
|
|||
|
|||
def _check_X_y_groups(X, y, groups): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this reside inside utils.validation
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably. Is the same applicable for sample_weights
? What do we usually do with sample_weights?
We might just write check_X_y
and then do a check_consistent_length(X, sample_weights)
and check_array(sample_weights)
.
d45b75c
to
ff5f379
Compare
allow_nd=True) | ||
check_consistent_length(X, y) | ||
if groups is not None: | ||
groups = check_array(groups, accept_sparse=['coo', 'csr', 'csc'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
groups can be infinite? and sparse? and nd? Is that tested? ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cannot be sparse, surely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or nd
dtype=None, force_all_finite=False, ensure_2d=False, | ||
allow_nd=True) | ||
if y is not None: | ||
y = check_array(y, accept_sparse=['coo', 'csr', 'csc'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for y. Are these tested? Should they be? I guess we should be as loose as possible with the test as long as the cross-validation classes work.
There is a test for And we could do the |
76027d5
to
578442b
Compare
@@ -843,6 +877,20 @@ def test_shufflesplit_reproducible(): | |||
list(a for a, b in ss.split(X))) | |||
|
|||
|
|||
def test_shufflesplit_list_input(): | |||
# Check that when y is a list / list of string labels, it works. | |||
ss = ShuffleSplit(random_state=42) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't that be StratifiedShuffleSplit
?
@@ -1087,6 +1091,8 @@ def __init__(self, n_splits=5, test_size=0.2, train_size=None, | |||
def _iter_indices(self, X, y, groups): | |||
if groups is None: | |||
raise ValueError("The groups parameter should not be None") | |||
groups = check_array(groups, ensure_2d=False, dtype=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about GroupKFold
, LeaveOneGroupOut
, LeavePGroupsOut
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed... Thanks for the catch!!
Argh. There seemed to have been no tests for |
f117a07
to
13f1e95
Compare
@@ -891,6 +901,8 @@ def get_n_splits(self, X, y, groups): | |||
""" | |||
if groups is None: | |||
raise ValueError("The groups parameter should not be None") | |||
X, y, groups = indexable(X, y, groups) | |||
groups = check_array(groups, ensure_2d=False, dtype=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd to it the other way around, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check_array
followed by indexable
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good apart from some nitpicks.
|
||
for j, (cv, p_groups_out) in enumerate(((logo, 1), (lpgo_1, 1), | ||
(lpgo_2, 2))): | ||
groups = (np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want these to be file-level constants?
logo = LeaveOneGroupOut() | ||
lpgo_1 = LeavePGroupsOut(n_groups=1) | ||
lpgo_2 = LeavePGroupsOut(n_groups=2) | ||
lpgo_3 = LeavePGroupsOut(n_groups=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for this one you only test the repr, right?
[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3], | ||
['1', '1', '1', '1', '2', '2', '2', '3', '3', '3', '3', '3']) | ||
|
||
all_n_splits = np.array([[3, 3, 3], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you hard-code it like this? that seems hard to validate. It's just scipy.misc.comb(len(np.unique(groups_i)), p_groups_out)
right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just scipy.misc.comb(len(np.unique(groups_i)), p_groups_out) right
That is the implementation in _split.py
. I thought it would be better to compare it against hand calculated values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm. The correctness of your "hand calculated values" is not immediately obvious to me.
How about
n_groups = len(np.unique(groups_i))
n_splits = n_groups if p_groups_out == 1 else n_groups * (n_groups - 1) / 2 ?
but I'm also fine leaving it like it is.
Why is all_n_splits
of length 7 when groups
is of length 6? (or github shows me a weird diff)
# First test: no train group is in the test set and vice versa | ||
grps_train_unique = np.unique(groups_arr[train]) | ||
grps_test_unique = np.unique(groups_arr[test]) | ||
assert_false(np.any(np.in1d(groups_arr[train], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not test the intersection is empty?
assert_equal(set(groups_arr[train]).intersection(groups_arr[test]), set())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(or intersect1d if you prefer)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait that is already done in the next 2 lines...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
("third test")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
third tests checks whether indices are disjoint, my code checks if the groups are disjoint.
grps_train_unique))) | ||
|
||
# Second test: train and test add up to all the data | ||
assert_equal(groups_arr[train].size + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
len(train) + len(test) = len(groups)
?
lgtm apart from minor comments |
13f1e95
to
fce36af
Compare
b5d1fe3
to
44f6db6
Compare
travis fails? |
Sorry about that. Should be fixed now... |
0516776
to
1ca13d1
Compare
np.testing.assert_equal(y_train2, y_train3) | ||
np.testing.assert_equal(X_test1, X_test3) | ||
np.testing.assert_equal(y_test3, y_test2) | ||
for stratify in ((y1, y2, y3), (None, None, None)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the delay! Have rebased and added the test... Could you check if it's okay? |
|
||
for stratify in ((y1, y2, y3), (None, None, None)): | ||
X_train1, X_test1, y_train1, y_test1 = train_test_split( | ||
X, y1, stratify=stratify[0], random_state=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think stratify=y1 if stratify else None
would be more readable (where stratify in (True, False) is iterated)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done :)
(Maybe we should allow |
Thanks for the patient review and merge! |
needs a whatsnew maybe? |
Fixes #7582 and #7126
At sklearn 0.18.0
That is fixed after this PR.
This PR also cleans up some docstrings and adds test for
LeavePGroupsOut
andLeaveOneGroupOut
...@jnothman @amueller @lesteve Reviews please :)