Skip to content

Commit a8e6f44

Browse files
committed
Optimize einsum
1 parent 163fd01 commit a8e6f44

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
'transformers>=4.0.0',
2828
'nltk',
2929
'stanza',
30+
'opt_einsum',
3031
'dill'],
3132
entry_points={
3233
'console_scripts': [

supar/modules/affine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import torch.nn as nn
5+
from opt_einsum import contract
56

67

78
class Biaffine(nn.Module):
@@ -71,7 +72,7 @@ def forward(self, x, y):
7172
if self.bias_y:
7273
y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
7374
# [batch_size, n_out, seq_len, seq_len]
74-
s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y) / self.n_in ** self.scale
75+
s = contract('bxi,oij,byj->boxy', x, self.weight, y) / self.n_in ** self.scale
7576
# remove dim 1 if n_out == 1
7677
s = s.squeeze(1)
7778

@@ -145,9 +146,9 @@ def forward(self, x, y, z):
145146
x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
146147
if self.bias_y:
147148
y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
148-
w = torch.einsum('bzk,oikj->bozij', z, self.weight)
149+
w = contract('bzk,oikj->bozij', z, self.weight)
149150
# [batch_size, n_out, seq_len, seq_len, seq_len]
150-
s = torch.einsum('bxi,bozij,byj->bozxy', x, w, y) / self.n_in ** self.scale
151+
s = contract('bxi,bozij,byj->bozxy', x, w, y) / self.n_in ** self.scale
151152
# remove dim 1 if n_out == 1
152153
s = s.squeeze(1)
153154

0 commit comments

Comments
 (0)