Skip to content

Commit 0365e33

Browse files
authored
Merge pull request yzhangcs#62 from KoichiYasuoka/main
torch>=1.7.0 (GPU version) support
2 parents 1dce1b8 + 9bcf10f commit 0365e33

File tree

5 files changed

+7
-7
lines changed

5 files changed

+7
-7
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
setup(
66
name='supar',
7-
version='1.0.0+dev20201223',
7+
version='1.0.0+dev20210310',
88
author='Yu Zhang',
99
author_email='yzhang.cs@outlook.com',
1010
description='Syntactic Parsing Models',

supar/models/constituency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def forward(self, words, feats):
186186
# concatenate the word and feat representations
187187
embed = torch.cat((word_embed, feat_embed), -1)
188188

189-
x = pack_padded_sequence(embed, mask.sum(1), True, False)
189+
x = pack_padded_sequence(embed, mask.sum(1).cpu(), True, False)
190190
x, _ = self.lstm(x)
191191
x, _ = pad_packed_sequence(x, True, total_length=seq_len)
192192
x = self.lstm_dropout(x)

supar/models/dependency.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def forward(self, words, feats):
185185
# concatenate the word and feat representations
186186
embed = torch.cat((word_embed, feat_embed), -1)
187187

188-
x = pack_padded_sequence(embed, mask.sum(1), True, False)
188+
x = pack_padded_sequence(embed, mask.sum(1).cpu(), True, False)
189189
x, _ = self.lstm(x)
190190
x, _ = pad_packed_sequence(x, True, total_length=seq_len)
191191
x = self.lstm_dropout(x)
@@ -645,7 +645,7 @@ def forward(self, words, feats):
645645
# concatenate the word and feat representations
646646
embed = torch.cat((word_embed, feat_embed), -1)
647647

648-
x = pack_padded_sequence(embed, mask.sum(1), True, False)
648+
x = pack_padded_sequence(embed, mask.sum(1).cpu(), True, False)
649649
x, _ = self.lstm(x)
650650
x, _ = pad_packed_sequence(x, True, total_length=seq_len)
651651
x = self.lstm_dropout(x)

supar/models/semantic_dependency.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def forward(self, words, feats):
219219
# concatenate the word and feat representations
220220
embed = torch.cat((word_embed, feat_embed), -1)
221221

222-
x = pack_padded_sequence(embed, mask.sum(1), True, False)
222+
x = pack_padded_sequence(embed, mask.sum(1).cpu(), True, False)
223223
x, _ = self.lstm(x)
224224
x, _ = pad_packed_sequence(x, True, total_length=seq_len)
225225
x = self.lstm_dropout(x)
@@ -503,7 +503,7 @@ def forward(self, words, feats):
503503
# concatenate the word and feat representations
504504
embed = torch.cat((word_embed, feat_embed), -1)
505505

506-
x = pack_padded_sequence(embed, mask.sum(1), True, False)
506+
x = pack_padded_sequence(embed, mask.sum(1).cpu(), True, False)
507507
x, _ = self.lstm(x)
508508
x, _ = pad_packed_sequence(x, True, total_length=seq_len)
509509
x = self.lstm_dropout(x)

supar/modules/char_lstm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def forward(self, x):
5454
# [batch_size, seq_len, fix_len]
5555
mask = x.ne(self.pad_index)
5656
# [batch_size, seq_len]
57-
lens = mask.sum(-1)
57+
lens = mask.sum(-1).cpu()
5858
char_mask = lens.gt(0)
5959

6060
# [n, fix_len, n_embed]

0 commit comments

Comments
 (0)