Skip to content

Commit ba8dd54

Browse files
committed
TAG/CHAR/LEMMA features
1 parent ff17441 commit ba8dd54

File tree

3 files changed

+87
-56
lines changed

3 files changed

+87
-56
lines changed

supar/cmds/biaffine_semantic_dependency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def main():
1212
subparsers = parser.add_subparsers(title='Commands', dest='mode')
1313
# train
1414
subparser = subparsers.add_parser('train', help='Train a parser.')
15-
subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'bert'], help='choices of additional features')
15+
subparser.add_argument('--feat', '-f', default='tag,char,lemma', help='additional features to use,separated by commas.')
1616
subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
1717
subparser.add_argument('--max-len', type=int, help='max length of the sentences')
1818
subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')

supar/models/semantic_dependency.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,21 @@ class BiaffineSemanticDependencyModel(nn.Module):
1919
Args:
2020
n_words (int):
2121
The size of the word vocabulary.
22-
n_feats (int):
23-
The size of the feat vocabulary.
2422
n_labels (int):
2523
The number of labels in the treebank.
24+
n_tags (int):
25+
The number of POS tags, needed if POS tag embeddings are used. Default: ``None``.
26+
n_chars (int):
27+
The number of characters, needed if character-level representations are used. Default: ``None``.
28+
n_lemmas (int):
29+
The number of lemmas, needed if lemma embeddings are used. Default: ``None``.
2630
feat (str):
27-
Specifies which type of additional feature to use: ``'char'`` | ``'bert'`` | ``'tag'``.
31+
Additional features to use,separated by commas.
32+
``'tag'``: POS tag embeddings.
2833
``'char'``: Character-level representations extracted by CharLSTM.
34+
``'lemma'``: Lemma embeddings.
2935
``'bert'``: BERT representations, other pretrained langugae models like XLNet are also feasible.
30-
``'tag'``: POS tag embeddings.
31-
Default: ``'char'``.
36+
Default: ``'tag,char,lemma'``.
3237
n_embed (int):
3338
The size of word embeddings. Default: 100.
3439
n_embed_proj (int):
@@ -37,6 +42,8 @@ class BiaffineSemanticDependencyModel(nn.Module):
3742
The size of feature representations. Default: 100.
3843
n_char_embed (int):
3944
The size of character embeddings serving as inputs of CharLSTM, required if ``feat='char'``. Default: 50.
45+
char_pad_index (int):
46+
The index of the padding token in the character vocabulary. Default: 0.
4047
bert (str):
4148
Specifies which kind of language model to use, e.g., ``'bert-base-cased'`` and ``'xlnet-base-cased'``.
4249
This is required if ``feat='bert'``. The full list can be found in `transformers`_.
@@ -47,6 +54,8 @@ class BiaffineSemanticDependencyModel(nn.Module):
4754
Default: 4.
4855
mix_dropout (float):
4956
The dropout ratio of BERT layers. Required if ``feat='bert'``. Default: .0.
57+
bert_pad_index (int):
58+
The index of the padding token in the BERT vocabulary. Default: 0.
5059
embed_dropout (float):
5160
The dropout ratio of input embeddings. Default: .2.
5261
n_lstm_hidden (int):
@@ -63,8 +72,8 @@ class BiaffineSemanticDependencyModel(nn.Module):
6372
The dropout ratio of edge MLP layers. Default: .25.
6473
label_mlp_dropout (float):
6574
The dropout ratio of label MLP layers. Default: .33.
66-
feat_pad_index (int):
67-
The index of the padding token in the feat vocabulary. Default: 0.
75+
interpolation (int):
76+
Constant to even out the label/edge loss. Default: .1.
6877
pad_index (int):
6978
The index of the padding token in the word vocabulary. Default: 0.
7079
unk_index (int):
@@ -78,16 +87,20 @@ class BiaffineSemanticDependencyModel(nn.Module):
7887

7988
def __init__(self,
8089
n_words,
81-
n_feats,
8290
n_labels,
83-
feat='char',
91+
n_tags=None,
92+
n_chars=None,
93+
n_lemmas=None,
94+
feat='tag,char,lemma',
8495
n_embed=100,
8596
n_embed_proj=125,
8697
n_feat_embed=100,
8798
n_char_embed=50,
99+
char_pad_index=0,
88100
bert=None,
89101
n_bert_layers=4,
90102
mix_dropout=.0,
103+
bert_pad_index=0,
91104
embed_dropout=.2,
92105
n_lstm_hidden=600,
93106
n_lstm_layers=3,
@@ -96,10 +109,9 @@ def __init__(self,
96109
n_mlp_label=600,
97110
edge_mlp_dropout=.25,
98111
label_mlp_dropout=.33,
99-
feat_pad_index=0,
112+
interpolation=0.1,
100113
pad_index=0,
101114
unk_index=1,
102-
interpolation=0.1,
103115
**kwargs):
104116
super().__init__()
105117

@@ -109,27 +121,31 @@ def __init__(self,
109121
embedding_dim=n_embed)
110122
self.embed_proj = nn.Linear(n_embed, n_embed_proj)
111123

112-
if feat == 'char':
113-
self.feat_embed = CharLSTM(n_chars=n_feats,
124+
self.n_input = n_embed + n_embed_proj
125+
if 'tag' in feat:
126+
self.tag_embed = nn.Embedding(num_embeddings=n_tags,
127+
embedding_dim=n_feat_embed)
128+
self.n_input += n_feat_embed
129+
if 'char' in feat:
130+
self.char_embed = CharLSTM(n_chars=n_chars,
114131
n_embed=n_char_embed,
115132
n_out=n_feat_embed,
116-
pad_index=feat_pad_index)
117-
elif feat == 'bert':
118-
self.feat_embed = BertEmbedding(model=bert,
133+
pad_index=char_pad_index)
134+
self.n_input += n_feat_embed
135+
if 'lemma' in feat:
136+
self.lemma_embed = nn.Embedding(num_embeddings=n_lemmas,
137+
embedding_dim=n_feat_embed)
138+
self.n_input += n_feat_embed
139+
if 'bert' in feat:
140+
self.bert_embed = BertEmbedding(model=bert,
119141
n_layers=n_bert_layers,
120-
n_out=n_feat_embed,
121-
pad_index=feat_pad_index,
142+
pad_index=bert_pad_index,
122143
dropout=mix_dropout)
123-
self.n_feat_embed = self.feat_embed.n_out
124-
elif feat == 'tag':
125-
self.feat_embed = nn.Embedding(num_embeddings=n_feats,
126-
embedding_dim=n_feat_embed)
127-
else:
128-
raise RuntimeError("The feat type should be in ['char', 'bert', 'tag'].")
144+
self.n_input += self.bert_embed.n_out
129145
self.embed_dropout = IndependentDropout(p=embed_dropout)
130146

131147
# the lstm layer
132-
self.lstm = LSTM(input_size=n_embed+n_feat_embed+n_embed_proj,
148+
self.lstm = LSTM(input_size=self.n_input,
133149
hidden_size=n_lstm_hidden,
134150
num_layers=n_lstm_layers,
135151
bidirectional=True,
@@ -146,9 +162,9 @@ def __init__(self,
146162
self.edge_attn = Biaffine(n_in=n_mlp_edge, n_out=2, bias_x=True, bias_y=True)
147163
self.label_attn = Biaffine(n_in=n_mlp_label, n_out=n_labels, bias_x=True, bias_y=True)
148164
self.criterion = nn.CrossEntropyLoss()
165+
self.interpolation = interpolation
149166
self.pad_index = pad_index
150167
self.unk_index = unk_index
151-
self.interpolation = interpolation
152168

153169
def load_pretrained(self, embed=None):
154170
if embed is not None:
@@ -160,10 +176,10 @@ def forward(self, words, feats):
160176
Args:
161177
words (~torch.LongTensor): ``[batch_size, seq_len]``.
162178
Word indices.
163-
feats (~torch.LongTensor):
164-
Feat indices.
165-
If feat is ``'char'`` or ``'bert'``, the size of feats should be ``[batch_size, seq_len, fix_len]``.
166-
if ``'tag'``, the size is ``[batch_size, seq_len]``.
179+
feats (list[~torch.LongTensor]):
180+
A list of feat indices.
181+
The size of indices is ``[batch_size, seq_len, fix_len]`` if feat is ``'char'`` or ``'bert'``,
182+
or ``[batch_size, seq_len]`` otherwise.
167183
168184
Returns:
169185
~torch.Tensor, ~torch.Tensor:
@@ -172,7 +188,7 @@ def forward(self, words, feats):
172188
scores of all possible labels on each edge.
173189
"""
174190

175-
batch_size, seq_len = words.shape
191+
_, seq_len = words.shape
176192
# get the mask and lengths of given batch
177193
mask = words.ne(self.pad_index)
178194
ext_words = words
@@ -186,8 +202,16 @@ def forward(self, words, feats):
186202
if hasattr(self, 'pretrained'):
187203
word_embed = torch.cat((word_embed, self.embed_proj(self.pretrained(words))), -1)
188204

189-
feat_embed = self.feat_embed(feats)
190-
word_embed, feat_embed = self.embed_dropout(word_embed, feat_embed)
205+
feat_embeds = []
206+
if 'tag' in self.args.feat:
207+
feat_embeds.append(self.tag_embed(feats.pop()))
208+
if 'char' in self.args.feat:
209+
feat_embeds.append(self.char_embed(feats.pop(0)))
210+
if 'bert' in self.args.feat:
211+
feat_embeds.append(self.bert_embed(feats.pop(0)))
212+
if 'lemma' in self.args.feat:
213+
feat_embeds.append(self.lemma_embed(feats.pop(0)))
214+
word_embed, feat_embed = self.embed_dropout(word_embed, torch.cat(feat_embeds, -1))
191215
# concatenate the word and feat representations
192216
embed = torch.cat((word_embed, feat_embed), -1)
193217

supar/parsers/semantic_dependency.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,9 @@ class BiaffineSemanticDependencyParser(Parser):
3636
def __init__(self, *args, **kwargs):
3737
super().__init__(*args, **kwargs)
3838

39-
if self.args.feat in ('char', 'bert'):
40-
self.WORD, self.FEAT = self.transform.FORM
41-
else:
42-
self.WORD, self.FEAT = self.transform.FORM, self.transform.POS
39+
self.WORD, self.CHAR, self.BERT = self.transform.FORM
40+
self.LEMMA = self.transform.LEMMA
41+
self.TAG = self.transform.POS
4342
self.EDGE, self.LABEL = self.transform.PHEAD
4443

4544
def train(self, train, dev, test, buckets=32, batch_size=5000, verbose=True, **kwargs):
@@ -108,7 +107,7 @@ def _train(self, loader):
108107

109108
bar, metric = progress_bar(loader), ChartMetric()
110109

111-
for words, feats, edges, labels in bar:
110+
for words, *feats, edges, labels in bar:
112111
self.optimizer.zero_grad()
113112

114113
mask = words.ne(self.WORD.pad_index)
@@ -132,7 +131,7 @@ def _evaluate(self, loader):
132131

133132
total_loss, metric = 0, ChartMetric()
134133

135-
for words, feats, edges, labels in loader:
134+
for words, *feats, edges, labels in loader:
136135
mask = words.ne(self.WORD.pad_index)
137136
mask = mask.unsqueeze(1) & mask.unsqueeze(2)
138137
mask[:, 0] = 0
@@ -153,7 +152,7 @@ def _predict(self, loader):
153152

154153
preds = {}
155154
charts, probs = [], []
156-
for words, feats in progress_bar(loader):
155+
for words, *feats in progress_bar(loader):
157156
mask = words.ne(self.WORD.pad_index)
158157
mask = mask.unsqueeze(1) & mask.unsqueeze(2)
159158
mask[:, 0] = 0
@@ -211,38 +210,46 @@ def build(cls,
211210

212211
logger.info("Building the fields")
213212
WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)
214-
if args.feat == 'char':
215-
FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=args.fix_len)
216-
elif args.feat == 'bert':
213+
TAG, CHAR, LEMMA, BERT = None, None, None, None
214+
if 'tag' in args.feat:
215+
TAG = Field('tags', bos=bos)
216+
if 'char' in args.feat:
217+
CHAR = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=args.fix_len)
218+
if 'lemma' in args.feat:
219+
LEMMA = Field('lemmas', pad=pad, unk=unk, bos=bos, lower=True)
220+
if 'bert' in args.feat:
217221
from transformers import AutoTokenizer
218222
tokenizer = AutoTokenizer.from_pretrained(args.bert)
219-
FEAT = SubwordField('bert',
223+
BERT = SubwordField('bert',
220224
pad=tokenizer.pad_token,
221225
unk=tokenizer.unk_token,
222226
bos=tokenizer.bos_token or tokenizer.cls_token,
223227
fix_len=args.fix_len,
224228
tokenize=tokenizer.tokenize)
225-
FEAT.vocab = tokenizer.get_vocab()
226-
else:
227-
FEAT = Field('tags', bos=bos)
229+
BERT.vocab = tokenizer.get_vocab()
228230
EDGE = ChartField('edges', use_vocab=False, fn=CoNLL.get_edges)
229231
LABEL = ChartField('labels', fn=CoNLL.get_labels)
230-
if args.feat in ('char', 'bert'):
231-
transform = CoNLL(FORM=(WORD, FEAT), PHEAD=(EDGE, LABEL))
232-
else:
233-
transform = CoNLL(FORM=WORD, POS=FEAT, PHEAD=(EDGE, LABEL))
232+
transform = CoNLL(FORM=(WORD, CHAR, BERT), LEMMA=LEMMA, POS=TAG, PHEAD=(EDGE, LABEL))
234233

235234
train = Dataset(transform, args.train)
236235
WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None))
237-
FEAT.build(train)
236+
if TAG is not None:
237+
TAG.build(train)
238+
if CHAR is not None:
239+
CHAR.build(train)
240+
if LEMMA is not None:
241+
LEMMA.build(train)
238242
LABEL.build(train)
239243
args.update({
240244
'n_words': WORD.vocab.n_init,
241-
'n_feats': len(FEAT.vocab),
242245
'n_labels': len(LABEL.vocab),
246+
'n_tags': len(TAG.vocab) if TAG is not None else None,
247+
'n_chars': len(CHAR.vocab) if CHAR is not None else None,
248+
'char_pad_index': CHAR.pad_index if CHAR is not None else None,
249+
'n_lemmas': len(LEMMA.vocab) if LEMMA is not None else None,
250+
'bert_pad_index': BERT.pad_index if BERT is not None else None,
243251
'pad_index': WORD.pad_index,
244-
'unk_index': WORD.unk_index,
245-
'feat_pad_index': FEAT.pad_index
252+
'unk_index': WORD.unk_index
246253
})
247254
logger.info(f"{transform}")
248255

0 commit comments

Comments
 (0)