@@ -536,14 +536,78 @@ class CRF2oDependencyModel(BiaffineDependencyModel):
536
536
https://www.aclweb.org/anthology/2020.acl-main.302/
537
537
"""
538
538
539
- def __init__ (self , n_lstm_hidden = 400 , n_mlp_sib = 100 , mlp_dropout = .33 , ** kwargs ):
540
- super ().__init__ (** kwargs )
539
+ def __init__ (self ,
540
+ n_words ,
541
+ n_feats ,
542
+ n_rels ,
543
+ feat = 'char' ,
544
+ n_embed = 100 ,
545
+ n_feat_embed = 100 ,
546
+ n_char_embed = 50 ,
547
+ bert = None ,
548
+ n_bert_layers = 4 ,
549
+ mix_dropout = .0 ,
550
+ embed_dropout = .33 ,
551
+ n_lstm_hidden = 400 ,
552
+ n_lstm_layers = 3 ,
553
+ lstm_dropout = .33 ,
554
+ n_mlp_arc = 500 ,
555
+ n_mlp_sib = 100 ,
556
+ n_mlp_rel = 100 ,
557
+ mlp_dropout = .33 ,
558
+ feat_pad_index = 0 ,
559
+ pad_index = 0 ,
560
+ unk_index = 1 ,
561
+ ** kwargs ):
562
+ super ().__init__ (** Config ().update (locals ()))
541
563
564
+ # the embedding layer
565
+ self .word_embed = nn .Embedding (num_embeddings = n_words ,
566
+ embedding_dim = n_embed )
567
+ if feat == 'char' :
568
+ self .feat_embed = CharLSTM (n_chars = n_feats ,
569
+ n_embed = n_char_embed ,
570
+ n_out = n_feat_embed ,
571
+ pad_index = feat_pad_index )
572
+ elif feat == 'bert' :
573
+ self .feat_embed = BertEmbedding (model = bert ,
574
+ n_layers = n_bert_layers ,
575
+ n_out = n_feat_embed ,
576
+ pad_index = feat_pad_index ,
577
+ dropout = mix_dropout )
578
+ self .n_feat_embed = self .feat_embed .n_out
579
+ elif feat == 'tag' :
580
+ self .feat_embed = nn .Embedding (num_embeddings = n_feats ,
581
+ embedding_dim = n_feat_embed )
582
+ else :
583
+ raise RuntimeError ("The feat type should be in ['char', 'bert', 'tag']." )
584
+ self .embed_dropout = IndependentDropout (p = embed_dropout )
585
+
586
+ # the lstm layer
587
+ self .lstm = LSTM (input_size = n_embed + n_feat_embed ,
588
+ hidden_size = n_lstm_hidden ,
589
+ num_layers = n_lstm_layers ,
590
+ bidirectional = True ,
591
+ dropout = lstm_dropout )
592
+ self .lstm_dropout = SharedDropout (p = lstm_dropout )
593
+
594
+ # the MLP layers
595
+ self .mlp_arc_d = MLP (n_in = n_lstm_hidden * 2 , n_out = n_mlp_arc , dropout = mlp_dropout )
596
+ self .mlp_arc_h = MLP (n_in = n_lstm_hidden * 2 , n_out = n_mlp_arc , dropout = mlp_dropout )
542
597
self .mlp_sib_s = MLP (n_in = n_lstm_hidden * 2 , n_out = n_mlp_sib , dropout = mlp_dropout )
543
598
self .mlp_sib_d = MLP (n_in = n_lstm_hidden * 2 , n_out = n_mlp_sib , dropout = mlp_dropout )
544
599
self .mlp_sib_h = MLP (n_in = n_lstm_hidden * 2 , n_out = n_mlp_sib , dropout = mlp_dropout )
600
+ self .mlp_rel_d = MLP (n_in = n_lstm_hidden * 2 , n_out = n_mlp_rel , dropout = mlp_dropout )
601
+ self .mlp_rel_h = MLP (n_in = n_lstm_hidden * 2 , n_out = n_mlp_rel , dropout = mlp_dropout )
545
602
603
+ # the Biaffine layers
604
+ self .arc_attn = Biaffine (n_in = n_mlp_arc , bias_x = True , bias_y = False )
546
605
self .sib_attn = Triaffine (n_in = n_mlp_sib , bias_x = True , bias_y = True )
606
+ self .rel_attn = Biaffine (n_in = n_mlp_rel , n_out = n_rels , bias_x = True , bias_y = True )
607
+ self .criterion = nn .CrossEntropyLoss ()
608
+ self .pad_index = pad_index
609
+ self .unk_index = unk_index
610
+
547
611
self .crf = CRF2oDependency ()
548
612
549
613
def forward (self , words , feats ):
0 commit comments