Skip to content

Commit 1e3e051

Browse files
committed
Requires torch>=1.7.0 (yzhangcs#62)
1 parent 71e806e commit 1e3e051

File tree

6 files changed

+9
-11
lines changed

6 files changed

+9
-11
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
setup_requires=[
2323
'setuptools>=18.0',
2424
],
25-
install_requires=['torch>=1.4.0', 'transformers>=3.1.0', 'nltk'],
25+
install_requires=['torch>=1.7.0', 'transformers>=3.1.0', 'nltk'],
2626
entry_points={
2727
'console_scripts': [
2828
'biaffine-dependency=supar.cmds.biaffine_dependency:main',

supar/models/constituency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def forward(self, words, feats):
186186
# concatenate the word and feat representations
187187
embed = torch.cat((word_embed, feat_embed), -1)
188188

189-
x = pack_padded_sequence(embed, mask.sum(1).cpu(), True, False)
189+
x = pack_padded_sequence(embed, mask.sum(1).tolist(), True, False)
190190
x, _ = self.lstm(x)
191191
x, _ = pad_packed_sequence(x, True, total_length=seq_len)
192192
x = self.lstm_dropout(x)

supar/models/dependency.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def forward(self, words, feats):
185185
# concatenate the word and feat representations
186186
embed = torch.cat((word_embed, feat_embed), -1)
187187

188-
x = pack_padded_sequence(embed, mask.sum(1).cpu(), True, False)
188+
x = pack_padded_sequence(embed, mask.sum(1).tolist(), True, False)
189189
x, _ = self.lstm(x)
190190
x, _ = pad_packed_sequence(x, True, total_length=seq_len)
191191
x = self.lstm_dropout(x)
@@ -645,7 +645,7 @@ def forward(self, words, feats):
645645
# concatenate the word and feat representations
646646
embed = torch.cat((word_embed, feat_embed), -1)
647647

648-
x = pack_padded_sequence(embed, mask.sum(1).cpu(), True, False)
648+
x = pack_padded_sequence(embed, mask.sum(1).tolist(), True, False)
649649
x, _ = self.lstm(x)
650650
x, _ = pad_packed_sequence(x, True, total_length=seq_len)
651651
x = self.lstm_dropout(x)

supar/models/semantic_dependency.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def forward(self, words, feats):
219219
# concatenate the word and feat representations
220220
embed = torch.cat((word_embed, feat_embed), -1)
221221

222-
x = pack_padded_sequence(embed, mask.sum(1).cpu(), True, False)
222+
x = pack_padded_sequence(embed, mask.sum(1).tolist(), True, False)
223223
x, _ = self.lstm(x)
224224
x, _ = pad_packed_sequence(x, True, total_length=seq_len)
225225
x = self.lstm_dropout(x)
@@ -503,7 +503,7 @@ def forward(self, words, feats):
503503
# concatenate the word and feat representations
504504
embed = torch.cat((word_embed, feat_embed), -1)
505505

506-
x = pack_padded_sequence(embed, mask.sum(1).cpu(), True, False)
506+
x = pack_padded_sequence(embed, mask.sum(1).tolist(), True, False)
507507
x, _ = self.lstm(x)
508508
x, _ = pad_packed_sequence(x, True, total_length=seq_len)
509509
x = self.lstm_dropout(x)

supar/modules/char_lstm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,8 @@ def __init__(self, n_chars, n_embed, n_out, pad_index=0):
2929
self.n_out = n_out
3030
self.pad_index = pad_index
3131

32-
# the embedding layer
3332
self.embed = nn.Embedding(num_embeddings=n_chars,
3433
embedding_dim=n_embed)
35-
# the lstm layer
3634
self.lstm = nn.LSTM(input_size=n_embed,
3735
hidden_size=n_out//2,
3836
batch_first=True,
@@ -54,12 +52,12 @@ def forward(self, x):
5452
# [batch_size, seq_len, fix_len]
5553
mask = x.ne(self.pad_index)
5654
# [batch_size, seq_len]
57-
lens = mask.sum(-1).cpu()
55+
lens = mask.sum(-1)
5856
char_mask = lens.gt(0)
5957

6058
# [n, fix_len, n_embed]
6159
x = self.embed(x[char_mask])
62-
x = pack_padded_sequence(x, lens[char_mask], True, False)
60+
x = pack_padded_sequence(x, lens[char_mask].tolist(), True, False)
6361
x, (h, _) = self.lstm(x)
6462
# [n, fix_len, n_out]
6563
h = torch.cat(torch.unbind(h), -1)

supar/utils/transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __contains__(self, key):
106106

107107
def __getattr__(self, name):
108108
if name in self.__dict__:
109-
return self.__dict__[name]
109+
return self.__dict__[name]
110110
elif name in self.maps:
111111
return self.values[self.maps[name]]
112112
else:

0 commit comments

Comments
 (0)