Skip to content

Commit e38e84a

Browse files
authored
Merge pull request explosion#10812 from kadarakos/feature/coref
Feature/coref
2 parents 2e8f0e9 + 1dc3894 commit e38e84a

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

spacy/ml/models/span_predictor.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
def build_span_predictor(
1616
tok2vec: Model[List[Doc], List[Floats2d]],
1717
hidden_size: int = 1024,
18-
dist_emb_size: int = 64,
18+
distance_embedding_size: int = 64,
19+
conv_channels: int = 4,
20+
window_size: int = 1,
21+
max_distance: int = 128,
22+
prefix: str = "coref_head_clusters"
1923
):
2024
# TODO add model return types
2125
# TODO fix this
@@ -27,11 +31,18 @@ def build_span_predictor(
2731

2832
with Model.define_operators({">>": chain, "&": tuplify}):
2933
span_predictor = PyTorchWrapper(
30-
SpanPredictor(dim, hidden_size, dist_emb_size),
34+
SpanPredictor(
35+
dim,
36+
hidden_size,
37+
distance_embedding_size,
38+
conv_channels,
39+
window_size,
40+
max_distance
41+
),
3142
convert_inputs=convert_span_predictor_inputs,
3243
)
3344
# TODO use proper parameter for prefix
34-
head_info = build_get_head_metadata("coref_head_clusters")
45+
head_info = build_get_head_metadata(prefix)
3546
model = (tok2vec & head_info) >> span_predictor
3647

3748
return model
@@ -122,8 +133,21 @@ def head_data_forward(model, docs, is_train):
122133

123134
# TODO this should maybe have a different name from the component
124135
class SpanPredictor(torch.nn.Module):
125-
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
136+
def __init__(
137+
self,
138+
input_size: int,
139+
hidden_size: int,
140+
dist_emb_size: int,
141+
conv_channels: int,
142+
window_size: int,
143+
max_distance: int
144+
145+
):
126146
super().__init__()
147+
if max_distance % 2 != 0:
148+
raise ValueError(
149+
"max_distance has to be an even number"
150+
)
127151
# input size = single token size
128152
# 64 = probably distance emb size
129153
# TODO check that dist_emb_size use is correct
@@ -138,12 +162,15 @@ def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
138162
# this use of dist_emb_size looks wrong but it was 64...?
139163
torch.nn.Linear(256, dist_emb_size),
140164
)
141-
# TODO make the Convs also parametrizeable
165+
kernel_size = window_size * 2 + 1
142166
self.conv = torch.nn.Sequential(
143-
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1)
167+
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
168+
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1)
144169
)
145170
# TODO make embeddings size a parameter
146-
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
171+
self.max_distance = max_distance
172+
# handle distances between +-(max_distance - 2 / 2)
173+
self.emb = torch.nn.Embedding(max_distance, dist_emb_size)
147174

148175
def forward(
149176
self,
@@ -169,10 +196,11 @@ def forward(
169196
relative_positions = heads_ids.unsqueeze(1) - torch.arange(
170197
words.shape[0]
171198
).unsqueeze(0)
199+
md = self.max_distance
172200
# make all valid distances positive
173-
emb_ids = relative_positions + 63
201+
emb_ids = relative_positions + (md - 2) // 2
174202
# "too_far"
175-
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
203+
emb_ids[(emb_ids < 0) + (emb_ids > md - 2)] = md - 1
176204
# Obtain "same sentence" boolean mask: (n_heads x n_words)
177205
heads_ids = heads_ids.long()
178206
same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)

spacy/pipeline/span_predictor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
[model]
2626
@architectures = "spacy.SpanPredictor.v1"
2727
hidden_size = 1024
28-
dist_emb_size = 64
28+
distance_embedding_size = 64
29+
conv_channels = 4
30+
window_size = 1
31+
max_distance = 128
32+
prefix = coref_head_clusters
2933
3034
[model.tok2vec]
3135
@architectures = "spacy.Tok2Vec.v2"

0 commit comments

Comments
 (0)