Skip to content

Commit 6e38783

Browse files
committed
Keep random subset of the test set
1 parent 2f14fbf commit 6e38783

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

code/lstm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,13 @@ def train_lstm(
415415
train, valid, test = load_data(n_words=n_words, valid_portion=0.05,
416416
maxlen=maxlen)
417417
if test_size > 0:
418-
test = (test[0][:test_size], test[1][:test_size])
418+
# The test set is sorted by size, but we want to keep random
419+
# size example. So we must select a random selection of the
420+
# examples.
421+
idx = numpy.arange(len(test[0]))
422+
random.shuffle(idx)
423+
idx = idx[:test_size]
424+
test = ([test[0][n] for n in idx], [test[1][n] for n in idx])
419425

420426
ydim = numpy.max(train[1]) + 1
421427

0 commit comments

Comments
 (0)