|
2 | 2 |
|
3 | 3 | import torch
|
4 | 4 | import torch.nn as nn
|
| 5 | +from opt_einsum import contract |
5 | 6 |
|
6 | 7 |
|
7 | 8 | class Biaffine(nn.Module):
|
@@ -71,7 +72,7 @@ def forward(self, x, y):
|
71 | 72 | if self.bias_y:
|
72 | 73 | y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
|
73 | 74 | # [batch_size, n_out, seq_len, seq_len]
|
74 |
| - s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y) / self.n_in ** self.scale |
| 75 | + s = contract('bxi,oij,byj->boxy', x, self.weight, y) / self.n_in ** self.scale |
75 | 76 | # remove dim 1 if n_out == 1
|
76 | 77 | s = s.squeeze(1)
|
77 | 78 |
|
@@ -145,9 +146,9 @@ def forward(self, x, y, z):
|
145 | 146 | x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
|
146 | 147 | if self.bias_y:
|
147 | 148 | y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
|
148 |
| - w = torch.einsum('bzk,oikj->bozij', z, self.weight) |
| 149 | + w = contract('bzk,oikj->bozij', z, self.weight) |
149 | 150 | # [batch_size, n_out, seq_len, seq_len, seq_len]
|
150 |
| - s = torch.einsum('bxi,bozij,byj->bozxy', x, w, y) / self.n_in ** self.scale |
| 151 | + s = contract('bxi,bozij,byj->bozxy', x, w, y) / self.n_in ** self.scale |
151 | 152 | # remove dim 1 if n_out == 1
|
152 | 153 | s = s.squeeze(1)
|
153 | 154 |
|
|
0 commit comments