Skip to content

Commit c22c400

Browse files
author
zysite
committed
More efficient biaffine operation via einsum
1 parent 5886b08 commit c22c400

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

parser/modules/biaffine.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,11 @@ def reset_parameters(self):
3232

3333
def forward(self, x, y):
3434
if self.bias_x:
35-
x = torch.cat([x, x.new_ones(x.shape[:-1]).unsqueeze(-1)], -1)
35+
x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
3636
if self.bias_y:
37-
y = torch.cat([y, y.new_ones(y.shape[:-1]).unsqueeze(-1)], -1)
38-
# [batch_size, 1, seq_len, d]
39-
x = x.unsqueeze(1)
40-
# [batch_size, 1, seq_len, d]
41-
y = y.unsqueeze(1)
37+
y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
4238
# [batch_size, n_out, seq_len, seq_len]
43-
s = x @ self.weight @ y.transpose(-1, -2)
39+
s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y)
4440
# remove dim 1 if n_out == 1
4541
s = s.squeeze(1)
4642

0 commit comments

Comments
 (0)