Skip to content

Commit 0853b5d

Browse files
authored
Sigmoid in place of CrossEntropy
1 parent f0fa3e1 commit 0853b5d

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

supar/models/sdp.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33
import torch.nn as nn
4+
import torch.nn.functional as F
45
from supar.models.model import Model
56
from supar.modules import MLP, Biaffine, Triaffine
67
from supar.structs import LBPSemanticDependency, MFVISemanticDependency
@@ -132,7 +133,7 @@ def __init__(self,
132133
self.label_mlp_d = MLP(n_in=self.args.n_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False)
133134
self.label_mlp_h = MLP(n_in=self.args.n_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False)
134135

135-
self.edge_attn = Biaffine(n_in=n_edge_mlp, n_out=2, bias_x=True, bias_y=True)
136+
self.edge_attn = Biaffine(n_in=n_edge_mlp, bias_x=True, bias_y=True)
136137
self.label_attn = Biaffine(n_in=n_label_mlp, n_out=n_labels, bias_x=True, bias_y=True)
137138
self.criterion = nn.CrossEntropyLoss()
138139

@@ -156,7 +157,7 @@ def forward(self, words, feats=None):
156157
157158
Returns:
158159
~torch.Tensor, ~torch.Tensor:
159-
The first tensor of shape ``[batch_size, seq_len, seq_len, 2]`` holds scores of all possible edges.
160+
The first tensor of shape ``[batch_size, seq_len, seq_len]`` holds scores of all possible edges.
160161
The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds
161162
scores of all possible labels on each edge.
162163
"""
@@ -168,8 +169,8 @@ def forward(self, words, feats=None):
168169
label_d = self.label_mlp_d(x)
169170
label_h = self.label_mlp_h(x)
170171

171-
# [batch_size, seq_len, seq_len, 2]
172-
s_edge = self.edge_attn(edge_d, edge_h).permute(0, 2, 3, 1)
172+
# [batch_size, seq_len, seq_len]
173+
s_edge = self.edge_attn(edge_d, edge_h)
173174
# [batch_size, seq_len, seq_len, n_labels]
174175
s_label = self.label_attn(label_d, label_h).permute(0, 2, 3, 1)
175176

@@ -178,7 +179,7 @@ def forward(self, words, feats=None):
178179
def loss(self, s_edge, s_label, labels, mask):
179180
r"""
180181
Args:
181-
s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``.
182+
s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
182183
Scores of all possible edges.
183184
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
184185
Scores of all possible labels on each edge.
@@ -193,24 +194,24 @@ def loss(self, s_edge, s_label, labels, mask):
193194
"""
194195

195196
edge_mask = labels.ge(0) & mask
196-
edge_loss = self.criterion(s_edge[mask], edge_mask[mask].long())
197+
edge_loss = F.binary_cross_entropy_with_logits(s_edge[mask], edge_mask[mask].float())
197198
label_loss = self.criterion(s_label[edge_mask], labels[edge_mask])
198199
return self.args.interpolation * label_loss + (1 - self.args.interpolation) * edge_loss
199200

200201
def decode(self, s_edge, s_label):
201202
r"""
202203
Args:
203-
s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``.
204+
s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
204205
Scores of all possible edges.
205206
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
206207
Scores of all possible labels on each edge.
207208
208209
Returns:
209-
~torch.BoolTensor:
210+
~torch.LongTensor:
210211
Predicted labels of shape ``[batch_size, seq_len, seq_len]``.
211212
"""
212213

213-
return s_label.argmax(-1).masked_fill_(s_edge.argmax(-1).lt(1), -1)
214+
return s_label.argmax(-1).masked_fill_(s_edge.lt(0), -1)
214215

215216

216217
class VISemanticDependencyModel(BiaffineSemanticDependencyModel):

0 commit comments

Comments
 (0)