Skip to content

Commit 2c37330

Browse files
committed
Revert to torch.einsum
1 parent 97a646d commit 2c37330

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

supar/modules/affine.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import torch
44
import torch.nn as nn
5-
from opt_einsum import contract
65

76

87
class Biaffine(nn.Module):
@@ -72,7 +71,7 @@ def forward(self, x, y):
7271
if self.bias_y:
7372
y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
7473
# [batch_size, n_out, seq_len, seq_len]
75-
s = contract('bxi,oij,byj->boxy', x, self.weight, y) / self.n_in ** self.scale
74+
s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y) / self.n_in ** self.scale
7675
# remove dim 1 if n_out == 1
7776
s = s.squeeze(1)
7877

@@ -146,9 +145,9 @@ def forward(self, x, y, z):
146145
x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
147146
if self.bias_y:
148147
y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
149-
w = contract('bzk,oikj->bozij', z, self.weight)
148+
w = torch.einsum('bzk,oikj->bozij', z, self.weight)
150149
# [batch_size, n_out, seq_len, seq_len, seq_len]
151-
s = contract('bxi,bozij,byj->bozxy', x, w, y) / self.n_in ** self.scale
150+
s = torch.einsum('bxi,bozij,byj->bozxy', x, w, y) / self.n_in ** self.scale
152151
# remove dim 1 if n_out == 1
153152
s = s.squeeze(1)
154153

0 commit comments

Comments
 (0)