Skip to content

Commit 10c690f

Browse files
committed
MISC: index arrays with integers
Indexing arrays with floats will be invalid in future releases of numpy and create a DeprecationWarning in numpy 1.8
1 parent 4d73592 commit 10c690f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

sklearn/cross_validation.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ def __init__(self, n, n_iter=3, train_size=.5, test_size=None,
649649
self.n_iter = n_iter
650650
if (isinstance(train_size, numbers.Real) and train_size >= 0.0
651651
and train_size <= 1.0):
652-
self.train_size = ceil(train_size * n)
652+
self.train_size = int(ceil(train_size * n))
653653
elif isinstance(train_size, numbers.Integral):
654654
self.train_size = train_size
655655
else:
@@ -660,7 +660,7 @@ def __init__(self, n, n_iter=3, train_size=.5, test_size=None,
660660
(self.train_size, n))
661661

662662
if isinstance(test_size, numbers.Real) and 0.0 <= test_size <= 1.0:
663-
self.test_size = ceil(test_size * n)
663+
self.test_size = int(ceil(test_size * n))
664664
elif isinstance(test_size, numbers.Integral):
665665
self.test_size = test_size
666666
elif test_size is None:
@@ -881,7 +881,7 @@ def _validate_shuffle_split(n, test_size, train_size):
881881
'samples %d. Reduce test_size and/or '
882882
'train_size.' % (n_train + n_test, n))
883883

884-
return n_train, n_test
884+
return int(n_train), int(n_test)
885885

886886

887887
def _validate_stratified_shuffle_split(y, test_size, train_size):

0 commit comments

Comments
 (0)