Skip to content

Commit 196886b

Browse files
authored
Fix coref size inference (explosion#10916)
* Add explicit tok2vec_size parameter in clusterer * Add tok2vec size to span predictor config * Minor fixes
1 parent aa2eb27 commit 196886b

File tree

4 files changed

+14
-21
lines changed

4 files changed

+14
-21
lines changed

spacy/ml/models/coref.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,15 @@ def build_wl_coref_model(
1919
# pairs to keep per mention after rough scoring
2020
antecedent_limit: int = 50,
2121
antecedent_batch_size: int = 512,
22+
tok2vec_size: int = 768, # tok2vec size
2223
):
2324
# TODO add model return types
24-
# TODO fix this
25-
try:
26-
dim = tok2vec.get_dim("nO")
27-
except ValueError:
28-
# happens with transformer listener
29-
dim = 768
25+
# dim = tok2vec.maybe_get_dim("n0")
3026

3127
with Model.define_operators({">>": chain}):
3228
coref_clusterer = PyTorchWrapper(
3329
CorefClusterer(
34-
dim,
30+
tok2vec_size,
3531
distance_embedding_size,
3632
hidden_size,
3733
depth,
@@ -56,7 +52,7 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
5652
def backprop(args: ArgsKwargs) -> List[Floats2d]:
5753
# convert to xp and wrap in list
5854
gradients = torch2xp(args.args[0])
59-
assert isinstance(gradients, Floats2d)
55+
# assert isinstance(gradients, Floats2d)
6056
return [gradients]
6157

6258
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
@@ -89,7 +85,7 @@ class CorefClusterer(torch.nn.Module):
8985

9086
def __init__(
9187
self,
92-
dim: int, # tok2vec size
88+
dim: int,
9389
dist_emb_size: int,
9490
hidden_size: int,
9591
n_layers: int,
@@ -109,19 +105,19 @@ def __init__(
109105
"""
110106
self.dropout = torch.nn.Dropout(dropout)
111107
self.batch_size = batch_size
112-
# Modules
113108
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
109+
114110
pair_emb = dim * 3 + self.pw.shape
115-
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
111+
self.a_scorer = AnaphoricityScorer(
112+
pair_emb, hidden_size, n_layers, dropout
113+
)
116114
self.lstm = torch.nn.LSTM(
117115
input_size=dim,
118116
hidden_size=dim,
119117
batch_first=True,
120118
)
119+
121120
self.rough_scorer = RoughScorer(dim, dropout, roughk)
122-
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
123-
pair_emb = dim * 3 + self.pw.shape
124-
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
125121

126122
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
127123
"""

spacy/ml/models/span_predictor.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
@registry.architectures("spacy.SpanPredictor.v1")
1414
def build_span_predictor(
1515
tok2vec: Model[List[Doc], List[Floats2d]],
16+
tok2vec_size: int = 768,
1617
hidden_size: int = 1024,
1718
distance_embedding_size: int = 64,
1819
conv_channels: int = 4,
@@ -21,17 +22,11 @@ def build_span_predictor(
2122
prefix: str = "coref_head_clusters",
2223
):
2324
# TODO add model return types
24-
# TODO fix this
25-
try:
26-
dim = tok2vec.get_dim("nO")
27-
except ValueError:
28-
# happens with transformer listener
29-
dim = 768
3025

3126
with Model.define_operators({">>": chain, "&": tuplify}):
3227
span_predictor = PyTorchWrapper(
3328
SpanPredictor(
34-
dim,
29+
tok2vec_size,
3530
hidden_size,
3631
distance_embedding_size,
3732
conv_channels,

spacy/pipeline/coref.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
default_config = """
3131
[model]
3232
@architectures = "spacy.Coref.v1"
33+
tok2vec_size = 768
3334
distance_embedding_size = 20
3435
hidden_size = 1024
3536
depth = 1

spacy/pipeline/span_predictor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
default_span_predictor_config = """
2525
[model]
2626
@architectures = "spacy.SpanPredictor.v1"
27+
tok2vec_size = 768
2728
hidden_size = 1024
2829
distance_embedding_size = 64
2930
conv_channels = 4

0 commit comments

Comments
 (0)