Skip to content

Commit 6a32f03

Browse files
committed
sort dataset by length to speed up error computation
1 parent ec112fa commit 6a32f03

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

code/imdb.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ def get_dataset_file(dataset, default_dataset, origin):
7474
return dataset
7575

7676

77-
def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1, maxlen=None):
78-
''' Loads the dataset
77+
def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1, maxlen=None,
78+
sort_by_len=True):
79+
'''Loads the dataset
7980
8081
:type path: String
8182
:param path: The path to the dataset (here IMDB)
@@ -87,6 +88,12 @@ def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1, maxlen=None):
8788
the validation set.
8889
:type maxlen: None or positive int
8990
:param maxlen: the max sequence length we use in the train/valid set.
91+
:type sort_by_len: bool
92+
:name sort_by_len: Sort by the sequence lenght for the train,
93+
valid and test set. This allow faster execution as it cause
94+
less padding per minibatch. Another mechanism must be used to
95+
shuffle the train set at each epoch.
96+
9097
'''
9198

9299
#############
@@ -140,6 +147,22 @@ def remove_unk(x):
140147
valid_set_x = remove_unk(valid_set_x)
141148
test_set_x = remove_unk(test_set_x)
142149

150+
def len_argsort(seq):
151+
return sorted(range(len(seq)), key=lambda x: len(seq[x]))
152+
153+
if sort_by_len:
154+
sorted_index = len_argsort(test_set_x)
155+
test_set_x = [test_set_x[i] for i in sorted_index]
156+
test_set_y = [test_set_y[i] for i in sorted_index]
157+
158+
sorted_index = len_argsort(valid_set_x)
159+
valid_set_x = [valid_set_x[i] for i in sorted_index]
160+
valid_set_y = [valid_set_y[i] for i in sorted_index]
161+
162+
sorted_index = len_argsort(train_set_x)
163+
train_set_x = [train_set_x[i] for i in sorted_index]
164+
train_set_y = [train_set_y[i] for i in sorted_index]
165+
143166
train = (train_set_x, train_set_y)
144167
valid = (valid_set_x, valid_set_y)
145168
test = (test_set_x, test_set_y)

code/lstm.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,10 +456,8 @@ def train_lstm(
456456

457457
print 'Optimization'
458458

459-
kf_valid = get_minibatches_idx(len(valid[0]), valid_batch_size,
460-
shuffle=True)
461-
kf_test = get_minibatches_idx(len(test[0]), valid_batch_size,
462-
shuffle=True)
459+
kf_valid = get_minibatches_idx(len(valid[0]), valid_batch_size)
460+
kf_test = get_minibatches_idx(len(test[0]), valid_batch_size)
463461

464462
print "%d train examples" % len(train[0])
465463
print "%d valid examples" % len(valid[0])
@@ -561,7 +559,8 @@ def train_lstm(
561559
best_p = unzip(tparams)
562560

563561
use_noise.set_value(0.)
564-
train_err = pred_error(f_pred, prepare_data, train, kf)
562+
kf_train_sorted = get_minibatches_idx(len(train[0]), batch_size)
563+
train_err = pred_error(f_pred, prepare_data, train, kf_train_sorted)
565564
valid_err = pred_error(f_pred, prepare_data, valid, kf_valid)
566565
test_err = pred_error(f_pred, prepare_data, test, kf_test)
567566

0 commit comments

Comments
 (0)