Skip to content

Commit 3001e6d

Browse files
committed
FIX make shuffle / resample pass-through indexing utilities
1 parent f0f4c79 commit 3001e6d

File tree

2 files changed

+47
-17
lines changed

2 files changed

+47
-17
lines changed

sklearn/utils/__init__.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ def resample(*arrays, **options):
169169
170170
Parameters
171171
----------
172-
*arrays : sequence of arrays or scipy.sparse matrices with same shape[0]
172+
*arrays : sequence of indexable data-structures
173+
Indexable data-structures can be arrays, lists, dataframes or scipy
174+
sparse matrices with consistent first dimension.
173175
174176
replace : boolean, True by default
175177
Implements resampling with replacement. If False, this will implement
@@ -184,16 +186,15 @@ def resample(*arrays, **options):
184186
185187
Returns
186188
-------
187-
resampled_arrays : sequence of arrays or scipy.sparse matrices with same \
188-
shape[0]
189-
Sequence of resampled views of the collections. The original arrays are
189+
resampled_arrays : sequence of indexable data-structures
190+
Sequence of resampled views of the collections. The original arrays are
190191
not impacted.
191192
192193
Examples
193194
--------
194195
It is possible to mix sparse and dense arrays in the same run::
195196
196-
>>> X = [[1., 0.], [2., 1.], [0., 0.]]
197+
>>> X = np.array([[1., 0.], [2., 1.], [0., 0.]])
197198
>>> y = np.array([0, 1, 2])
198199
199200
>>> from scipy.sparse import coo_matrix
@@ -247,8 +248,6 @@ def resample(*arrays, **options):
247248
max_n_samples, n_samples))
248249

249250
check_consistent_length(*arrays)
250-
arrays = [check_array(x, accept_sparse='csr', ensure_2d=False,
251-
allow_nd=True) for x in arrays]
252251

253252
if replace:
254253
indices = random_state.randint(0, n_samples, size=(max_n_samples,))
@@ -257,12 +256,9 @@ def resample(*arrays, **options):
257256
random_state.shuffle(indices)
258257
indices = indices[:max_n_samples]
259258

260-
resampled_arrays = []
261-
262-
for array in arrays:
263-
array = array[indices]
264-
resampled_arrays.append(array)
265-
259+
# convert sparse matrices to CSR for row-based indexing
260+
arrays = [a.tocsr() if issparse(a) else a for a in arrays]
261+
resampled_arrays = [safe_indexing(a, indices) for a in arrays]
266262
if len(resampled_arrays) == 1:
267263
# syntactic sugar for the unit argument case
268264
return resampled_arrays[0]
@@ -278,7 +274,9 @@ def shuffle(*arrays, **options):
278274
279275
Parameters
280276
----------
281-
*arrays : sequence of arrays or scipy.sparse matrices with same shape[0]
277+
*arrays : sequence of indexable data-structures
278+
Indexable data-structures can be arrays, lists, dataframes or scipy
279+
sparse matrices with consistent first dimension.
282280
283281
random_state : int or RandomState instance
284282
Control the shuffling for reproducible behavior.
@@ -289,16 +287,15 @@ def shuffle(*arrays, **options):
289287
290288
Returns
291289
-------
292-
shuffled_arrays : sequence of arrays or scipy.sparse matrices with same \
293-
shape[0]
290+
shuffled_arrays : sequence of indexable data-structures
294291
Sequence of shuffled views of the collections. The original arrays are
295292
not impacted.
296293
297294
Examples
298295
--------
299296
It is possible to mix sparse and dense arrays in the same run::
300297
301-
>>> X = [[1., 0.], [2., 1.], [0., 0.]]
298+
>>> X = np.array([[1., 0.], [2., 1.], [0., 0.]])
302299
>>> y = np.array([0, 1, 2])
303300
304301
>>> from scipy.sparse import coo_matrix

sklearn/utils/tests/test_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,36 @@ def to_tuple(A): # to make the inner arrays hashable
186186
S = set(to_tuple(A))
187187
shuffle(A) # shouldn't raise a ValueError for dim = 3
188188
assert_equal(set(to_tuple(A)), S)
189+
190+
191+
def test_shuffle_dont_convert_to_array():
192+
# Check that shuffle does not try to convert to numpy arrays with float
193+
# dtypes can let any indexable datastructure pass-through.
194+
a = ['a', 'b', 'c']
195+
b = np.array(['a', 'b', 'c'], dtype=object)
196+
c = [1, 2, 3]
197+
d = MockDataFrame(np.array([['a', 0],
198+
['b', 1],
199+
['c', 2]],
200+
dtype=object))
201+
e = sp.csc_matrix(np.arange(6).reshape(3, 2))
202+
a_s, b_s, c_s, d_s, e_s = shuffle(a, b, c, d, e, random_state=0)
203+
204+
assert_equal(a_s, ['c', 'b', 'a'])
205+
assert_equal(type(a_s), list)
206+
207+
assert_array_equal(b_s, ['c', 'b', 'a'])
208+
assert_equal(b_s.dtype, object)
209+
210+
assert_equal(c_s, [3, 2, 1])
211+
assert_equal(type(c_s), list)
212+
213+
assert_array_equal(d_s, np.array([['c', 2],
214+
['b', 1],
215+
['a', 0]],
216+
dtype=object))
217+
assert_equal(type(d_s), MockDataFrame)
218+
219+
assert_array_equal(e_s.toarray(), np.array([[4, 5],
220+
[2, 3],
221+
[0, 1]]))

0 commit comments

Comments
 (0)