1
1
# -*- coding: utf-8 -*-
2
2
3
3
import torch .nn as nn
4
+ import torch .nn .functional as F
4
5
from supar .models .model import Model
5
6
from supar .modules import MLP , Biaffine , Triaffine
6
7
from supar .structs import LBPSemanticDependency , MFVISemanticDependency
@@ -132,7 +133,7 @@ def __init__(self,
132
133
self .label_mlp_d = MLP (n_in = self .args .n_hidden , n_out = n_label_mlp , dropout = label_mlp_dropout , activation = False )
133
134
self .label_mlp_h = MLP (n_in = self .args .n_hidden , n_out = n_label_mlp , dropout = label_mlp_dropout , activation = False )
134
135
135
- self .edge_attn = Biaffine (n_in = n_edge_mlp , n_out = 2 , bias_x = True , bias_y = True )
136
+ self .edge_attn = Biaffine (n_in = n_edge_mlp , bias_x = True , bias_y = True )
136
137
self .label_attn = Biaffine (n_in = n_label_mlp , n_out = n_labels , bias_x = True , bias_y = True )
137
138
self .criterion = nn .CrossEntropyLoss ()
138
139
@@ -156,7 +157,7 @@ def forward(self, words, feats=None):
156
157
157
158
Returns:
158
159
~torch.Tensor, ~torch.Tensor:
159
- The first tensor of shape ``[batch_size, seq_len, seq_len, 2 ]`` holds scores of all possible edges.
160
+ The first tensor of shape ``[batch_size, seq_len, seq_len]`` holds scores of all possible edges.
160
161
The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds
161
162
scores of all possible labels on each edge.
162
163
"""
@@ -168,8 +169,8 @@ def forward(self, words, feats=None):
168
169
label_d = self .label_mlp_d (x )
169
170
label_h = self .label_mlp_h (x )
170
171
171
- # [batch_size, seq_len, seq_len, 2 ]
172
- s_edge = self .edge_attn (edge_d , edge_h ). permute ( 0 , 2 , 3 , 1 )
172
+ # [batch_size, seq_len, seq_len]
173
+ s_edge = self .edge_attn (edge_d , edge_h )
173
174
# [batch_size, seq_len, seq_len, n_labels]
174
175
s_label = self .label_attn (label_d , label_h ).permute (0 , 2 , 3 , 1 )
175
176
@@ -178,7 +179,7 @@ def forward(self, words, feats=None):
178
179
def loss (self , s_edge , s_label , labels , mask ):
179
180
r"""
180
181
Args:
181
- s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2 ]``.
182
+ s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
182
183
Scores of all possible edges.
183
184
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
184
185
Scores of all possible labels on each edge.
@@ -193,24 +194,24 @@ def loss(self, s_edge, s_label, labels, mask):
193
194
"""
194
195
195
196
edge_mask = labels .ge (0 ) & mask
196
- edge_loss = self . criterion (s_edge [mask ], edge_mask [mask ].long ())
197
+ edge_loss = F . binary_cross_entropy_with_logits (s_edge [mask ], edge_mask [mask ].float ())
197
198
label_loss = self .criterion (s_label [edge_mask ], labels [edge_mask ])
198
199
return self .args .interpolation * label_loss + (1 - self .args .interpolation ) * edge_loss
199
200
200
201
def decode (self , s_edge , s_label ):
201
202
r"""
202
203
Args:
203
- s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2 ]``.
204
+ s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
204
205
Scores of all possible edges.
205
206
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
206
207
Scores of all possible labels on each edge.
207
208
208
209
Returns:
209
- ~torch.BoolTensor :
210
+ ~torch.LongTensor :
210
211
Predicted labels of shape ``[batch_size, seq_len, seq_len]``.
211
212
"""
212
213
213
- return s_label .argmax (- 1 ).masked_fill_ (s_edge .argmax ( - 1 ). lt (1 ), - 1 )
214
+ return s_label .argmax (- 1 ).masked_fill_ (s_edge .lt (0 ), - 1 )
214
215
215
216
216
217
class VISemanticDependencyModel (BiaffineSemanticDependencyModel ):
0 commit comments