18
18
datasets = {'imdb' : (imdb .load_data , imdb .prepare_data )}
19
19
20
20
21
- def get_minibatches_idx (n , nb_batches , shuffle = False ):
21
+ def get_minibatches_idx (n , minibatch_size , shuffle = False ):
22
22
"""
23
23
Used to shuffle the dataset at each iteration.
24
24
"""
@@ -30,17 +30,16 @@ def get_minibatches_idx(n, nb_batches, shuffle=False):
30
30
31
31
minibatches = []
32
32
minibatch_start = 0
33
- for i in range (nb_batches ):
34
- if i < n % nb_batches :
35
- minibatch_size = n // nb_batches + 1
36
- else :
37
- minibatch_size = n // nb_batches
38
-
33
+ for i in range (n // minibatch_size ):
39
34
minibatches .append (idx_list [minibatch_start :
40
35
minibatch_start + minibatch_size ])
41
36
minibatch_start += minibatch_size
42
37
43
- return zip (range (nb_batches ), minibatches )
38
+ if (minibatch_start != n ):
39
+ # Make a minibatch out of what is left
40
+ minibatches .append (idx_list [minibatch_start :])
41
+
42
+ return zip (range (len (minibatches )), minibatches )
44
43
45
44
46
45
def get_dataset (name ):
@@ -446,11 +445,9 @@ def test_lstm(
446
445
447
446
print 'Optimization'
448
447
449
- kf_valid = get_minibatches_idx (len (valid [0 ]),
450
- len (valid [0 ]) / valid_batch_size ,
448
+ kf_valid = get_minibatches_idx (len (valid [0 ]), valid_batch_size ,
451
449
shuffle = True )
452
- kf_test = get_minibatches_idx (len (test [0 ]),
453
- len (test [0 ]) / valid_batch_size ,
450
+ kf_test = get_minibatches_idx (len (test [0 ]), valid_batch_size ,
454
451
shuffle = True )
455
452
456
453
history_errs = []
@@ -469,8 +466,7 @@ def test_lstm(
469
466
n_samples = 0
470
467
471
468
# Get new shuffled index for the training set.
472
- kf = get_minibatches_idx (len (train [0 ]), len (train [0 ])/ batch_size ,
473
- shuffle = True )
469
+ kf = get_minibatches_idx (len (train [0 ]), batch_size , shuffle = True )
474
470
475
471
for _ , train_index in kf :
476
472
uidx += 1
0 commit comments