Skip to content

Commit 3807a1b

Browse files
authored
Merge pull request explosion#10844 from polm/feature/coref-torch-guard
Add guards around torch import for coref
2 parents e38e84a + c9233a5 commit 3807a1b

File tree

4 files changed

+26
-24
lines changed

4 files changed

+26
-24
lines changed

spacy/ml/models/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
from .coref import * #noqa
2-
from .span_predictor import * #noqa
31
from .entity_linker import * # noqa
42
from .multi_task import * # noqa
53
from .parser import * # noqa
64
from .spancat import * # noqa
75
from .tagger import * # noqa
86
from .textcat import * # noqa
97
from .tok2vec import * # noqa
8+
9+
# some models require Torch
10+
from thinc.util import has_torch
11+
if has_torch:
12+
from .coref import * #noqa
13+
from .span_predictor import * #noqa
14+

spacy/ml/models/coref.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from typing import List, Tuple
2-
import torch
32

43
from thinc.api import Model, chain
54
from thinc.api import PyTorchWrapper, ArgsKwargs
65
from thinc.types import Floats2d, Ints2d, Ints1d
7-
from thinc.util import xp2torch, torch2xp
6+
from thinc.util import torch, xp2torch, torch2xp
87

98
from ...tokens import Doc
109
from ...util import registry
11-
from .coref_util import add_dummy
1210

1311

1412
@registry.architectures("spacy.Coref.v1")
@@ -186,6 +184,23 @@ def forward(
186184
return coref_scores, top_indices
187185

188186

187+
EPSILON = 1e-7
188+
# Note this function is kept here to keep a torch dep out of coref_util.
189+
def add_dummy(tensor: torch.Tensor, eps: bool = False):
190+
"""Prepends zeros (or a very small value if eps is True)
191+
to the first (not zeroth) dimension of tensor.
192+
"""
193+
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
194+
shape: List[int] = list(tensor.shape)
195+
shape[1] = 1
196+
if not eps:
197+
dummy = torch.zeros(shape, **kwargs) # type: ignore
198+
else:
199+
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
200+
output = torch.cat((dummy, tensor), dim=1)
201+
return output
202+
203+
189204
class AnaphoricityScorer(torch.nn.Module):
190205
"""Calculates anaphoricity scores by passing the inputs into a FFNN"""
191206

spacy/ml/models/coref_util.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@
22
from spacy.tokens import Doc
33
from typing import List, Tuple, Callable, Any, Set, Dict
44
from ...util import registry
5-
import torch
65

76
# type alias to make writing this less tedious
87
MentionClusters = List[List[Tuple[int, int]]]
98

109
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
1110

12-
EPSILON = 1e-7
13-
1411
class GraphNode:
1512
def __init__(self, node_id: int):
1613
self.id = node_id
@@ -25,20 +22,6 @@ def __repr__(self) -> str:
2522
return str(self.id)
2623

2724

28-
def add_dummy(tensor: torch.Tensor, eps: bool = False):
29-
""" Prepends zeros (or a very small value if eps is True)
30-
to the first (not zeroth) dimension of tensor.
31-
"""
32-
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
33-
shape: List[int] = list(tensor.shape)
34-
shape[1] = 1
35-
if not eps:
36-
dummy = torch.zeros(shape, **kwargs) # type: ignore
37-
else:
38-
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
39-
output = torch.cat((dummy, tensor), dim=1)
40-
return output
41-
4225
def get_sentence_ids(doc):
4326
out = []
4427
sent_id = -1

spacy/ml/models/span_predictor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from typing import List, Tuple
2-
import torch
32

43
from thinc.api import Model, chain, tuplify
54
from thinc.api import PyTorchWrapper, ArgsKwargs
65
from thinc.types import Floats2d, Ints1d
7-
from thinc.util import xp2torch, torch2xp
6+
from thinc.util import torch, xp2torch, torch2xp
87

98
from ...tokens import Doc
109
from ...util import registry

0 commit comments

Comments
 (0)