14
14
from supar .utils .logging import get_logger , progress_bar
15
15
from supar .utils .metric import AttachmentMetric
16
16
from supar .utils .transform import CoNLL
17
+ from torch .optim import Adam
18
+ from torch .optim .lr_scheduler import ExponentialLR
17
19
18
20
logger = get_logger (__name__ )
19
21
@@ -168,9 +170,7 @@ def _evaluate(self, loader):
168
170
mask [:, 0 ] = 0
169
171
s_arc , s_rel = self .model (words , feats )
170
172
loss = self .model .loss (s_arc , s_rel , arcs , rels , mask , self .args .partial )
171
- arc_preds , rel_preds = self .model .decode (s_arc , s_rel , mask ,
172
- self .args .tree ,
173
- self .args .proj )
173
+ arc_preds , rel_preds = self .model .decode (s_arc , s_rel , mask , self .args .tree , self .args .proj )
174
174
if self .args .partial :
175
175
mask &= arcs .ge (0 )
176
176
# ignore all punctuation if not specified
@@ -194,9 +194,7 @@ def _predict(self, loader):
194
194
mask [:, 0 ] = 0
195
195
lens = mask .sum (1 ).tolist ()
196
196
s_arc , s_rel = self .model (words , feats )
197
- arc_preds , rel_preds = self .model .decode (s_arc , s_rel , mask ,
198
- self .args .tree ,
199
- self .args .proj )
197
+ arc_preds , rel_preds = self .model .decode (s_arc , s_rel , mask , self .args .tree , self .args .proj )
200
198
arcs .extend (arc_preds [mask ].split (lens ))
201
199
rels .extend (rel_preds [mask ].split (lens ))
202
200
if self .args .prob :
@@ -211,13 +209,21 @@ def _predict(self, loader):
211
209
return preds
212
210
213
211
@classmethod
214
- def build (cls , path , min_freq = 2 , fix_len = 20 , ** kwargs ):
212
+ def build (cls , path ,
213
+ optimizer_args = {'lr' : 2e-3 , 'betas' : (.9 , .9 ), 'eps' : 1e-12 },
214
+ scheduler_args = {'gamma' : .75 ** (1 / 5000 )},
215
+ min_freq = 2 ,
216
+ fix_len = 20 , ** kwargs ):
215
217
r"""
216
218
Build a brand-new Parser, including initialization of all data fields and model parameters.
217
219
218
220
Args:
219
221
path (str):
220
222
The path of the model to be saved.
223
+ optimizer_args (dict):
224
+ Arguments for creating an optimizer.
225
+ scheduler_args (dict):
226
+ Arguments for creating a scheduler.
221
227
min_freq (str):
222
228
The minimum frequency needed to include a token in the vocabulary. Default: 2.
223
229
fix_len (int):
@@ -273,9 +279,15 @@ def build(cls, path, min_freq=2, fix_len=20, **kwargs):
273
279
'bos_index' : WORD .bos_index ,
274
280
'feat_pad_index' : FEAT .pad_index ,
275
281
})
282
+
283
+ logger .info ("Building the model" )
276
284
model = cls .MODEL (** args )
277
285
model .load_pretrained (WORD .embed ).to (args .device )
278
- return cls (args , model , transform )
286
+
287
+ optimizer = Adam (model .parameters (), ** optimizer_args )
288
+ scheduler = ExponentialLR (optimizer , ** scheduler_args )
289
+
290
+ return cls (args , model , transform , optimizer , scheduler )
279
291
280
292
281
293
class CRFNPDependencyParser (BiaffineDependencyParser ):
@@ -584,9 +596,7 @@ def _train(self, loader):
584
596
# ignore the first token of each sentence
585
597
mask [:, 0 ] = 0
586
598
s_arc , s_rel = self .model (words , feats )
587
- loss , s_arc = self .model .loss (s_arc , s_rel , arcs , rels , mask ,
588
- self .args .mbr ,
589
- self .args .partial )
599
+ loss , s_arc = self .model .loss (s_arc , s_rel , arcs , rels , mask , self .args .mbr , self .args .partial )
590
600
loss .backward ()
591
601
nn .utils .clip_grad_norm_ (self .model .parameters (), self .args .clip )
592
602
self .optimizer .step ()
@@ -612,12 +622,8 @@ def _evaluate(self, loader):
612
622
# ignore the first token of each sentence
613
623
mask [:, 0 ] = 0
614
624
s_arc , s_rel = self .model (words , feats )
615
- loss , s_arc = self .model .loss (s_arc , s_rel , arcs , rels , mask ,
616
- self .args .mbr ,
617
- self .args .partial )
618
- arc_preds , rel_preds = self .model .decode (s_arc , s_rel , mask ,
619
- self .args .tree ,
620
- self .args .proj )
625
+ loss , s_arc = self .model .loss (s_arc , s_rel , arcs , rels , mask , self .args .mbr , self .args .partial )
626
+ arc_preds , rel_preds = self .model .decode (s_arc , s_rel , mask , self .args .tree , self .args .proj )
621
627
if self .args .partial :
622
628
mask &= arcs .ge (0 )
623
629
# ignore all punctuation if not specified
@@ -643,9 +649,7 @@ def _predict(self, loader):
643
649
s_arc , s_rel = self .model (words , feats )
644
650
if self .args .mbr :
645
651
s_arc = self .model .crf (s_arc , mask , mbr = True )
646
- arc_preds , rel_preds = self .model .decode (s_arc , s_rel , mask ,
647
- self .args .tree ,
648
- self .args .proj )
652
+ arc_preds , rel_preds = self .model .decode (s_arc , s_rel , mask , self .args .tree , self .args .proj )
649
653
arcs .extend (arc_preds [mask ].split (lens ))
650
654
rels .extend (rel_preds [mask ].split (lens ))
651
655
if self .args .prob :
@@ -780,9 +784,7 @@ def _train(self, loader):
780
784
# ignore the first token of each sentence
781
785
mask [:, 0 ] = 0
782
786
s_arc , s_sib , s_rel = self .model (words , feats )
783
- loss , s_arc = self .model .loss (s_arc , s_sib , s_rel , arcs , sibs , rels , mask ,
784
- self .args .mbr ,
785
- self .args .partial )
787
+ loss , s_arc = self .model .loss (s_arc , s_sib , s_rel , arcs , sibs , rels , mask , self .args .mbr , self .args .partial )
786
788
loss .backward ()
787
789
nn .utils .clip_grad_norm_ (self .model .parameters (), self .args .clip )
788
790
self .optimizer .step ()
@@ -808,13 +810,8 @@ def _evaluate(self, loader):
808
810
# ignore the first token of each sentence
809
811
mask [:, 0 ] = 0
810
812
s_arc , s_sib , s_rel = self .model (words , feats )
811
- loss , s_arc = self .model .loss (s_arc , s_sib , s_rel , arcs , sibs , rels , mask ,
812
- self .args .mbr ,
813
- self .args .partial )
814
- arc_preds , rel_preds = self .model .decode (s_arc , s_sib , s_rel , mask ,
815
- self .args .tree ,
816
- self .args .mbr ,
817
- self .args .proj )
813
+ loss , s_arc = self .model .loss (s_arc , s_sib , s_rel , arcs , sibs , rels , mask , self .args .mbr , self .args .partial )
814
+ arc_preds , rel_preds = self .model .decode (s_arc , s_sib , s_rel , mask , self .args .tree , self .args .mbr , self .args .proj )
818
815
if self .args .partial :
819
816
mask &= arcs .ge (0 )
820
817
# ignore all punctuation if not specified
@@ -840,10 +837,7 @@ def _predict(self, loader):
840
837
s_arc , s_sib , s_rel = self .model (words , feats )
841
838
if self .args .mbr :
842
839
s_arc = self .model .crf ((s_arc , s_sib ), mask , mbr = True )
843
- arc_preds , rel_preds = self .model .decode (s_arc , s_sib , s_rel , mask ,
844
- self .args .tree ,
845
- self .args .mbr ,
846
- self .args .proj )
840
+ arc_preds , rel_preds = self .model .decode (s_arc , s_sib , s_rel , mask , self .args .tree , self .args .mbr , self .args .proj )
847
841
arcs .extend (arc_preds [mask ].split (lens ))
848
842
rels .extend (rel_preds [mask ].split (lens ))
849
843
if self .args .prob :
@@ -858,13 +852,21 @@ def _predict(self, loader):
858
852
return preds
859
853
860
854
@classmethod
861
- def build (cls , path , min_freq = 2 , fix_len = 20 , ** kwargs ):
855
+ def build (cls , path ,
856
+ optimizer_args = {'lr' : 2e-3 , 'betas' : (.9 , .9 ), 'eps' : 1e-12 },
857
+ scheduler_args = {'gamma' : .75 ** (1 / 5000 )},
858
+ min_freq = 2 ,
859
+ fix_len = 20 , ** kwargs ):
862
860
r"""
863
861
Build a brand-new Parser, including initialization of all data fields and model parameters.
864
862
865
863
Args:
866
864
path (str):
867
865
The path of the model to be saved.
866
+ optimizer_args (dict):
867
+ Arguments for creating an optimizer.
868
+ scheduler_args (dict):
869
+ Arguments for creating a scheduler.
868
870
min_freq (str):
869
871
The minimum frequency needed to include a token in the vocabulary. Default: 2.
870
872
fix_len (int):
@@ -921,6 +923,12 @@ def build(cls, path, min_freq=2, fix_len=20, **kwargs):
921
923
'bos_index' : WORD .bos_index ,
922
924
'feat_pad_index' : FEAT .pad_index
923
925
})
926
+
927
+ logger .info ("Building the model" )
924
928
model = cls .MODEL (** args )
925
929
model = model .load_pretrained (WORD .embed ).to (args .device )
926
- return cls (args , model , transform )
930
+
931
+ optimizer = Adam (model .parameters (), ** optimizer_args )
932
+ scheduler = ExponentialLR (optimizer , ** scheduler_args )
933
+
934
+ return cls (args , model , transform , optimizer , scheduler )
0 commit comments