|
| 1 | +import numpy |
| 2 | +import time |
| 3 | +import sys |
| 4 | +import subprocess |
| 5 | +import os |
| 6 | +import random |
| 7 | +import copy |
| 8 | +import gzip |
| 9 | +import cPickle |
| 10 | + |
| 11 | +from collections import OrderedDict |
| 12 | + |
| 13 | +import theano |
| 14 | +from theano import tensor as T |
| 15 | + |
| 16 | +PREFIX = os.getenv('ATISDATA', '') |
| 17 | + |
| 18 | + |
| 19 | +# utils functions |
| 20 | +def shuffle(lol, seed): |
| 21 | + ''' |
| 22 | + lol :: list of list as input |
| 23 | + seed :: seed the shuffling |
| 24 | +
|
| 25 | + shuffle inplace each list in the same order |
| 26 | + ''' |
| 27 | + for l in lol: |
| 28 | + random.seed(seed) |
| 29 | + random.shuffle(l) |
| 30 | + |
| 31 | + |
| 32 | +def contextwin(l, win): |
| 33 | + ''' |
| 34 | + win :: int corresponding to the size of the window |
| 35 | + given a list of indexes composing a sentence |
| 36 | + it will return a list of list of indexes corresponding |
| 37 | + to context windows surrounding each word in the sentence |
| 38 | + ''' |
| 39 | + assert (win % 2) == 1 |
| 40 | + assert win >= 1 |
| 41 | + l = list(l) |
| 42 | + |
| 43 | + lpadded = win//2 * [-1] + l + win//2 * [-1] |
| 44 | + out = [lpadded[i:i+win] for i in range(len(l))] |
| 45 | + |
| 46 | + assert len(out) == len(l) |
| 47 | + return out |
| 48 | + |
| 49 | + |
| 50 | +# data loading functions |
| 51 | +def atisfold(fold): |
| 52 | + assert fold in range(5) |
| 53 | + filename = os.path.join(PREFIX, 'atis.fold'+str(fold)+'.pkl.gz') |
| 54 | + f = gzip.open(filename, 'rb') |
| 55 | + train_set, valid_set, test_set, dicts = cPickle.load(f) |
| 56 | + return train_set, valid_set, test_set, dicts |
| 57 | + |
| 58 | + |
| 59 | +# metrics function using conlleval.pl |
| 60 | +def conlleval(p, g, w, filename): |
| 61 | + ''' |
| 62 | + INPUT: |
| 63 | + p :: predictions |
| 64 | + g :: groundtruth |
| 65 | + w :: corresponding words |
| 66 | +
|
| 67 | + OUTPUT: |
| 68 | + filename :: name of the file where the predictions |
| 69 | + are written. it will be the input of conlleval.pl script |
| 70 | + for computing the performance in terms of precision |
| 71 | + recall and f1 score |
| 72 | + ''' |
| 73 | + out = '' |
| 74 | + for sl, sp, sw in zip(g, p, w): |
| 75 | + out += 'BOS O O\n' |
| 76 | + for wl, wp, w in zip(sl, sp, sw): |
| 77 | + out += w + ' ' + wl + ' ' + wp + '\n' |
| 78 | + out += 'EOS O O\n\n' |
| 79 | + |
| 80 | + f = open(filename, 'w') |
| 81 | + f.writelines(out) |
| 82 | + f.close() |
| 83 | + |
| 84 | + return get_perf(filename) |
| 85 | + |
| 86 | + |
| 87 | +def get_perf(filename): |
| 88 | + ''' run conlleval.pl perl script to obtain |
| 89 | + precision/recall and F1 score ''' |
| 90 | + _conlleval = PREFIX + 'conlleval.pl' |
| 91 | + if not os.path.isfile(_conlleval): |
| 92 | + # url = 'http://www-etud.iro.umontreal.ca/ |
| 93 | + # ~mesnilgr/atis/conlleval.pl' |
| 94 | + download(url) |
| 95 | + chmod('conlleval.pl', stat.S_IRWXU) # give the execute permissions |
| 96 | + |
| 97 | + proc = subprocess.Popen(["perl", |
| 98 | + _conlleval], |
| 99 | + stdin=subprocess.PIPE, |
| 100 | + stdout=subprocess.PIPE) |
| 101 | + |
| 102 | + stdout, _ = proc.communicate(''.join(open(filename).readlines())) |
| 103 | + for line in stdout.split('\n'): |
| 104 | + if 'accuracy' in line: |
| 105 | + out = line.split() |
| 106 | + break |
| 107 | + |
| 108 | + precision = float(out[6][:-2]) |
| 109 | + recall = float(out[8][:-2]) |
| 110 | + f1score = float(out[10]) |
| 111 | + |
| 112 | + return {'p': precision, 'r': recall, 'f1': f1score} |
| 113 | + |
| 114 | + |
| 115 | +# actual model |
| 116 | +class basemodel(object): |
| 117 | + ''' load/save structure ''' |
| 118 | + |
| 119 | + def save(self, folder): |
| 120 | + for param in self.params: |
| 121 | + numpy.save(os.path.join(folder, |
| 122 | + param.name + '.npy'), param.get_value()) |
| 123 | + |
| 124 | + def load(self, folder): |
| 125 | + for param in self.params: |
| 126 | + param.set_value(numpy.load(os.path.join(folder, |
| 127 | + param.name + '.npy'))) |
| 128 | + |
| 129 | + |
| 130 | +class model(basemodel): |
| 131 | + ''' elman neural net model ''' |
| 132 | + def __init__(self, nh, nc, ne, de, cs): |
| 133 | + ''' |
| 134 | + nh :: dimension of the hidden layer |
| 135 | + nc :: number of classes |
| 136 | + ne :: number of word embeddings in the vocabulary |
| 137 | + de :: dimension of the word embeddings |
| 138 | + cs :: word window context size |
| 139 | + ''' |
| 140 | + # parameters of the model |
| 141 | + self.emb = theano.shared(name='embeddings', |
| 142 | + value=0.2 * numpy.random.uniform(-1.0, 1.0, |
| 143 | + (ne+1, de)) |
| 144 | + # add one for padding at the end |
| 145 | + .astype(theano.config.floatX)) |
| 146 | + self.wx = theano.shared(name='wx', |
| 147 | + value=0.2 * numpy.random.uniform(-1.0, 1.0, |
| 148 | + (de * cs, nh)) |
| 149 | + .astype(theano.config.floatX)) |
| 150 | + self.wh = theano.shared(name='wh', |
| 151 | + value=0.2 * numpy.random.uniform(-1.0, 1.0, |
| 152 | + (nh, nh)) |
| 153 | + .astype(theano.config.floatX)) |
| 154 | + self.w = theano.shared(name='w', |
| 155 | + value=0.2 * numpy.random.uniform(-1.0, 1.0, |
| 156 | + (nh, nc)) |
| 157 | + .astype(theano.config.floatX)) |
| 158 | + self.bh = theano.shared(name='bh', |
| 159 | + value=numpy.zeros(nh, |
| 160 | + dtype=theano.config.floatX)) |
| 161 | + self.b = theano.shared(name='b', |
| 162 | + value=numpy.zeros(nc, |
| 163 | + dtype=theano.config.floatX)) |
| 164 | + self.h0 = theano.shared(name='h0', |
| 165 | + value=numpy.zeros(nh, |
| 166 | + dtype=theano.config.floatX)) |
| 167 | + |
| 168 | + # bundle |
| 169 | + self.params = [self.emb, self.wx, self.wh, self.w, |
| 170 | + self.bh, self.b, self.h0] |
| 171 | + # as many columns as context window size |
| 172 | + # as many lines as words in the sentence |
| 173 | + idxs = T.imatrix() |
| 174 | + x = self.emb[idxs].reshape((idxs.shape[0], de*cs)) |
| 175 | + y_sentence = T.ivector('y_sentence') # labels |
| 176 | + |
| 177 | + def recurrence(x_t, h_tm1): |
| 178 | + h_t = T.nnet.sigmoid(T.dot(x_t, self.wx) |
| 179 | + + T.dot(h_tm1, self.wh) + self.bh) |
| 180 | + s_t = T.nnet.softmax(T.dot(h_t, self.w) + self.b) |
| 181 | + return [h_t, s_t] |
| 182 | + |
| 183 | + [h, s], _ = theano.scan(fn=recurrence, |
| 184 | + sequences=x, |
| 185 | + outputs_info=[self.h0, None], |
| 186 | + n_steps=x.shape[0]) |
| 187 | + |
| 188 | + p_y_given_x_sentence = s[:, 0, :] |
| 189 | + y_pred = T.argmax(p_y_given_x_sentence, axis=1) |
| 190 | + |
| 191 | + # cost and gradients and learning rate |
| 192 | + lr = T.scalar('lr') |
| 193 | + |
| 194 | + sentence_nll = -T.mean(T.log(p_y_given_x_sentence) |
| 195 | + [T.arange(x.shape[0]), y_sentence]) |
| 196 | + sentence_gradients = T.grad(sentence_nll, self.params) |
| 197 | + sentence_updates = OrderedDict((p, p - lr*g) |
| 198 | + for p, g in |
| 199 | + zip(self.params, sentence_gradients)) |
| 200 | + |
| 201 | + # theano functions to compile |
| 202 | + self.classify = theano.function(inputs=[idxs], outputs=y_pred) |
| 203 | + self.sentence_train = theano.function(inputs=[idxs, y_sentence, lr], |
| 204 | + outputs=sentence_nll, |
| 205 | + updates=sentence_updates) |
| 206 | + self.normalize = theano.function(inputs=[], |
| 207 | + updates={self.emb: |
| 208 | + self.emb / |
| 209 | + T.sqrt((self.emb**2) |
| 210 | + .sum(axis=1)) |
| 211 | + .dimshuffle(0, 'x')}) |
| 212 | + |
| 213 | + def train(self, x, y, window_size, learning_rate): |
| 214 | + |
| 215 | + cwords = contextwin(x, window_size) |
| 216 | + words = map(lambda x: numpy.asarray(x).astype('int32'), cwords) |
| 217 | + labels = y |
| 218 | + |
| 219 | + self.sentence_train(words, labels, learning_rate) |
| 220 | + self.normalize() |
| 221 | + |
| 222 | + |
| 223 | +def main(param, sync=None): |
| 224 | + |
| 225 | + folder = os.path.basename(__file__).split('.')[0] |
| 226 | + if not os.path.exists(folder): |
| 227 | + os.mkdir(folder) |
| 228 | + |
| 229 | + # load the dataset |
| 230 | + train_set, valid_set, test_set, dic = atisfold(param['fold']) |
| 231 | + |
| 232 | + idx2label = dict((k, v) for v, k in dic['labels2idx'].iteritems()) |
| 233 | + idx2word = dict((k, v) for v, k in dic['words2idx'].iteritems()) |
| 234 | + |
| 235 | + train_lex, train_ne, train_y = train_set |
| 236 | + valid_lex, valid_ne, valid_y = valid_set |
| 237 | + test_lex, test_ne, test_y = test_set |
| 238 | + |
| 239 | + vocsize = len(set(reduce(lambda x, y: list(x) + list(y), |
| 240 | + train_lex + valid_lex + test_lex))) |
| 241 | + nclasses = len(set(reduce(lambda x, y: list(x)+list(y), |
| 242 | + train_y + test_y + valid_y))) |
| 243 | + nsentences = len(train_lex) |
| 244 | + |
| 245 | + groundtruth_valid = [map(lambda x: idx2label[x], y) for y in valid_y] |
| 246 | + words_valid = [map(lambda x: idx2word[x], w) for w in valid_lex] |
| 247 | + groundtruth_test = [map(lambda x: idx2label[x], y) for y in test_y] |
| 248 | + words_test = [map(lambda x: idx2word[x], w) for w in test_lex] |
| 249 | + |
| 250 | + # instanciate the model |
| 251 | + numpy.random.seed(param['seed']) |
| 252 | + random.seed(param['seed']) |
| 253 | + |
| 254 | + rnn = model(nh=param['nhidden'], |
| 255 | + nc=nclasses, |
| 256 | + ne=vocsize, |
| 257 | + de=param['emb_dimension'], |
| 258 | + cs=param['win']) |
| 259 | + |
| 260 | + # train with early stopping on validation set |
| 261 | + best_f1 = -numpy.inf |
| 262 | + param['clr'] = param['lr'] |
| 263 | + for e in xrange(param['nepochs']): |
| 264 | + |
| 265 | + # shuffle |
| 266 | + shuffle([train_lex, train_ne, train_y], param['seed']) |
| 267 | + |
| 268 | + param['ce'] = e |
| 269 | + tic = time.time() |
| 270 | + |
| 271 | + for i, (x, y) in enumerate(zip(train_lex, train_y)): |
| 272 | + rnn.train(x, y, param['win'], param['clr']) |
| 273 | + |
| 274 | + # evaluation // back into the real world : idx -> words |
| 275 | + predictions_test = [map(lambda x: idx2label[x], |
| 276 | + rnn.classify(numpy.asarray( |
| 277 | + contextwin(x, param['win'])).astype('int32'))) |
| 278 | + for x in test_lex] |
| 279 | + predictions_valid = [map(lambda x: idx2label[x], |
| 280 | + rnn.classify(numpy.asarray( |
| 281 | + contextwin(x, param['win'])).astype('int32'))) |
| 282 | + for x in valid_lex] |
| 283 | + |
| 284 | + # evaluation // compute the accuracy using conlleval.pl |
| 285 | + res_test = conlleval(predictions_test, |
| 286 | + groundtruth_test, |
| 287 | + words_test, |
| 288 | + folder + '/current.test.txt') |
| 289 | + res_valid = conlleval(predictions_valid, |
| 290 | + groundtruth_valid, |
| 291 | + words_valid, |
| 292 | + folder + '/current.valid.txt') |
| 293 | + |
| 294 | + if res_valid['f1'] > best_f1: |
| 295 | + |
| 296 | + if sync is not None: |
| 297 | + sync() |
| 298 | + if param['savemodel']: |
| 299 | + rnn.save(folder) |
| 300 | + |
| 301 | + best_rnn = copy.deepcopy(rnn) |
| 302 | + best_f1 = res_valid['f1'] |
| 303 | + |
| 304 | + if param['verbose']: |
| 305 | + print('NEW BEST: epoch', e, |
| 306 | + 'valid F1', res_valid['f1'], |
| 307 | + 'best test F1', res_test['f1']) |
| 308 | + |
| 309 | + param['vf1'], param['tf1'] = res_valid['f1'], res_test['f1'] |
| 310 | + param['vp'], param['tp'] = res_valid['p'], res_test['p'] |
| 311 | + param['vr'], param['tr'] = res_valid['r'], res_test['r'] |
| 312 | + param['be'] = e |
| 313 | + |
| 314 | + subprocess.call(['mv', folder + '/current.test.txt', |
| 315 | + folder + '/best.test.txt']) |
| 316 | + subprocess.call(['mv', folder + '/current.valid.txt', |
| 317 | + folder + '/best.valid.txt']) |
| 318 | + else: |
| 319 | + if param['verbose']: |
| 320 | + print '' |
| 321 | + |
| 322 | + # learning rate decay if no improvement in 10 epochs |
| 323 | + if param['decay'] and abs(param['be']-param['ce']) >= 10: |
| 324 | + param['clr'] *= 0.5 |
| 325 | + rnn = best_rnn |
| 326 | + |
| 327 | + if param['clr'] < 1e-5: |
| 328 | + break |
| 329 | + |
| 330 | + print('BEST RESULT: epoch', param['be'], |
| 331 | + 'valid F1', param['vf1'], |
| 332 | + 'best test F1', param['tf1'], |
| 333 | + 'with the model', folder) |
| 334 | + |
| 335 | +if __name__ == '__main__': |
| 336 | + |
| 337 | + # best model |
| 338 | + s = {'fold': 3, |
| 339 | + # 5 folds 0,1,2,3,4 |
| 340 | + 'data': 'atis', |
| 341 | + 'lr': 0.0970806646812754, |
| 342 | + 'verbose': 1, |
| 343 | + 'decay': True, |
| 344 | + # decay on the learning rate if improvement stops |
| 345 | + 'win': 7, |
| 346 | + # number of words in the context window |
| 347 | + 'nhidden': 200, |
| 348 | + # number of hidden units |
| 349 | + 'seed': 345, |
| 350 | + 'emb_dimension': 50, |
| 351 | + # dimension of word embedding |
| 352 | + 'nepochs': 60, |
| 353 | + 'savemodel': True} |
| 354 | + |
| 355 | + main(s) |
| 356 | + print s |
0 commit comments