1
1
# -*- coding: utf-8 -*-
2
2
3
+ import os
4
+ from datetime import datetime , timedelta
3
5
from parser import BiaffineParser , Model
4
- from parser .utils import Corpus , Embedding , TextDataset , Vocab , collate_fn
6
+ from parser .metric import Metric
7
+ from parser .utils import Corpus , Embedding , Vocab
8
+ from parser .utils .data import TextDataset , batchify
5
9
6
10
import torch
7
- from torch .utils .data import DataLoader
8
-
9
- from config import Config
11
+ from torch .optim import Adam
12
+ from torch .optim .lr_scheduler import ExponentialLR
10
13
11
14
12
15
class Train (object ):
@@ -15,78 +18,110 @@ def add_subparser(self, name, parser):
15
18
subparser = parser .add_parser (
16
19
name , help = 'Train a model.'
17
20
)
18
- subparser .add_argument ('--ftrain' , default = 'data/train.conllx' ,
21
+ subparser .add_argument ('--buckets' , default = 64 , type = int ,
22
+ help = 'max num of buckets to use' )
23
+ subparser .add_argument ('--punct' , action = 'store_true' ,
24
+ help = 'whether to include punctuation' )
25
+ subparser .add_argument ('--ftrain' , default = 'data/ptb/train.conllx' ,
19
26
help = 'path to train file' )
20
- subparser .add_argument ('--fdev' , default = 'data/dev.conllx' ,
27
+ subparser .add_argument ('--fdev' , default = 'data/ptb/ dev.conllx' ,
21
28
help = 'path to dev file' )
22
- subparser .add_argument ('--ftest' , default = 'data/test.conllx' ,
29
+ subparser .add_argument ('--ftest' , default = 'data/ptb/ test.conllx' ,
23
30
help = 'path to test file' )
24
31
subparser .add_argument ('--fembed' , default = 'data/glove.6B.100d.txt' ,
25
- help = 'path to pretrained embedding file' )
26
- subparser .set_defaults (func = self )
32
+ help = 'path to pretrained embeddings' )
33
+ subparser .add_argument ('--unk' , default = 'unk' ,
34
+ help = 'unk token in pretrained embeddings' )
27
35
28
36
return subparser
29
37
30
- def __call__ (self , args ):
38
+ def __call__ (self , config ):
31
39
print ("Preprocess the data" )
32
- train = Corpus .load (args .ftrain )
33
- dev = Corpus .load (args .fdev )
34
- test = Corpus .load (args .ftest )
35
- embed = Embedding .load (args .fembed )
36
- vocab = Vocab .from_corpus (corpus = train , min_freq = 2 )
37
- vocab .read_embeddings (embed = embed , unk = 'unk' )
38
- torch .save (vocab , args .vocab )
40
+ train = Corpus .load (config .ftrain )
41
+ dev = Corpus .load (config .fdev )
42
+ test = Corpus .load (config .ftest )
43
+ if os .path .exists (config .vocab ):
44
+ vocab = torch .load (config .vocab )
45
+ else :
46
+ vocab = Vocab .from_corpus (corpus = train , min_freq = 2 )
47
+ vocab .read_embeddings (Embedding .load (config .fembed , config .unk ))
48
+ torch .save (vocab , config .vocab )
49
+ config .update ({
50
+ 'n_words' : vocab .n_train_words ,
51
+ 'n_tags' : vocab .n_tags ,
52
+ 'n_rels' : vocab .n_rels ,
53
+ 'pad_index' : vocab .pad_index ,
54
+ 'unk_index' : vocab .unk_index
55
+ })
39
56
print (vocab )
40
57
41
58
print ("Load the dataset" )
42
59
trainset = TextDataset (vocab .numericalize (train ))
43
60
devset = TextDataset (vocab .numericalize (dev ))
44
61
testset = TextDataset (vocab .numericalize (test ))
45
62
# set the data loaders
46
- train_loader = DataLoader (dataset = trainset ,
47
- batch_size = Config .batch_size ,
48
- shuffle = True ,
49
- collate_fn = collate_fn )
50
- dev_loader = DataLoader (dataset = devset ,
51
- batch_size = Config .batch_size ,
52
- collate_fn = collate_fn )
53
- test_loader = DataLoader (dataset = testset ,
54
- batch_size = Config .batch_size ,
55
- collate_fn = collate_fn )
56
- print (f" size of trainset: { len (trainset )} " )
57
- print (f" size of devset: { len (devset )} " )
58
- print (f" size of testset: { len (testset )} " )
63
+ train_loader = batchify (dataset = trainset ,
64
+ batch_size = config .batch_size ,
65
+ n_buckets = config .buckets ,
66
+ shuffle = True )
67
+ dev_loader = batchify (dataset = devset ,
68
+ batch_size = config .batch_size ,
69
+ n_buckets = config .buckets )
70
+ test_loader = batchify (dataset = testset ,
71
+ batch_size = config .batch_size ,
72
+ n_buckets = config .buckets )
73
+ print (f"{ 'train:' :6} { len (trainset ):5} sentences in total, "
74
+ f"{ len (train_loader ):3} batches provided" )
75
+ print (f"{ 'dev:' :6} { len (devset ):5} sentences in total, "
76
+ f"{ len (dev_loader ):3} batches provided" )
77
+ print (f"{ 'test:' :6} { len (testset ):5} sentences in total, "
78
+ f"{ len (test_loader ):3} batches provided" )
59
79
60
80
print ("Create the model" )
61
- params = {
62
- 'n_words' : vocab .n_train_words ,
63
- 'n_embed' : Config .n_embed ,
64
- 'n_tags' : vocab .n_tags ,
65
- 'n_tag_embed' : Config .n_tag_embed ,
66
- 'embed_dropout' : Config .embed_dropout ,
67
- 'n_lstm_hidden' : Config .n_lstm_hidden ,
68
- 'n_lstm_layers' : Config .n_lstm_layers ,
69
- 'lstm_dropout' : Config .lstm_dropout ,
70
- 'n_mlp_arc' : Config .n_mlp_arc ,
71
- 'n_mlp_rel' : Config .n_mlp_rel ,
72
- 'mlp_dropout' : Config .mlp_dropout ,
73
- 'n_rels' : vocab .n_rels ,
74
- 'pad_index' : vocab .pad_index ,
75
- 'unk_index' : vocab .unk_index
76
- }
77
- for k , v in params .items ():
78
- print (f" { k } : { v } " )
79
- network = BiaffineParser (params , vocab .embeddings )
81
+ parser = BiaffineParser (config , vocab .embeddings )
80
82
if torch .cuda .is_available ():
81
- network = network .cuda ()
82
- print (f"{ network } \n " )
83
+ parser = parser .cuda ()
84
+ print (f"{ parser } \n " )
85
+
86
+ model = Model (vocab , parser )
87
+
88
+ total_time = timedelta ()
89
+ best_e , best_metric = 1 , Metric ()
90
+ model .optimizer = Adam (model .parser .parameters (),
91
+ config .lr ,
92
+ (config .beta_1 , config .beta_2 ),
93
+ config .epsilon )
94
+ model .scheduler = ExponentialLR (model .optimizer ,
95
+ config .decay ** (1 / config .steps ))
96
+
97
+ for epoch in range (1 , config .epochs + 1 ):
98
+ start = datetime .now ()
99
+ # train one epoch and update the parameters
100
+ model .train (train_loader )
101
+
102
+ print (f"Epoch { epoch } / { config .epochs } :" )
103
+ loss , train_metric = model .evaluate (train_loader , config .punct )
104
+ print (f"{ 'train:' :6} Loss: { loss :.4f} { train_metric } " )
105
+ loss , dev_metric = model .evaluate (dev_loader , config .punct )
106
+ print (f"{ 'dev:' :6} Loss: { loss :.4f} { dev_metric } " )
107
+ loss , test_metric = model .evaluate (test_loader , config .punct )
108
+ print (f"{ 'test:' :6} Loss: { loss :.4f} { test_metric } " )
109
+
110
+ t = datetime .now () - start
111
+ # save the model if it is the best so far
112
+ if dev_metric > best_metric and epoch > config .patience :
113
+ best_e , best_metric = epoch , dev_metric
114
+ model .parser .save (config .model + f".{ best_e } " )
115
+ print (f"{ t } s elapsed (saved)\n " )
116
+ else :
117
+ print (f"{ t } s elapsed\n " )
118
+ total_time += t
119
+ if epoch - best_e >= config .patience :
120
+ break
121
+ model .parser = BiaffineParser .load (config .model + f".{ best_e } " )
122
+ loss , metric = model .evaluate (test_loader , config .punct )
83
123
84
- model = Model (vocab , network )
85
- model (loaders = (train_loader , dev_loader , test_loader ),
86
- epochs = Config .epochs ,
87
- patience = Config .patience ,
88
- lr = Config .lr ,
89
- betas = (Config .beta_1 , Config .beta_2 ),
90
- epsilon = Config .epsilon ,
91
- annealing = lambda x : Config .decay ** (x / Config .decay_steps ),
92
- file = args .file )
124
+ print (f"max score of dev is { best_metric .score :.2%} at epoch { best_e } " )
125
+ print (f"the score of test at epoch { best_e } is { metric .score :.2%} " )
126
+ print (f"average time of each epoch is { total_time / epoch } s" )
127
+ print (f"{ total_time } s elapsed" )
0 commit comments