@@ -196,18 +196,18 @@ def forward(self, semiring):
196
196
197
197
# [n, batch_size, ...]
198
198
il = ir = semiring .dot (stripe (s_c , n , w ), stripe (s_c , n , w , (w , 1 )), 1 )
199
- # I(j->i) = logsumexp( C(i->r) + C(j->r+1)) + s(j->i), i <= r < j
199
+ # I(j->i) = < C(i->r), C(j->r+1)> * s(j->i), i <= r < j
200
200
# fill the w-th diagonal of the lower triangular part of s_i with I(j->i) of n spans
201
201
s_i .diagonal (- w ).copy_ (semiring .mul (il , s_arc .diagonal (- w ).movedim (- 1 , 0 )).movedim (0 , - 1 ))
202
- # I(i->j) = logsumexp( C(i->r) + C(j->r+1)) + s(i->j), i <= r < j
202
+ # I(i->j) = < C(i->r), C(j->r+1)> * s(i->j), i <= r < j
203
203
# fill the w-th diagonal of the upper triangular part of s_i with I(i->j) of n spans
204
204
s_i .diagonal (w ).copy_ (semiring .mul (ir , s_arc .diagonal (w ).movedim (- 1 , 0 )).movedim (0 , - 1 ))
205
205
206
206
# [n, batch_size, ...]
207
- # C(j->i) = logsumexp( C(r->i) + I(j->r)) , i <= r < j
207
+ # C(j->i) = < C(r->i), I(j->r)> , i <= r < j
208
208
cl = semiring .dot (stripe (s_c , n , w , (0 , 0 ), 0 ), stripe (s_i , n , w , (w , 0 )), 1 )
209
209
s_c .diagonal (- w ).copy_ (cl .movedim (0 , - 1 ))
210
- # C(i->j) = logsumexp( I(i->r) + C(r->j)) , i < r <= j
210
+ # C(i->j) = < I(i->r), C(r->j)> , i < r <= j
211
211
cr = semiring .dot (stripe (s_i , n , w , (0 , 1 )), stripe (s_c , n , w , (1 , w ), 0 ), 1 )
212
212
s_c .diagonal (w ).copy_ (cr .movedim (0 , - 1 ))
213
213
if not self .multiroot :
@@ -310,19 +310,17 @@ def forward(self, semiring):
310
310
for w in range (1 , seq_len ):
311
311
n = seq_len - w
312
312
313
- # I(j->i) = logsum(exp(I(j->r) + S(j->r, i)) +, i < r < j
314
- # exp(C(j->j) + C(i->j-1)))
315
- # + s(j->i)
313
+ # I(j->i) = <I(j->r), S(j->r, i)> * s(j->i), i < r < j
314
+ # <C(j->j), C(i->j-1)> * s(j->i), otherwise
316
315
# [n, w, batch_size, ...]
317
316
il = semiring .times (stripe (s_i , n , w , (w , 1 )),
318
317
stripe (s_s , n , w , (1 , 0 ), 0 ),
319
318
stripe (s_sib [range (w , n + w ), range (n ), :], n , w , (0 , 1 )))
320
319
il [:, - 1 ] = semiring .mul (stripe (s_c , n , 1 , (w , w )), stripe (s_c , n , 1 , (0 , w - 1 ))).squeeze (1 )
321
320
il = semiring .sum (il , 1 )
322
321
s_i .diagonal (- w ).copy_ (semiring .mul (il , s_arc .diagonal (- w ).movedim (- 1 , 0 )).movedim (0 , - 1 ))
323
- # I(i->j) = logsum(exp(I(i->r) + S(i->r, j)) +, i < r < j
324
- # exp(C(i->i) + C(j->i+1)))
325
- # + s(i->j)
322
+ # I(i->j) = <I(i->r), S(i->r, j)> * s(i->j), i < r < j
323
+ # <C(i->i), C(j->i+1)> * s(i->j), otherwise
326
324
# [n, w, batch_size, ...]
327
325
ir = semiring .times (stripe (s_i , n , w ),
328
326
stripe (s_s , n , w , (0 , w ), 0 ),
@@ -335,16 +333,16 @@ def forward(self, semiring):
335
333
336
334
# [batch_size, ..., n]
337
335
sl = sr = semiring .dot (stripe (s_c , n , w ), stripe (s_c , n , w , (w , 1 )), 1 ).movedim (0 , - 1 )
338
- # S(j, i) = logsumexp( C(i->r) + C(j->r+1)) , i <= r < j
336
+ # S(j, i) = < C(i->r), C(j->r+1)> , i <= r < j
339
337
s_s .diagonal (- w ).copy_ (sl )
340
- # S(i, j) = logsumexp( C(i->r) + C(j->r+1)) , i <= r < j
338
+ # S(i, j) = < C(i->r), C(j->r+1)> , i <= r < j
341
339
s_s .diagonal (w ).copy_ (sr )
342
340
343
341
# [n, batch_size, ...]
344
- # C(j->i) = logsumexp( C(r->i) + I(j->r)) , i <= r < j
342
+ # C(j->i) = < C(r->i), I(j->r)> , i <= r < j
345
343
cl = semiring .dot (stripe (s_c , n , w , (0 , 0 ), 0 ), stripe (s_i , n , w , (w , 0 )), 1 )
346
344
s_c .diagonal (- w ).copy_ (cl .movedim (0 , - 1 ))
347
- # C(i->j) = logsumexp( I(i->r) + C(r->j)) , i < r <= j
345
+ # C(i->j) = < I(i->r), C(r->j)> , i < r <= j
348
346
cr = semiring .dot (stripe (s_i , n , w , (0 , 1 )), stripe (s_c , n , w , (1 , w ), 0 ), 1 )
349
347
s_c .diagonal (w ).copy_ (cr .movedim (0 , - 1 ))
350
348
return semiring .unconvert (s_c )[0 ][self .lens , range (batch_size )]
0 commit comments