15
15
def build_span_predictor (
16
16
tok2vec : Model [List [Doc ], List [Floats2d ]],
17
17
hidden_size : int = 1024 ,
18
- dist_emb_size : int = 64 ,
18
+ distance_embedding_size : int = 64 ,
19
+ conv_channels : int = 4 ,
20
+ window_size : int = 1 ,
21
+ max_distance : int = 128 ,
22
+ prefix : str = "coref_head_clusters"
19
23
):
20
24
# TODO add model return types
21
25
# TODO fix this
@@ -27,11 +31,18 @@ def build_span_predictor(
27
31
28
32
with Model .define_operators ({">>" : chain , "&" : tuplify }):
29
33
span_predictor = PyTorchWrapper (
30
- SpanPredictor (dim , hidden_size , dist_emb_size ),
34
+ SpanPredictor (
35
+ dim ,
36
+ hidden_size ,
37
+ distance_embedding_size ,
38
+ conv_channels ,
39
+ window_size ,
40
+ max_distance
41
+ ),
31
42
convert_inputs = convert_span_predictor_inputs ,
32
43
)
33
44
# TODO use proper parameter for prefix
34
- head_info = build_get_head_metadata ("coref_head_clusters" )
45
+ head_info = build_get_head_metadata (prefix )
35
46
model = (tok2vec & head_info ) >> span_predictor
36
47
37
48
return model
@@ -122,8 +133,21 @@ def head_data_forward(model, docs, is_train):
122
133
123
134
# TODO this should maybe have a different name from the component
124
135
class SpanPredictor (torch .nn .Module ):
125
- def __init__ (self , input_size : int , hidden_size : int , dist_emb_size : int ):
136
+ def __init__ (
137
+ self ,
138
+ input_size : int ,
139
+ hidden_size : int ,
140
+ dist_emb_size : int ,
141
+ conv_channels : int ,
142
+ window_size : int ,
143
+ max_distance : int
144
+
145
+ ):
126
146
super ().__init__ ()
147
+ if max_distance % 2 != 0 :
148
+ raise ValueError (
149
+ "max_distance has to be an even number"
150
+ )
127
151
# input size = single token size
128
152
# 64 = probably distance emb size
129
153
# TODO check that dist_emb_size use is correct
@@ -138,12 +162,15 @@ def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
138
162
# this use of dist_emb_size looks wrong but it was 64...?
139
163
torch .nn .Linear (256 , dist_emb_size ),
140
164
)
141
- # TODO make the Convs also parametrizeable
165
+ kernel_size = window_size * 2 + 1
142
166
self .conv = torch .nn .Sequential (
143
- torch .nn .Conv1d (64 , 4 , 3 , 1 , 1 ), torch .nn .Conv1d (4 , 2 , 3 , 1 , 1 )
167
+ torch .nn .Conv1d (dist_emb_size , conv_channels , kernel_size , 1 , 1 ),
168
+ torch .nn .Conv1d (conv_channels , 2 , kernel_size , 1 , 1 )
144
169
)
145
170
# TODO make embeddings size a parameter
146
- self .emb = torch .nn .Embedding (128 , dist_emb_size ) # [-63, 63] + too_far
171
+ self .max_distance = max_distance
172
+ # handle distances between +-(max_distance - 2 / 2)
173
+ self .emb = torch .nn .Embedding (max_distance , dist_emb_size )
147
174
148
175
def forward (
149
176
self ,
@@ -169,10 +196,11 @@ def forward(
169
196
relative_positions = heads_ids .unsqueeze (1 ) - torch .arange (
170
197
words .shape [0 ]
171
198
).unsqueeze (0 )
199
+ md = self .max_distance
172
200
# make all valid distances positive
173
- emb_ids = relative_positions + 63
201
+ emb_ids = relative_positions + ( md - 2 ) // 2
174
202
# "too_far"
175
- emb_ids [(emb_ids < 0 ) + (emb_ids > 126 )] = 127
203
+ emb_ids [(emb_ids < 0 ) + (emb_ids > md - 2 )] = md - 1
176
204
# Obtain "same sentence" boolean mask: (n_heads x n_words)
177
205
heads_ids = heads_ids .long ()
178
206
same_sent = sent_id [heads_ids ].unsqueeze (1 ) == sent_id .unsqueeze (0 )
0 commit comments