Skip to content

Commit c6fdcff

Browse files
committed
Fixed function get_minibatches_idx()
1 parent 194adad commit c6fdcff

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

code/lstm.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
datasets = {'imdb': (imdb.load_data, imdb.prepare_data)}
1919

2020

21-
def get_minibatches_idx(n, nb_batches, shuffle=False):
21+
def get_minibatches_idx(n, minibatch_size, shuffle=False):
2222
"""
2323
Used to shuffle the dataset at each iteration.
2424
"""
@@ -30,17 +30,16 @@ def get_minibatches_idx(n, nb_batches, shuffle=False):
3030

3131
minibatches = []
3232
minibatch_start = 0
33-
for i in range(nb_batches):
34-
if i < n % nb_batches:
35-
minibatch_size = n // nb_batches + 1
36-
else:
37-
minibatch_size = n // nb_batches
38-
33+
for i in range(n // minibatch_size):
3934
minibatches.append(idx_list[minibatch_start:
4035
minibatch_start + minibatch_size])
4136
minibatch_start += minibatch_size
4237

43-
return zip(range(nb_batches), minibatches)
38+
if (minibatch_start != n):
39+
# Make a minibatch out of what is left
40+
minibatches.append(idx_list[minibatch_start:])
41+
42+
return zip(range(len(minibatches)), minibatches)
4443

4544

4645
def get_dataset(name):
@@ -446,11 +445,9 @@ def test_lstm(
446445

447446
print 'Optimization'
448447

449-
kf_valid = get_minibatches_idx(len(valid[0]),
450-
len(valid[0]) / valid_batch_size,
448+
kf_valid = get_minibatches_idx(len(valid[0]), valid_batch_size,
451449
shuffle=True)
452-
kf_test = get_minibatches_idx(len(test[0]),
453-
len(test[0]) / valid_batch_size,
450+
kf_test = get_minibatches_idx(len(test[0]), valid_batch_size,
454451
shuffle=True)
455452

456453
history_errs = []
@@ -469,8 +466,7 @@ def test_lstm(
469466
n_samples = 0
470467

471468
# Get new shuffled index for the training set.
472-
kf = get_minibatches_idx(len(train[0]), len(train[0])/batch_size,
473-
shuffle=True)
469+
kf = get_minibatches_idx(len(train[0]), batch_size, shuffle=True)
474470

475471
for _, train_index in kf:
476472
uidx += 1

0 commit comments

Comments
 (0)