Skip to content

Commit ad6236d

Browse files
committed
Biaffine Semantic Dependency Parser
1 parent b536923 commit ad6236d

File tree

7 files changed

+595
-5
lines changed

7 files changed

+595
-5
lines changed

supar/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22

3-
from .parsers import (BiaffineDependencyParser, CRF2oDependencyParser,
3+
from .parsers import (BiaffineDependencyParser,
4+
BiaffineSemanticDependencyParser, CRF2oDependencyParser,
45
CRFConstituencyParser, CRFDependencyParser,
56
CRFNPDependencyParser, Parser)
67

@@ -9,14 +10,16 @@
910
'CRFDependencyParser',
1011
'CRF2oDependencyParser',
1112
'CRFConstituencyParser',
13+
'BiaffineSemanticDependencyParser',
1214
'Parser']
1315
__version__ = '1.0.0'
1416

1517
PARSER = {parser.NAME: parser for parser in [BiaffineDependencyParser,
1618
CRFNPDependencyParser,
1719
CRFDependencyParser,
1820
CRF2oDependencyParser,
19-
CRFConstituencyParser]}
21+
CRFConstituencyParser,
22+
BiaffineSemanticDependencyParser]}
2023

2124
PRETRAINED = {
2225
'biaffine-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.biaffine.dependency.char.zip',
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import argparse
4+
5+
from supar import BiaffineSemanticDependencyParser
6+
from supar.cmds.cmd import parse
7+
8+
9+
def main():
10+
parser = argparse.ArgumentParser(description='Create Biaffine Semantic Dependency Parser.')
11+
parser.set_defaults(Parser=BiaffineSemanticDependencyParser)
12+
subparsers = parser.add_subparsers(title='Commands', dest='mode')
13+
# train
14+
subparser = subparsers.add_parser('train', help='Train a parser.')
15+
subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'bert'], help='choices of additional features')
16+
subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
17+
subparser.add_argument('--max-len', type=int, help='max length of the sentences')
18+
subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
19+
subparser.add_argument('--train', default='data/sdp/DM/train.conllu', help='path to train file')
20+
subparser.add_argument('--dev', default='data/sdp/DM/dev.conllu', help='path to dev file')
21+
subparser.add_argument('--test', default='data/sdp/DM/test.conllu', help='path to test file')
22+
subparser.add_argument('--embed', default='data/glove.6B.100d.txt', help='path to pretrained embeddings')
23+
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
24+
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
25+
subparser.add_argument('--bert', default='bert-base-cased', help='which bert model to use')
26+
# evaluate
27+
subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
28+
subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
29+
subparser.add_argument('--data', default='data/sdp/DM/test.conllu', help='path to dataset')
30+
# predict
31+
subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
32+
subparser.add_argument('--prob', action='store_true', help='whether to output probs')
33+
subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
34+
subparser.add_argument('--data', default='data/sdp/DM/test.conllu', help='path to dataset')
35+
subparser.add_argument('--pred', default='pred.conllu', help='path to predicted result')
36+
parse(parser)
37+
38+
39+
if __name__ == "__main__":
40+
main()

supar/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from .constituency import CRFConstituencyModel
44
from .dependency import (BiaffineDependencyModel, CRF2oDependencyModel,
55
CRFDependencyModel, CRFNPDependencyModel)
6+
from .semantic_dependency import BiaffineSemanticDependencyModel
67

78
__all__ = ['BiaffineDependencyModel',
89
'CRFDependencyModel',
910
'CRF2oDependencyModel',
1011
'CRFNPDependencyModel',
11-
'CRFConstituencyModel']
12+
'CRFConstituencyModel',
13+
'BiaffineSemanticDependencyModel']

supar/models/semantic_dependency.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import torch
4+
import torch.nn as nn
5+
from supar.modules import LSTM, MLP, BertEmbedding, Biaffine, CharLSTM
6+
from supar.modules.dropout import IndependentDropout, SharedDropout
7+
from supar.utils import Config
8+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
9+
10+
11+
class BiaffineSemanticDependencyModel(nn.Module):
12+
r"""
13+
The implementation of Biaffine Semantic Dependency Parser.
14+
15+
References:
16+
- Timothy Dozat and Christopher D. Manning. 2018.
17+
`Simpler but More Accurate Semantic Dependency Parsing`_.
18+
19+
Args:
20+
n_words (int):
21+
The size of the word vocabulary.
22+
n_feats (int):
23+
The size of the feat vocabulary.
24+
n_labels (int):
25+
The number of labels in the treebank.
26+
feat (str):
27+
Specifies which type of additional feature to use: ``'char'`` | ``'bert'`` | ``'tag'``.
28+
``'char'``: Character-level representations extracted by CharLSTM.
29+
``'bert'``: BERT representations, other pretrained langugae models like XLNet are also feasible.
30+
``'tag'``: POS tag embeddings.
31+
Default: ``'char'``.
32+
n_embed (int):
33+
The size of word embeddings. Default: 100.
34+
n_embed_proj (int):
35+
The size of linearly transformed word embeddings. Default: 125.
36+
n_feat_embed (int):
37+
The size of feature representations. Default: 100.
38+
n_char_embed (int):
39+
The size of character embeddings serving as inputs of CharLSTM, required if ``feat='char'``. Default: 50.
40+
bert (str):
41+
Specifies which kind of language model to use, e.g., ``'bert-base-cased'`` and ``'xlnet-base-cased'``.
42+
This is required if ``feat='bert'``. The full list can be found in `transformers`_.
43+
Default: ``None``.
44+
n_bert_layers (int):
45+
Specifies how many last layers to use. Required if ``feat='bert'``.
46+
The final outputs would be the weight sum of the hidden states of these layers.
47+
Default: 4.
48+
mix_dropout (float):
49+
The dropout ratio of BERT layers. Required if ``feat='bert'``. Default: .0.
50+
embed_dropout (float):
51+
The dropout ratio of input embeddings. Default: .2.
52+
n_lstm_hidden (int):
53+
The size of LSTM hidden states. Default: 600.
54+
n_lstm_layers (int):
55+
The number of LSTM layers. Default: 3.
56+
lstm_dropout (float):
57+
The dropout ratio of LSTM. Default: .33.
58+
n_mlp_edge (int):
59+
Edge MLP size. Default: 600.
60+
n_mlp_label (int):
61+
Label MLP size. Default: 600.
62+
edge_mlp_dropout (float):
63+
The dropout ratio of edge MLP layers. Default: .25.
64+
label_mlp_dropout (float):
65+
The dropout ratio of label MLP layers. Default: .33.
66+
feat_pad_index (int):
67+
The index of the padding token in the feat vocabulary. Default: 0.
68+
pad_index (int):
69+
The index of the padding token in the word vocabulary. Default: 0.
70+
unk_index (int):
71+
The index of the unknown token in the word vocabulary. Default: 1.
72+
73+
.. _Simpler but More Accurate Semantic Dependency Parsing:
74+
https://www.aclweb.org/anthology/P18-2077/
75+
.. _transformers:
76+
https://github.com/huggingface/transformers
77+
"""
78+
79+
def __init__(self,
80+
n_words,
81+
n_feats,
82+
n_labels,
83+
feat='char',
84+
n_embed=100,
85+
n_embed_proj=125,
86+
n_feat_embed=100,
87+
n_char_embed=50,
88+
bert=None,
89+
n_bert_layers=4,
90+
mix_dropout=.0,
91+
embed_dropout=.2,
92+
n_lstm_hidden=600,
93+
n_lstm_layers=3,
94+
lstm_dropout=.33,
95+
n_mlp_edge=600,
96+
n_mlp_label=600,
97+
edge_mlp_dropout=.25,
98+
label_mlp_dropout=.33,
99+
feat_pad_index=0,
100+
pad_index=0,
101+
unk_index=1,
102+
interpolation=0.1,
103+
**kwargs):
104+
super().__init__()
105+
106+
self.args = Config().update(locals())
107+
# the embedding layer
108+
self.word_embed = nn.Embedding(num_embeddings=n_words,
109+
embedding_dim=n_embed)
110+
self.embed_proj = nn.Linear(n_embed, n_embed_proj)
111+
112+
if feat == 'char':
113+
self.feat_embed = CharLSTM(n_chars=n_feats,
114+
n_embed=n_char_embed,
115+
n_out=n_feat_embed,
116+
pad_index=feat_pad_index)
117+
elif feat == 'bert':
118+
self.feat_embed = BertEmbedding(model=bert,
119+
n_layers=n_bert_layers,
120+
n_out=n_feat_embed,
121+
pad_index=feat_pad_index,
122+
dropout=mix_dropout)
123+
self.n_feat_embed = self.feat_embed.n_out
124+
elif feat == 'tag':
125+
self.feat_embed = nn.Embedding(num_embeddings=n_feats,
126+
embedding_dim=n_feat_embed)
127+
else:
128+
raise RuntimeError("The feat type should be in ['char', 'bert', 'tag'].")
129+
self.embed_dropout = IndependentDropout(p=embed_dropout)
130+
131+
# the lstm layer
132+
self.lstm = LSTM(input_size=n_embed+n_feat_embed+n_embed_proj,
133+
hidden_size=n_lstm_hidden,
134+
num_layers=n_lstm_layers,
135+
bidirectional=True,
136+
dropout=lstm_dropout)
137+
self.lstm_dropout = SharedDropout(p=lstm_dropout)
138+
139+
# the MLP layers
140+
self.mlp_edge_d = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_edge, dropout=edge_mlp_dropout, activation=False)
141+
self.mlp_edge_h = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_edge, dropout=edge_mlp_dropout, activation=False)
142+
self.mlp_label_d = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_label, dropout=label_mlp_dropout, activation=False)
143+
self.mlp_label_h = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_label, dropout=label_mlp_dropout, activation=False)
144+
145+
# the Biaffine layers
146+
self.edge_attn = Biaffine(n_in=n_mlp_edge, n_out=2, bias_x=True, bias_y=True)
147+
self.label_attn = Biaffine(n_in=n_mlp_label, n_out=n_labels, bias_x=True, bias_y=True)
148+
self.criterion = nn.CrossEntropyLoss()
149+
self.pad_index = pad_index
150+
self.unk_index = unk_index
151+
self.interpolation = interpolation
152+
153+
def load_pretrained(self, embed=None):
154+
if embed is not None:
155+
self.pretrained = nn.Embedding.from_pretrained(embed)
156+
return self
157+
158+
def forward(self, words, feats):
159+
r"""
160+
Args:
161+
words (~torch.LongTensor): ``[batch_size, seq_len]``.
162+
Word indices.
163+
feats (~torch.LongTensor):
164+
Feat indices.
165+
If feat is ``'char'`` or ``'bert'``, the size of feats should be ``[batch_size, seq_len, fix_len]``.
166+
if ``'tag'``, the size is ``[batch_size, seq_len]``.
167+
168+
Returns:
169+
~torch.Tensor, ~torch.Tensor:
170+
The first tensor of shape ``[batch_size, seq_len, seq_len, 2]`` holds scores of all possible edges.
171+
The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds
172+
scores of all possible labels on each edge.
173+
"""
174+
175+
batch_size, seq_len = words.shape
176+
# get the mask and lengths of given batch
177+
mask = words.ne(self.pad_index)
178+
ext_words = words
179+
# set the indices larger than num_embeddings to unk_index
180+
if hasattr(self, 'pretrained'):
181+
ext_mask = words.ge(self.word_embed.num_embeddings)
182+
ext_words = words.masked_fill(ext_mask, self.unk_index)
183+
184+
# get outputs from embedding layers
185+
word_embed = self.word_embed(ext_words)
186+
if hasattr(self, 'pretrained'):
187+
word_embed = torch.cat((word_embed, self.embed_proj(self.pretrained(words))), -1)
188+
189+
feat_embed = self.feat_embed(feats)
190+
word_embed, feat_embed = self.embed_dropout(word_embed, feat_embed)
191+
# concatenate the word and feat representations
192+
embed = torch.cat((word_embed, feat_embed), -1)
193+
194+
x = pack_padded_sequence(embed, mask.sum(1), True, False)
195+
x, _ = self.lstm(x)
196+
x, _ = pad_packed_sequence(x, True, total_length=seq_len)
197+
x = self.lstm_dropout(x)
198+
199+
# apply MLPs to the BiLSTM output states
200+
edge_d = self.mlp_edge_d(x)
201+
edge_h = self.mlp_edge_h(x)
202+
label_d = self.mlp_label_d(x)
203+
label_h = self.mlp_label_h(x)
204+
205+
# [batch_size, seq_len, seq_len, 2]
206+
s_egde = self.edge_attn(edge_d, edge_h).permute(0, 2, 3, 1)
207+
# [batch_size, seq_len, seq_len, n_labels]
208+
s_label = self.label_attn(label_d, label_h).permute(0, 2, 3, 1)
209+
210+
return s_egde, s_label
211+
212+
def loss(self, s_egde, s_label, edges, labels, mask):
213+
r"""
214+
Args:
215+
s_egde (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``.
216+
Scores of all possible edges.
217+
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
218+
Scores of all possible labels on each edge.
219+
edges (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
220+
The tensor of gold-standard edges.
221+
labels (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
222+
The tensor of gold-standard labels.
223+
mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
224+
The mask for covering the unpadded tokens.
225+
226+
Returns:
227+
~torch.Tensor:
228+
The training loss.
229+
"""
230+
231+
edge_mask = edges.gt(0) & mask
232+
edge_loss = self.criterion(s_egde[mask], edges[mask])
233+
label_loss = self.criterion(s_label[edge_mask], labels[edge_mask])
234+
return self.interpolation * label_loss + (1 - self.interpolation) * edge_loss
235+
236+
def decode(self, s_egde, s_label):
237+
r"""
238+
Args:
239+
s_egde (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``.
240+
Scores of all possible edges.
241+
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
242+
Scores of all possible labels on each edge.
243+
244+
Returns:
245+
~torch.Tensor, ~torch.Tensor:
246+
Predicted edges and labels of shape ``[batch_size, seq_len, seq_len]``.
247+
"""
248+
249+
return s_egde.argmax(-1), s_label.argmax(-1)

supar/parsers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from .dependency import (BiaffineDependencyParser, CRF2oDependencyParser,
55
CRFDependencyParser, CRFNPDependencyParser)
66
from .parser import Parser
7+
from .semantic_dependency import BiaffineSemanticDependencyParser
78

89
__all__ = ['BiaffineDependencyParser',
910
'CRFNPDependencyParser',
1011
'CRFDependencyParser',
1112
'CRF2oDependencyParser',
1213
'CRFConstituencyParser',
14+
'BiaffineSemanticDependencyParser',
1315
'Parser']

0 commit comments

Comments
 (0)