Skip to content

[MRG + 1] FIX Validate and convert X, y and groups to ndarray before splitting #7593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Nov 3, 2016

Conversation

raghavrv
Copy link
Member

@raghavrv raghavrv commented Oct 6, 2016

Fixes #7582 and #7126

At sklearn 0.18.0

>>> from sklearn.model_selection import train_test_split
>>> X, y = [[1,], [2,], [3,], [4,], [5,], [6,]], ['1', '2', '1', '2', '1', '2']
>>> _ = train_test_split(X, y, stratify=y)
IndexError: index 0 is out of bounds for axis 1 with size 0

That is fixed after this PR.

This PR also cleans up some docstrings and adds test for LeavePGroupsOut and LeaveOneGroupOut...

@jnothman @amueller @lesteve Reviews please :)

@@ -1708,3 +1714,23 @@ def _build_repr(self):
params[key] = value

return '%s(%s)' % (class_name, _pprint(params, offset=len(class_name)))


def _check_X_y_groups(X, y, groups):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this reside inside utils.validation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably. Is the same applicable for sample_weights? What do we usually do with sample_weights?
We might just write check_X_y and then do a check_consistent_length(X, sample_weights) and check_array(sample_weights).

allow_nd=True)
check_consistent_length(X, y)
if groups is not None:
groups = check_array(groups, accept_sparse=['coo', 'csr', 'csc'],
Copy link
Member

@amueller amueller Oct 7, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groups can be infinite? and sparse? and nd? Is that tested? ;)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cannot be sparse, surely.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or nd

dtype=None, force_all_finite=False, ensure_2d=False,
allow_nd=True)
if y is not None:
y = check_array(y, accept_sparse=['coo', 'csr', 'csc'],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for y. Are these tested? Should they be? I guess we should be as loose as possible with the test as long as the cross-validation classes work.

@raghavrv
Copy link
Member Author

raghavrv commented Oct 8, 2016

There is a test for train_test_split which tests support for nd arrays... And we cannot allow nd only there as it uses ShuffleSplit internally... You are correct, groups cannot have nan or be nd but they can be sparse I think...

And we could do the check_X_y followed by checks for groups, but it doesnt allow a None for y

@raghavrv
Copy link
Member Author

Okay, I did away with the helper and made a case to case minimial validation for y and groups. For X, indexability is alone checked. One more pass @jnothman @amueller please!

@@ -843,6 +877,20 @@ def test_shufflesplit_reproducible():
list(a for a, b in ss.split(X)))


def test_shufflesplit_list_input():
# Check that when y is a list / list of string labels, it works.
ss = ShuffleSplit(random_state=42)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't that be StratifiedShuffleSplit?

@@ -1087,6 +1091,8 @@ def __init__(self, n_splits=5, test_size=0.2, train_size=None,
def _iter_indices(self, X, y, groups):
if groups is None:
raise ValueError("The groups parameter should not be None")
groups = check_array(groups, ensure_2d=False, dtype=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about GroupKFold, LeaveOneGroupOut, LeavePGroupsOut?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed... Thanks for the catch!!

@raghavrv
Copy link
Member Author

I fixed #7126 along the way... One more look at this @amueller @jnothman

@raghavrv
Copy link
Member Author

Argh. There seemed to have been no tests for LeavePGroupsOut and LeaveOneGroupOut in the old/new tests... Have added them too...

@@ -891,6 +901,8 @@ def get_n_splits(self, X, y, groups):
"""
if groups is None:
raise ValueError("The groups parameter should not be None")
X, y, groups = indexable(X, y, groups)
groups = check_array(groups, ensure_2d=False, dtype=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd to it the other way around, I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_array followed by indexable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah

Copy link
Member

@amueller amueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good apart from some nitpicks.


for j, (cv, p_groups_out) in enumerate(((logo, 1), (lpgo_1, 1),
(lpgo_2, 2))):
groups = (np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want these to be file-level constants?

logo = LeaveOneGroupOut()
lpgo_1 = LeavePGroupsOut(n_groups=1)
lpgo_2 = LeavePGroupsOut(n_groups=2)
lpgo_3 = LeavePGroupsOut(n_groups=3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this one you only test the repr, right?

[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3],
['1', '1', '1', '1', '2', '2', '2', '3', '3', '3', '3', '3'])

all_n_splits = np.array([[3, 3, 3],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you hard-code it like this? that seems hard to validate. It's just scipy.misc.comb(len(np.unique(groups_i)), p_groups_out) right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just scipy.misc.comb(len(np.unique(groups_i)), p_groups_out) right

That is the implementation in _split.py. I thought it would be better to compare it against hand calculated values?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. The correctness of your "hand calculated values" is not immediately obvious to me.
How about

n_groups = len(np.unique(groups_i))
n_splits = n_groups if p_groups_out == 1 else n_groups * (n_groups - 1) / 2 ?

but I'm also fine leaving it like it is.
Why is all_n_splits of length 7 when groups is of length 6? (or github shows me a weird diff)

# First test: no train group is in the test set and vice versa
grps_train_unique = np.unique(groups_arr[train])
grps_test_unique = np.unique(groups_arr[test])
assert_false(np.any(np.in1d(groups_arr[train],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not test the intersection is empty?
assert_equal(set(groups_arr[train]).intersection(groups_arr[test]), set())

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(or intersect1d if you prefer)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Thanks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait that is already done in the next 2 lines...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

("third test")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

third tests checks whether indices are disjoint, my code checks if the groups are disjoint.

grps_train_unique)))

# Second test: train and test add up to all the data
assert_equal(groups_arr[train].size +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len(train) + len(test) = len(groups)?

@amueller amueller changed the title [MRG] FIX Validate and convert X, y and groups to ndarray before splitting [MRG + 1] FIX Validate and convert X, y and groups to ndarray before splitting Oct 19, 2016
@amueller
Copy link
Member

lgtm apart from minor comments

@raghavrv
Copy link
Member Author

Have addressed your comments. A 2nd look please? @jnothman @vene @TomDLT ?

@raghavrv raghavrv changed the title [MRG + 1] FIX Validate and convert X, y and groups to ndarray before splitting [MRG + 2] FIX Validate and convert X, y and groups to ndarray before splitting Oct 20, 2016
@raghavrv raghavrv changed the title [MRG + 2] FIX Validate and convert X, y and groups to ndarray before splitting [MRG + 1] FIX Validate and convert X, y and groups to ndarray before splitting Oct 20, 2016
@amueller
Copy link
Member

travis fails?

@raghavrv
Copy link
Member Author

Sorry about that. Should be fixed now...

np.testing.assert_equal(y_train2, y_train3)
np.testing.assert_equal(X_test1, X_test3)
np.testing.assert_equal(y_test3, y_test2)
for stratify in ((y1, y2, y3), (None, None, None)):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this seem okay? @jnothman @amueller

@raghavrv
Copy link
Member Author

raghavrv commented Nov 3, 2016

Apologies for the delay! Have rebased and added the test... Could you check if it's okay?


for stratify in ((y1, y2, y3), (None, None, None)):
X_train1, X_test1, y_train1, y_test1 = train_test_split(
X, y1, stratify=stratify[0], random_state=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think stratify=y1 if stratify else None would be more readable (where stratify in (True, False) is iterated)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

@jnothman
Copy link
Member

jnothman commented Nov 3, 2016

(Maybe we should allow stratify to be an int index into the **args)

@jnothman jnothman merged commit d7c956a into scikit-learn:master Nov 3, 2016
@raghavrv raghavrv deleted the check_X_y_groups branch November 3, 2016 22:57
@raghavrv
Copy link
Member Author

raghavrv commented Nov 3, 2016

Thanks for the patient review and merge!

amueller pushed a commit to amueller/scikit-learn that referenced this pull request Nov 9, 2016
@amueller
Copy link
Member

needs a whatsnew maybe?

sergeyf pushed a commit to sergeyf/scikit-learn that referenced this pull request Feb 28, 2017
Sundrique pushed a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

REGRESSION: StratifiedShuffleSplit errors on list y
4 participants