Skip to content

Commit 2da9122

Browse files
committed
catch ctrl-C
1 parent 43adeff commit 2da9122

File tree

1 file changed

+81
-76
lines changed

1 file changed

+81
-76
lines changed

code/lstm.py

Lines changed: 81 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,8 @@ def train_lstm(
384384
optimizer=adadelta, # sgd, adadelta and rmsprop available, sgd very hard to use, not recommanded (probably need momentum and decaying learning rate).
385385
encoder='lstm', # TODO: can be removed must be lstm.
386386
saveto='lstm_model.npz', # The best model will be saved there
387-
validFreq=390, # Compute the validation error after this number of update.
388-
saveFreq=1040, # Save the parameters after every saveFreq updates
387+
validFreq=370, # Compute the validation error after this number of update.
388+
saveFreq=1110, # Save the parameters after every saveFreq updates
389389
maxlen=100, # Sequence longer then this get ignored
390390
batch_size=16, # The batch size during training.
391391
valid_batch_size=64, # The batch size used for validation/test set.
@@ -467,80 +467,85 @@ def train_lstm(
467467
uidx = 0 # the number of update done
468468
estop = False # early stop
469469
start_time = time.clock()
470-
for eidx in xrange(max_epochs):
471-
n_samples = 0
472-
473-
# Get new shuffled index for the training set.
474-
kf = get_minibatches_idx(len(train[0]), batch_size, shuffle=True)
475-
476-
for _, train_index in kf:
477-
uidx += 1
478-
use_noise.set_value(1.)
479-
480-
# Select the random examples for this minibatch
481-
y = [train[1][t] for t in train_index]
482-
x = [train[0][t]for t in train_index]
483-
484-
# Get the data in numpy.ndarray formet.
485-
# It return something of the shape (minibatch maxlen, n samples)
486-
x, mask, y = prepare_data(x, y, maxlen=maxlen)
487-
if x is None:
488-
print 'Minibatch with zero sample under length ', maxlen
489-
continue
490-
n_samples += x.shape[1]
491-
492-
cost = f_grad_shared(x, mask, y)
493-
f_update(lrate)
494-
495-
if numpy.isnan(cost) or numpy.isinf(cost):
496-
print 'NaN detected'
497-
return 1., 1., 1.
498-
499-
if numpy.mod(uidx, dispFreq) == 0:
500-
print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost
501-
502-
if numpy.mod(uidx, saveFreq) == 0:
503-
print 'Saving...',
504-
505-
if best_p is not None:
506-
params = best_p
507-
else:
508-
params = unzip(tparams)
509-
numpy.savez(saveto, history_errs=history_errs, **params)
510-
pkl.dump(model_options, open('%s.pkl' % saveto, 'wb'), -1)
511-
print 'Done'
512-
513-
if numpy.mod(uidx, validFreq) == 0:
514-
use_noise.set_value(0.)
515-
train_err = pred_error(f_pred, prepare_data, train, kf)
516-
valid_err = pred_error(f_pred, prepare_data, valid, kf_valid)
517-
test_err = pred_error(f_pred, prepare_data, test, kf_test)
518-
519-
history_errs.append([valid_err, test_err])
520-
521-
if (uidx == 0 or
522-
valid_err <= numpy.array(history_errs)[:,
523-
0].min()):
524-
525-
best_p = unzip(tparams)
526-
bad_counter = 0
527-
528-
print ('Train ', train_err, 'Valid ', valid_err,
529-
'Test ', test_err)
530-
531-
if (len(history_errs) > patience and
532-
valid_err >= numpy.array(history_errs)[:-patience,
533-
0].min()):
534-
bad_counter += 1
535-
if bad_counter > patience:
536-
print 'Early Stop!'
537-
estop = True
538-
break
539-
540-
print 'Seen %d samples' % n_samples
541-
542-
if estop:
543-
break
470+
try:
471+
for eidx in xrange(max_epochs):
472+
n_samples = 0
473+
474+
# Get new shuffled index for the training set.
475+
kf = get_minibatches_idx(len(train[0]), batch_size, shuffle=True)
476+
477+
for _, train_index in kf:
478+
uidx += 1
479+
use_noise.set_value(1.)
480+
481+
# Select the random examples for this minibatch
482+
y = [train[1][t] for t in train_index]
483+
x = [train[0][t]for t in train_index]
484+
485+
# Get the data in numpy.ndarray formet.
486+
# It return something of the shape (minibatch maxlen, n samples)
487+
x, mask, y = prepare_data(x, y, maxlen=maxlen)
488+
if x is None:
489+
print 'Minibatch with zero sample under length ', maxlen
490+
continue
491+
n_samples += x.shape[1]
492+
493+
cost = f_grad_shared(x, mask, y)
494+
f_update(lrate)
495+
496+
if numpy.isnan(cost) or numpy.isinf(cost):
497+
print 'NaN detected'
498+
return 1., 1., 1.
499+
500+
if numpy.mod(uidx, dispFreq) == 0:
501+
print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost
502+
503+
if numpy.mod(uidx, saveFreq) == 0:
504+
print 'Saving...',
505+
506+
if best_p is not None:
507+
params = best_p
508+
else:
509+
params = unzip(tparams)
510+
numpy.savez(saveto, history_errs=history_errs, **params)
511+
pkl.dump(model_options, open('%s.pkl' % saveto, 'wb'), -1)
512+
print 'Done'
513+
514+
if numpy.mod(uidx, validFreq) == 0:
515+
use_noise.set_value(0.)
516+
train_err = pred_error(f_pred, prepare_data, train, kf)
517+
valid_err = pred_error(f_pred, prepare_data, valid, kf_valid)
518+
test_err = pred_error(f_pred, prepare_data, test, kf_test)
519+
520+
history_errs.append([valid_err, test_err])
521+
522+
if (uidx == 0 or
523+
valid_err <= numpy.array(history_errs)[:,
524+
0].min()):
525+
526+
best_p = unzip(tparams)
527+
bad_counter = 0
528+
529+
print ('Train ', train_err, 'Valid ', valid_err,
530+
'Test ', test_err)
531+
532+
if (len(history_errs) > patience and
533+
valid_err >= numpy.array(history_errs)[:-patience,
534+
0].min()):
535+
bad_counter += 1
536+
if bad_counter > patience:
537+
print 'Early Stop!'
538+
estop = True
539+
break
540+
541+
print 'Seen %d samples' % n_samples
542+
543+
if estop:
544+
break
545+
546+
except KeyboardInterrupt:
547+
print "Training interupted"
548+
544549
end_time = time.clock()
545550
if best_p is not None:
546551
zipp(best_p, tparams)

0 commit comments

Comments
 (0)