Skip to content

Commit 2c610d3

Browse files
committed
made rnnslu compatible with python 3. tested on cpu for many epochs, but not to completion
1 parent 53f246d commit 2c610d3

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

code/rnnslu.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
2+
from __future__ import print_function
3+
from six.moves import xrange
4+
import six.moves.cPickle as pickle
5+
16
from collections import OrderedDict
27
import copy
3-
import cPickle
48
import gzip
59
import os
610
import urllib
@@ -66,7 +70,10 @@ def atisfold(fold):
6670
assert fold in range(5)
6771
filename = os.path.join(PREFIX, 'atis.fold'+str(fold)+'.pkl.gz')
6872
f = gzip.open(filename, 'rb')
69-
train_set, valid_set, test_set, dicts = cPickle.load(f)
73+
try:
74+
train_set, valid_set, test_set, dicts = pickle.load(f, encoding='latin1')
75+
except:
76+
train_set, valid_set, test_set, dicts = pickle.load(f)
7077
return train_set, valid_set, test_set, dicts
7178

7279

@@ -107,7 +114,7 @@ def download(origin, destination):
107114
download the corresponding atis file
108115
from http://www-etud.iro.umontreal.ca/~mesnilgr/atis/
109116
'''
110-
print 'Downloading data from %s' % origin
117+
print('Downloading data from %s' % origin)
111118
urllib.urlretrieve(origin, destination)
112119

113120

@@ -125,8 +132,10 @@ def get_perf(filename, folder):
125132
stdin=subprocess.PIPE,
126133
stdout=subprocess.PIPE)
127134

128-
stdout, _ = proc.communicate(''.join(open(filename).readlines()))
135+
stdout, _ = proc.communicate(''.join(open(filename).readlines()).encode('utf-8'))
136+
stdout = stdout.decode('utf-8')
129137
out = None
138+
130139
for line in stdout.split('\n'):
131140
if 'accuracy' in line:
132141
out = line.split()
@@ -237,7 +246,7 @@ def recurrence(x_t, h_tm1):
237246
def train(self, x, y, window_size, learning_rate):
238247

239248
cwords = contextwin(x, window_size)
240-
words = map(lambda x: numpy.asarray(x).astype('int32'), cwords)
249+
words = list(map(lambda x: numpy.asarray(x).astype('int32'), cwords))
241250
labels = y
242251

243252
self.sentence_train(words, labels, learning_rate)
@@ -274,7 +283,7 @@ def main(param=None):
274283
'nepochs': 60,
275284
# 60 is recommended
276285
'savemodel': False}
277-
print param
286+
print(param)
278287

279288
folder_name = os.path.basename(__file__).split('.')[0]
280289
folder = os.path.join(os.path.dirname(__file__), folder_name)
@@ -284,8 +293,8 @@ def main(param=None):
284293
# load the dataset
285294
train_set, valid_set, test_set, dic = atisfold(param['fold'])
286295

287-
idx2label = dict((k, v) for v, k in dic['labels2idx'].iteritems())
288-
idx2word = dict((k, v) for v, k in dic['words2idx'].iteritems())
296+
idx2label = dict((k, v) for v, k in dic['labels2idx'].items())
297+
idx2word = dict((k, v) for v, k in dic['words2idx'].items())
289298

290299
train_lex, train_ne, train_y = train_set
291300
valid_lex, valid_ne, valid_y = valid_set
@@ -323,9 +332,9 @@ def main(param=None):
323332

324333
for i, (x, y) in enumerate(zip(train_lex, train_y)):
325334
rnn.train(x, y, param['win'], param['clr'])
326-
print '[learning] epoch %i >> %2.2f%%' % (
327-
e, (i + 1) * 100. / nsentences),
328-
print 'completed in %.2f (sec) <<\r' % (timeit.default_timer() - tic),
335+
print('[learning] epoch %i >> %2.2f%%' % (
336+
e, (i + 1) * 100. / nsentences),)
337+
print('completed in %.2f (sec) <<\r' % (timeit.default_timer() - tic),)
329338
sys.stdout.flush()
330339

331340
# evaluation // back into the real world : idx -> words
@@ -374,7 +383,7 @@ def main(param=None):
374383
folder + '/best.valid.txt'])
375384
else:
376385
if param['verbose']:
377-
print ''
386+
print('')
378387

379388
# learning rate decay if no improvement in 10 epochs
380389
if param['decay'] and abs(param['be']-param['ce']) >= 10:
@@ -384,10 +393,10 @@ def main(param=None):
384393
if param['clr'] < 1e-5:
385394
break
386395

387-
print('BEST RESULT: epoch', param['be'],
388-
'valid F1', param['vf1'],
389-
'best test F1', param['tf1'],
390-
'with the model', folder)
396+
print(('BEST RESULT: epoch', param['be'],
397+
'valid F1', param['vf1'],
398+
'best test F1', param['tf1'],
399+
'with the model', folder))
391400

392401

393402
if __name__ == '__main__':

0 commit comments

Comments
 (0)