@@ -384,8 +384,8 @@ def train_lstm(
384
384
optimizer = adadelta , # sgd, adadelta and rmsprop available, sgd very hard to use, not recommanded (probably need momentum and decaying learning rate).
385
385
encoder = 'lstm' , # TODO: can be removed must be lstm.
386
386
saveto = 'lstm_model.npz' , # The best model will be saved there
387
- validFreq = 390 , # Compute the validation error after this number of update.
388
- saveFreq = 1040 , # Save the parameters after every saveFreq updates
387
+ validFreq = 370 , # Compute the validation error after this number of update.
388
+ saveFreq = 1110 , # Save the parameters after every saveFreq updates
389
389
maxlen = 100 , # Sequence longer then this get ignored
390
390
batch_size = 16 , # The batch size during training.
391
391
valid_batch_size = 64 , # The batch size used for validation/test set.
@@ -467,80 +467,85 @@ def train_lstm(
467
467
uidx = 0 # the number of update done
468
468
estop = False # early stop
469
469
start_time = time .clock ()
470
- for eidx in xrange (max_epochs ):
471
- n_samples = 0
472
-
473
- # Get new shuffled index for the training set.
474
- kf = get_minibatches_idx (len (train [0 ]), batch_size , shuffle = True )
475
-
476
- for _ , train_index in kf :
477
- uidx += 1
478
- use_noise .set_value (1. )
479
-
480
- # Select the random examples for this minibatch
481
- y = [train [1 ][t ] for t in train_index ]
482
- x = [train [0 ][t ]for t in train_index ]
483
-
484
- # Get the data in numpy.ndarray formet.
485
- # It return something of the shape (minibatch maxlen, n samples)
486
- x , mask , y = prepare_data (x , y , maxlen = maxlen )
487
- if x is None :
488
- print 'Minibatch with zero sample under length ' , maxlen
489
- continue
490
- n_samples += x .shape [1 ]
491
-
492
- cost = f_grad_shared (x , mask , y )
493
- f_update (lrate )
494
-
495
- if numpy .isnan (cost ) or numpy .isinf (cost ):
496
- print 'NaN detected'
497
- return 1. , 1. , 1.
498
-
499
- if numpy .mod (uidx , dispFreq ) == 0 :
500
- print 'Epoch ' , eidx , 'Update ' , uidx , 'Cost ' , cost
501
-
502
- if numpy .mod (uidx , saveFreq ) == 0 :
503
- print 'Saving...' ,
504
-
505
- if best_p is not None :
506
- params = best_p
507
- else :
508
- params = unzip (tparams )
509
- numpy .savez (saveto , history_errs = history_errs , ** params )
510
- pkl .dump (model_options , open ('%s.pkl' % saveto , 'wb' ), - 1 )
511
- print 'Done'
512
-
513
- if numpy .mod (uidx , validFreq ) == 0 :
514
- use_noise .set_value (0. )
515
- train_err = pred_error (f_pred , prepare_data , train , kf )
516
- valid_err = pred_error (f_pred , prepare_data , valid , kf_valid )
517
- test_err = pred_error (f_pred , prepare_data , test , kf_test )
518
-
519
- history_errs .append ([valid_err , test_err ])
520
-
521
- if (uidx == 0 or
522
- valid_err <= numpy .array (history_errs )[:,
523
- 0 ].min ()):
524
-
525
- best_p = unzip (tparams )
526
- bad_counter = 0
527
-
528
- print ('Train ' , train_err , 'Valid ' , valid_err ,
529
- 'Test ' , test_err )
530
-
531
- if (len (history_errs ) > patience and
532
- valid_err >= numpy .array (history_errs )[:- patience ,
533
- 0 ].min ()):
534
- bad_counter += 1
535
- if bad_counter > patience :
536
- print 'Early Stop!'
537
- estop = True
538
- break
539
-
540
- print 'Seen %d samples' % n_samples
541
-
542
- if estop :
543
- break
470
+ try :
471
+ for eidx in xrange (max_epochs ):
472
+ n_samples = 0
473
+
474
+ # Get new shuffled index for the training set.
475
+ kf = get_minibatches_idx (len (train [0 ]), batch_size , shuffle = True )
476
+
477
+ for _ , train_index in kf :
478
+ uidx += 1
479
+ use_noise .set_value (1. )
480
+
481
+ # Select the random examples for this minibatch
482
+ y = [train [1 ][t ] for t in train_index ]
483
+ x = [train [0 ][t ]for t in train_index ]
484
+
485
+ # Get the data in numpy.ndarray formet.
486
+ # It return something of the shape (minibatch maxlen, n samples)
487
+ x , mask , y = prepare_data (x , y , maxlen = maxlen )
488
+ if x is None :
489
+ print 'Minibatch with zero sample under length ' , maxlen
490
+ continue
491
+ n_samples += x .shape [1 ]
492
+
493
+ cost = f_grad_shared (x , mask , y )
494
+ f_update (lrate )
495
+
496
+ if numpy .isnan (cost ) or numpy .isinf (cost ):
497
+ print 'NaN detected'
498
+ return 1. , 1. , 1.
499
+
500
+ if numpy .mod (uidx , dispFreq ) == 0 :
501
+ print 'Epoch ' , eidx , 'Update ' , uidx , 'Cost ' , cost
502
+
503
+ if numpy .mod (uidx , saveFreq ) == 0 :
504
+ print 'Saving...' ,
505
+
506
+ if best_p is not None :
507
+ params = best_p
508
+ else :
509
+ params = unzip (tparams )
510
+ numpy .savez (saveto , history_errs = history_errs , ** params )
511
+ pkl .dump (model_options , open ('%s.pkl' % saveto , 'wb' ), - 1 )
512
+ print 'Done'
513
+
514
+ if numpy .mod (uidx , validFreq ) == 0 :
515
+ use_noise .set_value (0. )
516
+ train_err = pred_error (f_pred , prepare_data , train , kf )
517
+ valid_err = pred_error (f_pred , prepare_data , valid , kf_valid )
518
+ test_err = pred_error (f_pred , prepare_data , test , kf_test )
519
+
520
+ history_errs .append ([valid_err , test_err ])
521
+
522
+ if (uidx == 0 or
523
+ valid_err <= numpy .array (history_errs )[:,
524
+ 0 ].min ()):
525
+
526
+ best_p = unzip (tparams )
527
+ bad_counter = 0
528
+
529
+ print ('Train ' , train_err , 'Valid ' , valid_err ,
530
+ 'Test ' , test_err )
531
+
532
+ if (len (history_errs ) > patience and
533
+ valid_err >= numpy .array (history_errs )[:- patience ,
534
+ 0 ].min ()):
535
+ bad_counter += 1
536
+ if bad_counter > patience :
537
+ print 'Early Stop!'
538
+ estop = True
539
+ break
540
+
541
+ print 'Seen %d samples' % n_samples
542
+
543
+ if estop :
544
+ break
545
+
546
+ except KeyboardInterrupt :
547
+ print "Training interupted"
548
+
544
549
end_time = time .clock ()
545
550
if best_p is not None :
546
551
zipp (best_p , tparams )
0 commit comments