Skip to content

Commit 9da16df

Browse files
committed
Add guards around torch import
Torch is required for the coref/spanpred models but shouldn't be required for spaCy in general. The one tricky part of this is that one function in coref_util relied on torch, but that file was imported in several places. Since the function was only used in one place I moved it there.
1 parent e38e84a commit 9da16df

File tree

3 files changed

+25
-18
lines changed

3 files changed

+25
-18
lines changed

spacy/ml/models/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
from .coref import * #noqa
2-
from .span_predictor import * #noqa
31
from .entity_linker import * # noqa
42
from .multi_task import * # noqa
53
from .parser import * # noqa
64
from .spancat import * # noqa
75
from .tagger import * # noqa
86
from .textcat import * # noqa
97
from .tok2vec import * # noqa
8+
9+
# some models require Torch
10+
try:
11+
import torch
12+
from .coref import * #noqa
13+
from .span_predictor import * #noqa
14+
except ImportError:
15+
pass
16+

spacy/ml/models/coref.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from ...tokens import Doc
1010
from ...util import registry
11-
from .coref_util import add_dummy
1211

1312

1413
@registry.architectures("spacy.Coref.v1")
@@ -186,6 +185,22 @@ def forward(
186185
return coref_scores, top_indices
187186

188187

188+
# Note this function is kept here to keep a torch dep out of coref_util.
189+
def add_dummy(tensor: torch.Tensor, eps: bool = False):
190+
"""Prepends zeros (or a very small value if eps is True)
191+
to the first (not zeroth) dimension of tensor.
192+
"""
193+
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
194+
shape: List[int] = list(tensor.shape)
195+
shape[1] = 1
196+
if not eps:
197+
dummy = torch.zeros(shape, **kwargs) # type: ignore
198+
else:
199+
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
200+
output = torch.cat((dummy, tensor), dim=1)
201+
return output
202+
203+
189204
class AnaphoricityScorer(torch.nn.Module):
190205
"""Calculates anaphoricity scores by passing the inputs into a FFNN"""
191206

spacy/ml/models/coref_util.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from spacy.tokens import Doc
33
from typing import List, Tuple, Callable, Any, Set, Dict
44
from ...util import registry
5-
import torch
65

76
# type alias to make writing this less tedious
87
MentionClusters = List[List[Tuple[int, int]]]
@@ -25,20 +24,6 @@ def __repr__(self) -> str:
2524
return str(self.id)
2625

2726

28-
def add_dummy(tensor: torch.Tensor, eps: bool = False):
29-
""" Prepends zeros (or a very small value if eps is True)
30-
to the first (not zeroth) dimension of tensor.
31-
"""
32-
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
33-
shape: List[int] = list(tensor.shape)
34-
shape[1] = 1
35-
if not eps:
36-
dummy = torch.zeros(shape, **kwargs) # type: ignore
37-
else:
38-
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
39-
output = torch.cat((dummy, tensor), dim=1)
40-
return output
41-
4227
def get_sentence_ids(doc):
4328
out = []
4429
sent_id = -1

0 commit comments

Comments
 (0)