|
2 | 2 |
|
3 | 3 | import torch
|
4 | 4 | import torch.nn as nn
|
5 |
| -from opt_einsum import contract |
6 | 5 |
|
7 | 6 |
|
8 | 7 | class Biaffine(nn.Module):
|
@@ -72,7 +71,7 @@ def forward(self, x, y):
|
72 | 71 | if self.bias_y:
|
73 | 72 | y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
|
74 | 73 | # [batch_size, n_out, seq_len, seq_len]
|
75 |
| - s = contract('bxi,oij,byj->boxy', x, self.weight, y) / self.n_in ** self.scale |
| 74 | + s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y) / self.n_in ** self.scale |
76 | 75 | # remove dim 1 if n_out == 1
|
77 | 76 | s = s.squeeze(1)
|
78 | 77 |
|
@@ -146,9 +145,9 @@ def forward(self, x, y, z):
|
146 | 145 | x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
|
147 | 146 | if self.bias_y:
|
148 | 147 | y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
|
149 |
| - w = contract('bzk,oikj->bozij', z, self.weight) |
| 148 | + w = torch.einsum('bzk,oikj->bozij', z, self.weight) |
150 | 149 | # [batch_size, n_out, seq_len, seq_len, seq_len]
|
151 |
| - s = contract('bxi,bozij,byj->bozxy', x, w, y) / self.n_in ** self.scale |
| 150 | + s = torch.einsum('bxi,bozij,byj->bozxy', x, w, y) / self.n_in ** self.scale |
152 | 151 | # remove dim 1 if n_out == 1
|
153 | 152 | s = s.squeeze(1)
|
154 | 153 |
|
|
0 commit comments