@@ -74,8 +74,9 @@ def get_dataset_file(dataset, default_dataset, origin):
74
74
return dataset
75
75
76
76
77
- def load_data (path = "imdb.pkl" , n_words = 100000 , valid_portion = 0.1 , maxlen = None ):
78
- ''' Loads the dataset
77
+ def load_data (path = "imdb.pkl" , n_words = 100000 , valid_portion = 0.1 , maxlen = None ,
78
+ sort_by_len = True ):
79
+ '''Loads the dataset
79
80
80
81
:type path: String
81
82
:param path: The path to the dataset (here IMDB)
@@ -87,6 +88,12 @@ def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1, maxlen=None):
87
88
the validation set.
88
89
:type maxlen: None or positive int
89
90
:param maxlen: the max sequence length we use in the train/valid set.
91
+ :type sort_by_len: bool
92
+ :name sort_by_len: Sort by the sequence lenght for the train,
93
+ valid and test set. This allow faster execution as it cause
94
+ less padding per minibatch. Another mechanism must be used to
95
+ shuffle the train set at each epoch.
96
+
90
97
'''
91
98
92
99
#############
@@ -140,6 +147,22 @@ def remove_unk(x):
140
147
valid_set_x = remove_unk (valid_set_x )
141
148
test_set_x = remove_unk (test_set_x )
142
149
150
+ def len_argsort (seq ):
151
+ return sorted (range (len (seq )), key = lambda x : len (seq [x ]))
152
+
153
+ if sort_by_len :
154
+ sorted_index = len_argsort (test_set_x )
155
+ test_set_x = [test_set_x [i ] for i in sorted_index ]
156
+ test_set_y = [test_set_y [i ] for i in sorted_index ]
157
+
158
+ sorted_index = len_argsort (valid_set_x )
159
+ valid_set_x = [valid_set_x [i ] for i in sorted_index ]
160
+ valid_set_y = [valid_set_y [i ] for i in sorted_index ]
161
+
162
+ sorted_index = len_argsort (train_set_x )
163
+ train_set_x = [train_set_x [i ] for i in sorted_index ]
164
+ train_set_y = [train_set_y [i ] for i in sorted_index ]
165
+
143
166
train = (train_set_x , train_set_y )
144
167
valid = (valid_set_x , valid_set_y )
145
168
test = (test_set_x , test_set_y )
0 commit comments