|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from supar.modules import LSTM, MLP, BertEmbedding, Biaffine, CharLSTM |
| 6 | +from supar.modules.dropout import IndependentDropout, SharedDropout |
| 7 | +from supar.utils import Config |
| 8 | +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
| 9 | + |
| 10 | + |
| 11 | +class BiaffineSemanticDependencyModel(nn.Module): |
| 12 | + r""" |
| 13 | + The implementation of Biaffine Semantic Dependency Parser. |
| 14 | +
|
| 15 | + References: |
| 16 | + - Timothy Dozat and Christopher D. Manning. 2018. |
| 17 | + `Simpler but More Accurate Semantic Dependency Parsing`_. |
| 18 | +
|
| 19 | + Args: |
| 20 | + n_words (int): |
| 21 | + The size of the word vocabulary. |
| 22 | + n_feats (int): |
| 23 | + The size of the feat vocabulary. |
| 24 | + n_labels (int): |
| 25 | + The number of labels in the treebank. |
| 26 | + feat (str): |
| 27 | + Specifies which type of additional feature to use: ``'char'`` | ``'bert'`` | ``'tag'``. |
| 28 | + ``'char'``: Character-level representations extracted by CharLSTM. |
| 29 | + ``'bert'``: BERT representations, other pretrained langugae models like XLNet are also feasible. |
| 30 | + ``'tag'``: POS tag embeddings. |
| 31 | + Default: ``'char'``. |
| 32 | + n_embed (int): |
| 33 | + The size of word embeddings. Default: 100. |
| 34 | + n_embed_proj (int): |
| 35 | + The size of linearly transformed word embeddings. Default: 125. |
| 36 | + n_feat_embed (int): |
| 37 | + The size of feature representations. Default: 100. |
| 38 | + n_char_embed (int): |
| 39 | + The size of character embeddings serving as inputs of CharLSTM, required if ``feat='char'``. Default: 50. |
| 40 | + bert (str): |
| 41 | + Specifies which kind of language model to use, e.g., ``'bert-base-cased'`` and ``'xlnet-base-cased'``. |
| 42 | + This is required if ``feat='bert'``. The full list can be found in `transformers`_. |
| 43 | + Default: ``None``. |
| 44 | + n_bert_layers (int): |
| 45 | + Specifies how many last layers to use. Required if ``feat='bert'``. |
| 46 | + The final outputs would be the weight sum of the hidden states of these layers. |
| 47 | + Default: 4. |
| 48 | + mix_dropout (float): |
| 49 | + The dropout ratio of BERT layers. Required if ``feat='bert'``. Default: .0. |
| 50 | + embed_dropout (float): |
| 51 | + The dropout ratio of input embeddings. Default: .2. |
| 52 | + n_lstm_hidden (int): |
| 53 | + The size of LSTM hidden states. Default: 600. |
| 54 | + n_lstm_layers (int): |
| 55 | + The number of LSTM layers. Default: 3. |
| 56 | + lstm_dropout (float): |
| 57 | + The dropout ratio of LSTM. Default: .33. |
| 58 | + n_mlp_edge (int): |
| 59 | + Edge MLP size. Default: 600. |
| 60 | + n_mlp_label (int): |
| 61 | + Label MLP size. Default: 600. |
| 62 | + edge_mlp_dropout (float): |
| 63 | + The dropout ratio of edge MLP layers. Default: .25. |
| 64 | + label_mlp_dropout (float): |
| 65 | + The dropout ratio of label MLP layers. Default: .33. |
| 66 | + feat_pad_index (int): |
| 67 | + The index of the padding token in the feat vocabulary. Default: 0. |
| 68 | + pad_index (int): |
| 69 | + The index of the padding token in the word vocabulary. Default: 0. |
| 70 | + unk_index (int): |
| 71 | + The index of the unknown token in the word vocabulary. Default: 1. |
| 72 | +
|
| 73 | + .. _Simpler but More Accurate Semantic Dependency Parsing: |
| 74 | + https://www.aclweb.org/anthology/P18-2077/ |
| 75 | + .. _transformers: |
| 76 | + https://github.com/huggingface/transformers |
| 77 | + """ |
| 78 | + |
| 79 | + def __init__(self, |
| 80 | + n_words, |
| 81 | + n_feats, |
| 82 | + n_labels, |
| 83 | + feat='char', |
| 84 | + n_embed=100, |
| 85 | + n_embed_proj=125, |
| 86 | + n_feat_embed=100, |
| 87 | + n_char_embed=50, |
| 88 | + bert=None, |
| 89 | + n_bert_layers=4, |
| 90 | + mix_dropout=.0, |
| 91 | + embed_dropout=.2, |
| 92 | + n_lstm_hidden=600, |
| 93 | + n_lstm_layers=3, |
| 94 | + lstm_dropout=.33, |
| 95 | + n_mlp_edge=600, |
| 96 | + n_mlp_label=600, |
| 97 | + edge_mlp_dropout=.25, |
| 98 | + label_mlp_dropout=.33, |
| 99 | + feat_pad_index=0, |
| 100 | + pad_index=0, |
| 101 | + unk_index=1, |
| 102 | + interpolation=0.1, |
| 103 | + **kwargs): |
| 104 | + super().__init__() |
| 105 | + |
| 106 | + self.args = Config().update(locals()) |
| 107 | + # the embedding layer |
| 108 | + self.word_embed = nn.Embedding(num_embeddings=n_words, |
| 109 | + embedding_dim=n_embed) |
| 110 | + self.embed_proj = nn.Linear(n_embed, n_embed_proj) |
| 111 | + |
| 112 | + if feat == 'char': |
| 113 | + self.feat_embed = CharLSTM(n_chars=n_feats, |
| 114 | + n_embed=n_char_embed, |
| 115 | + n_out=n_feat_embed, |
| 116 | + pad_index=feat_pad_index) |
| 117 | + elif feat == 'bert': |
| 118 | + self.feat_embed = BertEmbedding(model=bert, |
| 119 | + n_layers=n_bert_layers, |
| 120 | + n_out=n_feat_embed, |
| 121 | + pad_index=feat_pad_index, |
| 122 | + dropout=mix_dropout) |
| 123 | + self.n_feat_embed = self.feat_embed.n_out |
| 124 | + elif feat == 'tag': |
| 125 | + self.feat_embed = nn.Embedding(num_embeddings=n_feats, |
| 126 | + embedding_dim=n_feat_embed) |
| 127 | + else: |
| 128 | + raise RuntimeError("The feat type should be in ['char', 'bert', 'tag'].") |
| 129 | + self.embed_dropout = IndependentDropout(p=embed_dropout) |
| 130 | + |
| 131 | + # the lstm layer |
| 132 | + self.lstm = LSTM(input_size=n_embed+n_feat_embed+n_embed_proj, |
| 133 | + hidden_size=n_lstm_hidden, |
| 134 | + num_layers=n_lstm_layers, |
| 135 | + bidirectional=True, |
| 136 | + dropout=lstm_dropout) |
| 137 | + self.lstm_dropout = SharedDropout(p=lstm_dropout) |
| 138 | + |
| 139 | + # the MLP layers |
| 140 | + self.mlp_edge_d = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_edge, dropout=edge_mlp_dropout, activation=False) |
| 141 | + self.mlp_edge_h = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_edge, dropout=edge_mlp_dropout, activation=False) |
| 142 | + self.mlp_label_d = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_label, dropout=label_mlp_dropout, activation=False) |
| 143 | + self.mlp_label_h = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_label, dropout=label_mlp_dropout, activation=False) |
| 144 | + |
| 145 | + # the Biaffine layers |
| 146 | + self.edge_attn = Biaffine(n_in=n_mlp_edge, n_out=2, bias_x=True, bias_y=True) |
| 147 | + self.label_attn = Biaffine(n_in=n_mlp_label, n_out=n_labels, bias_x=True, bias_y=True) |
| 148 | + self.criterion = nn.CrossEntropyLoss() |
| 149 | + self.pad_index = pad_index |
| 150 | + self.unk_index = unk_index |
| 151 | + self.interpolation = interpolation |
| 152 | + |
| 153 | + def load_pretrained(self, embed=None): |
| 154 | + if embed is not None: |
| 155 | + self.pretrained = nn.Embedding.from_pretrained(embed) |
| 156 | + return self |
| 157 | + |
| 158 | + def forward(self, words, feats): |
| 159 | + r""" |
| 160 | + Args: |
| 161 | + words (~torch.LongTensor): ``[batch_size, seq_len]``. |
| 162 | + Word indices. |
| 163 | + feats (~torch.LongTensor): |
| 164 | + Feat indices. |
| 165 | + If feat is ``'char'`` or ``'bert'``, the size of feats should be ``[batch_size, seq_len, fix_len]``. |
| 166 | + if ``'tag'``, the size is ``[batch_size, seq_len]``. |
| 167 | +
|
| 168 | + Returns: |
| 169 | + ~torch.Tensor, ~torch.Tensor: |
| 170 | + The first tensor of shape ``[batch_size, seq_len, seq_len, 2]`` holds scores of all possible edges. |
| 171 | + The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds |
| 172 | + scores of all possible labels on each edge. |
| 173 | + """ |
| 174 | + |
| 175 | + batch_size, seq_len = words.shape |
| 176 | + # get the mask and lengths of given batch |
| 177 | + mask = words.ne(self.pad_index) |
| 178 | + ext_words = words |
| 179 | + # set the indices larger than num_embeddings to unk_index |
| 180 | + if hasattr(self, 'pretrained'): |
| 181 | + ext_mask = words.ge(self.word_embed.num_embeddings) |
| 182 | + ext_words = words.masked_fill(ext_mask, self.unk_index) |
| 183 | + |
| 184 | + # get outputs from embedding layers |
| 185 | + word_embed = self.word_embed(ext_words) |
| 186 | + if hasattr(self, 'pretrained'): |
| 187 | + word_embed = torch.cat((word_embed, self.embed_proj(self.pretrained(words))), -1) |
| 188 | + |
| 189 | + feat_embed = self.feat_embed(feats) |
| 190 | + word_embed, feat_embed = self.embed_dropout(word_embed, feat_embed) |
| 191 | + # concatenate the word and feat representations |
| 192 | + embed = torch.cat((word_embed, feat_embed), -1) |
| 193 | + |
| 194 | + x = pack_padded_sequence(embed, mask.sum(1), True, False) |
| 195 | + x, _ = self.lstm(x) |
| 196 | + x, _ = pad_packed_sequence(x, True, total_length=seq_len) |
| 197 | + x = self.lstm_dropout(x) |
| 198 | + |
| 199 | + # apply MLPs to the BiLSTM output states |
| 200 | + edge_d = self.mlp_edge_d(x) |
| 201 | + edge_h = self.mlp_edge_h(x) |
| 202 | + label_d = self.mlp_label_d(x) |
| 203 | + label_h = self.mlp_label_h(x) |
| 204 | + |
| 205 | + # [batch_size, seq_len, seq_len, 2] |
| 206 | + s_egde = self.edge_attn(edge_d, edge_h).permute(0, 2, 3, 1) |
| 207 | + # [batch_size, seq_len, seq_len, n_labels] |
| 208 | + s_label = self.label_attn(label_d, label_h).permute(0, 2, 3, 1) |
| 209 | + |
| 210 | + return s_egde, s_label |
| 211 | + |
| 212 | + def loss(self, s_egde, s_label, edges, labels, mask): |
| 213 | + r""" |
| 214 | + Args: |
| 215 | + s_egde (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``. |
| 216 | + Scores of all possible edges. |
| 217 | + s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. |
| 218 | + Scores of all possible labels on each edge. |
| 219 | + edges (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. |
| 220 | + The tensor of gold-standard edges. |
| 221 | + labels (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. |
| 222 | + The tensor of gold-standard labels. |
| 223 | + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. |
| 224 | + The mask for covering the unpadded tokens. |
| 225 | +
|
| 226 | + Returns: |
| 227 | + ~torch.Tensor: |
| 228 | + The training loss. |
| 229 | + """ |
| 230 | + |
| 231 | + edge_mask = edges.gt(0) & mask |
| 232 | + edge_loss = self.criterion(s_egde[mask], edges[mask]) |
| 233 | + label_loss = self.criterion(s_label[edge_mask], labels[edge_mask]) |
| 234 | + return self.interpolation * label_loss + (1 - self.interpolation) * edge_loss |
| 235 | + |
| 236 | + def decode(self, s_egde, s_label): |
| 237 | + r""" |
| 238 | + Args: |
| 239 | + s_egde (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``. |
| 240 | + Scores of all possible edges. |
| 241 | + s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. |
| 242 | + Scores of all possible labels on each edge. |
| 243 | +
|
| 244 | + Returns: |
| 245 | + ~torch.Tensor, ~torch.Tensor: |
| 246 | + Predicted edges and labels of shape ``[batch_size, seq_len, seq_len]``. |
| 247 | + """ |
| 248 | + |
| 249 | + return s_egde.argmax(-1), s_label.argmax(-1) |
0 commit comments