Skip to content

Commit 2b35bb7

Browse files
committed
Fix tensorizer on GPU
1 parent 6e5181b commit 2b35bb7

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

spacy/pipeline.pyx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,11 @@ class Tagger(Pipe):
415415
vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
416416
idx += 1
417417
if tensors is not None:
418-
doc.extend_tensor(tensors[i])
418+
if isinstance(doc.tensor, numpy.ndarray) \
419+
and not isinstance(tensors[i], numpy.ndarray):
420+
doc.extend_tensor(tensors[i].get())
421+
else:
422+
doc.extend_tensor(tensors[i])
419423
doc.is_tagged = True
420424

421425
def update(self, docs, golds, drop=0., sgd=None, losses=None):

spacy/syntax/nn_parser.pyx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,11 @@ cdef class Parser:
751751
for j in range(doc.length):
752752
doc.c[j] = state.c._sent[j]
753753
if tensors is not None:
754-
doc.extend_tensor(tensors[i])
754+
if isinstance(doc.tensor, numpy.ndarray) \
755+
and not isinstance(tensors[i], numpy.ndarray):
756+
doc.extend_tensor(tensors[i].get())
757+
else:
758+
doc.extend_tensor(tensors[i])
755759
self.moves.finalize_doc(doc)
756760

757761
for hook in self.postprocesses:

0 commit comments

Comments
 (0)