Skip to content

Commit b7c8d04

Browse files
author
Grégoire
committed
fixes according to reviews
1 parent 21664bf commit b7c8d04

File tree

1 file changed

+21
-27
lines changed

1 file changed

+21
-27
lines changed

code/rnnslu.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -112,22 +112,7 @@ def get_perf(filename):
112112
return {'p': precision, 'r': recall, 'f1': f1score}
113113

114114

115-
# actual model
116-
class basemodel(object):
117-
''' load/save structure '''
118-
119-
def save(self, folder):
120-
for param in self.params:
121-
numpy.save(os.path.join(folder,
122-
param.name + '.npy'), param.get_value())
123-
124-
def load(self, folder):
125-
for param in self.params:
126-
param.set_value(numpy.load(os.path.join(folder,
127-
param.name + '.npy')))
128-
129-
130-
class model(basemodel):
115+
class RNNSLU(object):
131116
''' elman neural net model '''
132117
def __init__(self, nh, nc, ne, de, cs):
133118
'''
@@ -219,8 +204,19 @@ def train(self, x, y, window_size, learning_rate):
219204
self.sentence_train(words, labels, learning_rate)
220205
self.normalize()
221206

207+
def save(self, folder):
208+
for param in self.params:
209+
numpy.save(os.path.join(folder,
210+
param.name + '.npy'), param.get_value())
211+
212+
def load(self, folder):
213+
for param in self.params:
214+
param.set_value(numpy.load(os.path.join(folder,
215+
param.name + '.npy')))
222216

223-
def main(param, sync=None):
217+
218+
219+
def main(param):
224220

225221
folder = os.path.basename(__file__).split('.')[0]
226222
if not os.path.exists(folder):
@@ -251,11 +247,11 @@ def main(param, sync=None):
251247
numpy.random.seed(param['seed'])
252248
random.seed(param['seed'])
253249

254-
rnn = model(nh=param['nhidden'],
255-
nc=nclasses,
256-
ne=vocsize,
257-
de=param['emb_dimension'],
258-
cs=param['win'])
250+
rnn = RNNSLU(nh=param['nhidden'],
251+
nc=nclasses,
252+
ne=vocsize,
253+
de=param['emb_dimension'],
254+
cs=param['win'])
259255

260256
# train with early stopping on validation set
261257
best_f1 = -numpy.inf
@@ -293,8 +289,6 @@ def main(param, sync=None):
293289

294290
if res_valid['f1'] > best_f1:
295291

296-
if sync is not None:
297-
sync()
298292
if param['savemodel']:
299293
rnn.save(folder)
300294

@@ -332,8 +326,7 @@ def main(param, sync=None):
332326
'best test F1', param['tf1'],
333327
'with the model', folder)
334328

335-
if __name__ == '__main__':
336-
329+
def test_rnnslu(n_epochs):
337330
# best model
338331
s = {'fold': 3,
339332
# 5 folds 0,1,2,3,4
@@ -349,7 +342,8 @@ def main(param, sync=None):
349342
'seed': 345,
350343
'emb_dimension': 50,
351344
# dimension of word embedding
352-
'nepochs': 60,
345+
'nepochs': n_epochs,
346+
# 60 is recommended
353347
'savemodel': True}
354348

355349
main(s)

0 commit comments

Comments
 (0)