Skip to content

Commit fe4ae6f

Browse files
committed
Fix datatype issue
1 parent 2c0cf0d commit fe4ae6f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

supar/modules/transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def reset_parameters(self):
5959
self.embed.weight.copy_(w)
6060

6161
def forward(self, x):
62-
pos = x.new_tensor(range(x.shape[1]))
62+
pos = x.new_tensor(range(x.shape[1])).long()
6363
offset = sum(divmod(self.embed.weight.shape[0], 2))
6464
return self.embed(pos - pos.unsqueeze(-1) + offset)
6565

@@ -179,7 +179,7 @@ def __init__(self, n_heads, n_model, n_embed, dropout=0.1):
179179
self.n_embed = n_embed
180180
self.scale = n_embed**0.5
181181

182-
self.pos_embed = SinusoidRelativePositionalEmbedding()
182+
self.pos_embed = RelativePositionalEmbedding(n_model=n_embed)
183183
self.wq = nn.Parameter(torch.zeros(n_model, n_embed, n_heads))
184184
self.wk = nn.Parameter(torch.zeros(n_model, n_embed, n_heads))
185185
self.wv = nn.Parameter(torch.zeros(n_model, n_embed, n_heads))

0 commit comments

Comments
 (0)