@@ -19,16 +19,21 @@ class BiaffineSemanticDependencyModel(nn.Module):
19
19
Args:
20
20
n_words (int):
21
21
The size of the word vocabulary.
22
- n_feats (int):
23
- The size of the feat vocabulary.
24
22
n_labels (int):
25
23
The number of labels in the treebank.
24
+ n_tags (int):
25
+ The number of POS tags, needed if POS tag embeddings are used. Default: ``None``.
26
+ n_chars (int):
27
+ The number of characters, needed if character-level representations are used. Default: ``None``.
28
+ n_lemmas (int):
29
+ The number of lemmas, needed if lemma embeddings are used. Default: ``None``.
26
30
feat (str):
27
- Specifies which type of additional feature to use: ``'char'`` | ``'bert'`` | ``'tag'``.
31
+ Additional features to use,separated by commas.
32
+ ``'tag'``: POS tag embeddings.
28
33
``'char'``: Character-level representations extracted by CharLSTM.
34
+ ``'lemma'``: Lemma embeddings.
29
35
``'bert'``: BERT representations, other pretrained langugae models like XLNet are also feasible.
30
- ``'tag'``: POS tag embeddings.
31
- Default: ``'char'``.
36
+ Default: ``'tag,char,lemma'``.
32
37
n_embed (int):
33
38
The size of word embeddings. Default: 100.
34
39
n_embed_proj (int):
@@ -37,6 +42,8 @@ class BiaffineSemanticDependencyModel(nn.Module):
37
42
The size of feature representations. Default: 100.
38
43
n_char_embed (int):
39
44
The size of character embeddings serving as inputs of CharLSTM, required if ``feat='char'``. Default: 50.
45
+ char_pad_index (int):
46
+ The index of the padding token in the character vocabulary. Default: 0.
40
47
bert (str):
41
48
Specifies which kind of language model to use, e.g., ``'bert-base-cased'`` and ``'xlnet-base-cased'``.
42
49
This is required if ``feat='bert'``. The full list can be found in `transformers`_.
@@ -47,6 +54,8 @@ class BiaffineSemanticDependencyModel(nn.Module):
47
54
Default: 4.
48
55
mix_dropout (float):
49
56
The dropout ratio of BERT layers. Required if ``feat='bert'``. Default: .0.
57
+ bert_pad_index (int):
58
+ The index of the padding token in the BERT vocabulary. Default: 0.
50
59
embed_dropout (float):
51
60
The dropout ratio of input embeddings. Default: .2.
52
61
n_lstm_hidden (int):
@@ -63,8 +72,8 @@ class BiaffineSemanticDependencyModel(nn.Module):
63
72
The dropout ratio of edge MLP layers. Default: .25.
64
73
label_mlp_dropout (float):
65
74
The dropout ratio of label MLP layers. Default: .33.
66
- feat_pad_index (int):
67
- The index of the padding token in the feat vocabulary . Default: 0 .
75
+ interpolation (int):
76
+ Constant to even out the label/edge loss . Default: .1 .
68
77
pad_index (int):
69
78
The index of the padding token in the word vocabulary. Default: 0.
70
79
unk_index (int):
@@ -78,16 +87,20 @@ class BiaffineSemanticDependencyModel(nn.Module):
78
87
79
88
def __init__ (self ,
80
89
n_words ,
81
- n_feats ,
82
90
n_labels ,
83
- feat = 'char' ,
91
+ n_tags = None ,
92
+ n_chars = None ,
93
+ n_lemmas = None ,
94
+ feat = 'tag,char,lemma' ,
84
95
n_embed = 100 ,
85
96
n_embed_proj = 125 ,
86
97
n_feat_embed = 100 ,
87
98
n_char_embed = 50 ,
99
+ char_pad_index = 0 ,
88
100
bert = None ,
89
101
n_bert_layers = 4 ,
90
102
mix_dropout = .0 ,
103
+ bert_pad_index = 0 ,
91
104
embed_dropout = .2 ,
92
105
n_lstm_hidden = 600 ,
93
106
n_lstm_layers = 3 ,
@@ -96,10 +109,9 @@ def __init__(self,
96
109
n_mlp_label = 600 ,
97
110
edge_mlp_dropout = .25 ,
98
111
label_mlp_dropout = .33 ,
99
- feat_pad_index = 0 ,
112
+ interpolation = 0.1 ,
100
113
pad_index = 0 ,
101
114
unk_index = 1 ,
102
- interpolation = 0.1 ,
103
115
** kwargs ):
104
116
super ().__init__ ()
105
117
@@ -109,27 +121,31 @@ def __init__(self,
109
121
embedding_dim = n_embed )
110
122
self .embed_proj = nn .Linear (n_embed , n_embed_proj )
111
123
112
- if feat == 'char' :
113
- self .feat_embed = CharLSTM (n_chars = n_feats ,
124
+ self .n_input = n_embed + n_embed_proj
125
+ if 'tag' in feat :
126
+ self .tag_embed = nn .Embedding (num_embeddings = n_tags ,
127
+ embedding_dim = n_feat_embed )
128
+ self .n_input += n_feat_embed
129
+ if 'char' in feat :
130
+ self .char_embed = CharLSTM (n_chars = n_chars ,
114
131
n_embed = n_char_embed ,
115
132
n_out = n_feat_embed ,
116
- pad_index = feat_pad_index )
117
- elif feat == 'bert' :
118
- self .feat_embed = BertEmbedding (model = bert ,
133
+ pad_index = char_pad_index )
134
+ self .n_input += n_feat_embed
135
+ if 'lemma' in feat :
136
+ self .lemma_embed = nn .Embedding (num_embeddings = n_lemmas ,
137
+ embedding_dim = n_feat_embed )
138
+ self .n_input += n_feat_embed
139
+ if 'bert' in feat :
140
+ self .bert_embed = BertEmbedding (model = bert ,
119
141
n_layers = n_bert_layers ,
120
- n_out = n_feat_embed ,
121
- pad_index = feat_pad_index ,
142
+ pad_index = bert_pad_index ,
122
143
dropout = mix_dropout )
123
- self .n_feat_embed = self .feat_embed .n_out
124
- elif feat == 'tag' :
125
- self .feat_embed = nn .Embedding (num_embeddings = n_feats ,
126
- embedding_dim = n_feat_embed )
127
- else :
128
- raise RuntimeError ("The feat type should be in ['char', 'bert', 'tag']." )
144
+ self .n_input += self .bert_embed .n_out
129
145
self .embed_dropout = IndependentDropout (p = embed_dropout )
130
146
131
147
# the lstm layer
132
- self .lstm = LSTM (input_size = n_embed + n_feat_embed + n_embed_proj ,
148
+ self .lstm = LSTM (input_size = self . n_input ,
133
149
hidden_size = n_lstm_hidden ,
134
150
num_layers = n_lstm_layers ,
135
151
bidirectional = True ,
@@ -146,9 +162,9 @@ def __init__(self,
146
162
self .edge_attn = Biaffine (n_in = n_mlp_edge , n_out = 2 , bias_x = True , bias_y = True )
147
163
self .label_attn = Biaffine (n_in = n_mlp_label , n_out = n_labels , bias_x = True , bias_y = True )
148
164
self .criterion = nn .CrossEntropyLoss ()
165
+ self .interpolation = interpolation
149
166
self .pad_index = pad_index
150
167
self .unk_index = unk_index
151
- self .interpolation = interpolation
152
168
153
169
def load_pretrained (self , embed = None ):
154
170
if embed is not None :
@@ -160,10 +176,10 @@ def forward(self, words, feats):
160
176
Args:
161
177
words (~torch.LongTensor): ``[batch_size, seq_len]``.
162
178
Word indices.
163
- feats (~torch.LongTensor):
164
- Feat indices.
165
- If feat is ``'char'`` or ``'bert'``, the size of feats should be ``[batch_size, seq_len, fix_len]``.
166
- if ``'tag'``, the size is `` [batch_size, seq_len]``.
179
+ feats (list[ ~torch.LongTensor] ):
180
+ A list of feat indices.
181
+ The size of indices is ``[batch_size, seq_len, fix_len]`` if feat is ``'char'`` or ``'bert'``,
182
+ or ``[batch_size, seq_len]`` otherwise .
167
183
168
184
Returns:
169
185
~torch.Tensor, ~torch.Tensor:
@@ -172,7 +188,7 @@ def forward(self, words, feats):
172
188
scores of all possible labels on each edge.
173
189
"""
174
190
175
- batch_size , seq_len = words .shape
191
+ _ , seq_len = words .shape
176
192
# get the mask and lengths of given batch
177
193
mask = words .ne (self .pad_index )
178
194
ext_words = words
@@ -186,8 +202,16 @@ def forward(self, words, feats):
186
202
if hasattr (self , 'pretrained' ):
187
203
word_embed = torch .cat ((word_embed , self .embed_proj (self .pretrained (words ))), - 1 )
188
204
189
- feat_embed = self .feat_embed (feats )
190
- word_embed , feat_embed = self .embed_dropout (word_embed , feat_embed )
205
+ feat_embeds = []
206
+ if 'tag' in self .args .feat :
207
+ feat_embeds .append (self .tag_embed (feats .pop ()))
208
+ if 'char' in self .args .feat :
209
+ feat_embeds .append (self .char_embed (feats .pop (0 )))
210
+ if 'bert' in self .args .feat :
211
+ feat_embeds .append (self .bert_embed (feats .pop (0 )))
212
+ if 'lemma' in self .args .feat :
213
+ feat_embeds .append (self .lemma_embed (feats .pop (0 )))
214
+ word_embed , feat_embed = self .embed_dropout (word_embed , torch .cat (feat_embeds , - 1 ))
191
215
# concatenate the word and feat representations
192
216
embed = torch .cat ((word_embed , feat_embed ), - 1 )
193
217
0 commit comments