1
1
# -*- coding: utf-8 -*-
2
2
3
3
import torch
4
+ import torch .autograd as autograd
4
5
from supar .utils .fn import pad , stripe
5
6
6
7
@@ -284,6 +285,7 @@ def mst(scores, mask, multiroot=False):
284
285
return pad (preds , total_length = seq_len ).to (mask .device )
285
286
286
287
288
+ @torch .enable_grad ()
287
289
def eisner (scores , mask , multiroot = False ):
288
290
r"""
289
291
First-order Eisner algorithm for projective decoding :cite:`mcdonald-etal-2005-online`.
@@ -313,66 +315,43 @@ def eisner(scores, mask, multiroot=False):
313
315
314
316
lens = mask .sum (1 )
315
317
batch_size , seq_len , _ = scores .shape
316
- scores = scores .permute (2 , 1 , 0 )
317
- s_i = torch .full_like (scores , float ('-inf' ))
318
- s_c = torch .full_like (scores , float ('-inf' ))
319
- p_i = scores .new_zeros (seq_len , seq_len , batch_size ).long ()
320
- p_c = scores .new_zeros (seq_len , seq_len , batch_size ).long ()
318
+ scores = scores .permute (2 , 1 , 0 ).requires_grad_ ()
319
+ s_i = torch .full_like (scores , - 1e30 )
320
+ s_c = torch .full_like (scores , - 1e30 )
321
321
s_c .diagonal ().fill_ (0 )
322
322
323
323
for w in range (1 , seq_len ):
324
324
n = seq_len - w
325
- starts = p_i .new_tensor (range (n )).unsqueeze (0 )
326
325
# ilr = C(i->r) + C(j->r+1)
327
326
ilr = stripe (s_c , n , w ) + stripe (s_c , n , w , (w , 1 ))
328
327
# [batch_size, n, w]
329
328
il = ir = ilr .permute (2 , 0 , 1 )
330
329
# I(j->i) = max(C(i->r) + C(j->r+1) + s(j->i)), i <= r < j
331
- il_span , il_path = il .max (- 1 )
330
+ il_span , _ = il .max (- 1 )
332
331
s_i .diagonal (- w ).copy_ (il_span + scores .diagonal (- w ))
333
- p_i .diagonal (- w ).copy_ (il_path + starts )
334
332
# I(i->j) = max(C(i->r) + C(j->r+1) + s(i->j)), i <= r < j
335
- ir_span , ir_path = ir .max (- 1 )
333
+ ir_span , _ = ir .max (- 1 )
336
334
s_i .diagonal (w ).copy_ (ir_span + scores .diagonal (w ))
337
- p_i .diagonal (w ).copy_ (ir_path + starts )
338
335
339
336
# C(j->i) = max(C(r->i) + I(j->r)), i <= r < j
340
337
cl = stripe (s_c , n , w , (0 , 0 ), 0 ) + stripe (s_i , n , w , (w , 0 ))
341
- cl_span , cl_path = cl .permute (2 , 0 , 1 ).max (- 1 )
338
+ cl_span , _ = cl .permute (2 , 0 , 1 ).max (- 1 )
342
339
s_c .diagonal (- w ).copy_ (cl_span )
343
- p_c .diagonal (- w ).copy_ (cl_path + starts )
344
340
# C(i->j) = max(I(i->r) + C(r->j)), i < r <= j
345
341
cr = stripe (s_i , n , w , (0 , 1 )) + stripe (s_c , n , w , (1 , w ), 0 )
346
- cr_span , cr_path = cr .permute (2 , 0 , 1 ).max (- 1 )
342
+ cr_span , _ = cr .permute (2 , 0 , 1 ).max (- 1 )
347
343
s_c .diagonal (w ).copy_ (cr_span )
348
344
if not multiroot :
349
345
s_c [0 , w ][lens .ne (w )] = float ('-inf' )
350
- p_c .diagonal (w ).copy_ (cr_path + starts + 1 )
351
-
352
- def backtrack (p_i , p_c , heads , i , j , complete ):
353
- if i == j :
354
- return
355
- if complete :
356
- r = p_c [i , j ]
357
- backtrack (p_i , p_c , heads , i , r , False )
358
- backtrack (p_i , p_c , heads , r , j , True )
359
- else :
360
- r , heads [j ] = p_i [i , j ], i
361
- i , j = sorted ((i , j ))
362
- backtrack (p_i , p_c , heads , i , r , True )
363
- backtrack (p_i , p_c , heads , j , r + 1 , True )
364
346
365
- preds = []
366
- p_c = p_c .permute (2 , 0 , 1 ).cpu ()
367
- p_i = p_i .permute (2 , 0 , 1 ).cpu ()
368
- for i , length in enumerate (lens .tolist ()):
369
- heads = p_c .new_zeros (length + 1 , dtype = torch .long )
370
- backtrack (p_i [i ], p_c [i ], heads , 0 , length , True )
371
- preds .append (heads .to (mask .device ))
347
+ logZ = s_c [0 ].gather (0 , lens .unsqueeze (0 )).sum ()
348
+ marginals , = autograd .grad (logZ , scores )
349
+ preds = lens .new_zeros (batch_size , seq_len ).masked_scatter_ (mask , marginals .permute (2 , 1 , 0 ).nonzero ()[:, 2 ])
372
350
373
- return pad ( preds , total_length = seq_len ). to ( mask . device )
351
+ return preds
374
352
375
353
354
+ @torch .enable_grad ()
376
355
def eisner2o (scores , mask , multiroot = False ):
377
356
r"""
378
357
Second-order Eisner algorithm for projective decoding :cite:`mcdonald-pereira-2006-online`.
@@ -421,7 +400,7 @@ def eisner2o(scores, mask, multiroot=False):
421
400
422
401
# the end position of each sentence in a batch
423
402
lens = mask .sum (1 )
424
- s_arc , s_sib = scores
403
+ s_arc , s_sib = ( s . requires_grad_ () for s in scores )
425
404
batch_size , seq_len , _ = s_arc .shape
426
405
# [seq_len, seq_len, batch_size]
427
406
s_arc = s_arc .permute (2 , 1 , 0 )
@@ -430,16 +409,12 @@ def eisner2o(scores, mask, multiroot=False):
430
409
s_i = torch .full_like (s_arc , float ('-inf' ))
431
410
s_s = torch .full_like (s_arc , float ('-inf' ))
432
411
s_c = torch .full_like (s_arc , float ('-inf' ))
433
- p_i = s_arc .new_zeros (seq_len , seq_len , batch_size ).long ()
434
- p_s = s_arc .new_zeros (seq_len , seq_len , batch_size ).long ()
435
- p_c = s_arc .new_zeros (seq_len , seq_len , batch_size ).long ()
436
412
s_c .diagonal ().fill_ (0 )
437
413
438
414
for w in range (1 , seq_len ):
439
415
# n denotes the number of spans to iterate,
440
416
# from span (0, w) to span (n, n+w) given width w
441
417
n = seq_len - w
442
- starts = p_i .new_tensor (range (n )).unsqueeze (0 )
443
418
# I(j->i) = max(I(j->r) + S(j->r, i)), i < r < j |
444
419
# C(j->j) + C(i->j-1))
445
420
# + s(j->i)
@@ -450,9 +425,8 @@ def eisner2o(scores, mask, multiroot=False):
450
425
il0 = stripe (s_c , n , 1 , (w , w )) + stripe (s_c , n , 1 , (0 , w - 1 ))
451
426
# il0[0] are set to zeros since the scores of the complete spans starting from 0 are always -inf
452
427
il [:, - 1 ] = il0 .index_fill_ (0 , lens .new_tensor (0 ), 0 ).squeeze (1 )
453
- il_span , il_path = il .permute (2 , 0 , 1 ).max (- 1 )
428
+ il_span , _ = il .permute (2 , 0 , 1 ).max (- 1 )
454
429
s_i .diagonal (- w ).copy_ (il_span + s_arc .diagonal (- w ))
455
- p_i .diagonal (- w ).copy_ (il_path + starts + 1 )
456
430
# I(i->j) = max(I(i->r) + S(i->r, j), i < r < j |
457
431
# C(i->i) + C(j->i+1))
458
432
# + s(i->j)
@@ -463,66 +437,36 @@ def eisner2o(scores, mask, multiroot=False):
463
437
# [n, 1, batch_size]
464
438
ir0 = stripe (s_c , n , 1 ) + stripe (s_c , n , 1 , (w , 1 ))
465
439
ir [:, 0 ] = ir0 .squeeze (1 )
466
- ir_span , ir_path = ir .permute (2 , 0 , 1 ).max (- 1 )
440
+ ir_span , _ = ir .permute (2 , 0 , 1 ).max (- 1 )
467
441
s_i .diagonal (w ).copy_ (ir_span + s_arc .diagonal (w ))
468
- p_i .diagonal (w ).copy_ (ir_path + starts )
469
442
470
443
# [n, w, batch_size]
471
444
slr = stripe (s_c , n , w ) + stripe (s_c , n , w , (w , 1 ))
472
- slr_span , slr_path = slr .permute (2 , 0 , 1 ).max (- 1 )
445
+ slr_span , _ = slr .permute (2 , 0 , 1 ).max (- 1 )
473
446
# S(j, i) = max(C(i->r) + C(j->r+1)), i <= r < j
474
447
s_s .diagonal (- w ).copy_ (slr_span )
475
- p_s .diagonal (- w ).copy_ (slr_path + starts )
476
448
# S(i, j) = max(C(i->r) + C(j->r+1)), i <= r < j
477
449
s_s .diagonal (w ).copy_ (slr_span )
478
- p_s .diagonal (w ).copy_ (slr_path + starts )
479
450
480
451
# C(j->i) = max(C(r->i) + I(j->r)), i <= r < j
481
452
cl = stripe (s_c , n , w , (0 , 0 ), 0 ) + stripe (s_i , n , w , (w , 0 ))
482
- cl_span , cl_path = cl .permute (2 , 0 , 1 ).max (- 1 )
453
+ cl_span , _ = cl .permute (2 , 0 , 1 ).max (- 1 )
483
454
s_c .diagonal (- w ).copy_ (cl_span )
484
- p_c .diagonal (- w ).copy_ (cl_path + starts )
485
455
# C(i->j) = max(I(i->r) + C(r->j)), i < r <= j
486
456
cr = stripe (s_i , n , w , (0 , 1 )) + stripe (s_c , n , w , (1 , w ), 0 )
487
- cr_span , cr_path = cr .permute (2 , 0 , 1 ).max (- 1 )
457
+ cr_span , _ = cr .permute (2 , 0 , 1 ).max (- 1 )
488
458
s_c .diagonal (w ).copy_ (cr_span )
489
459
if not multiroot :
490
460
s_c [0 , w ][lens .ne (w )] = float ('-inf' )
491
- p_c .diagonal (w ).copy_ (cr_path + starts + 1 )
492
-
493
- def backtrack (p_i , p_s , p_c , heads , i , j , flag ):
494
- if i == j :
495
- return
496
- if flag == 'c' :
497
- r = p_c [i , j ]
498
- backtrack (p_i , p_s , p_c , heads , i , r , 'i' )
499
- backtrack (p_i , p_s , p_c , heads , r , j , 'c' )
500
- elif flag == 's' :
501
- r = p_s [i , j ]
502
- i , j = sorted ((i , j ))
503
- backtrack (p_i , p_s , p_c , heads , i , r , 'c' )
504
- backtrack (p_i , p_s , p_c , heads , j , r + 1 , 'c' )
505
- elif flag == 'i' :
506
- r , heads [j ] = p_i [i , j ], i
507
- if r == i :
508
- r = i + 1 if i < j else i - 1
509
- backtrack (p_i , p_s , p_c , heads , j , r , 'c' )
510
- else :
511
- backtrack (p_i , p_s , p_c , heads , i , r , 'i' )
512
- backtrack (p_i , p_s , p_c , heads , r , j , 's' )
513
461
514
- preds = []
515
- p_i = p_i .permute (2 , 0 , 1 ).cpu ()
516
- p_s = p_s .permute (2 , 0 , 1 ).cpu ()
517
- p_c = p_c .permute (2 , 0 , 1 ).cpu ()
518
- for i , length in enumerate (lens .tolist ()):
519
- heads = p_c .new_zeros (length + 1 , dtype = torch .long )
520
- backtrack (p_i [i ], p_s [i ], p_c [i ], heads , 0 , length , 'c' )
521
- preds .append (heads .to (mask .device ))
462
+ logZ = s_c [0 ].gather (0 , lens .unsqueeze (0 )).sum ()
463
+ marginals , = autograd .grad (logZ , s_arc )
464
+ preds = lens .new_zeros (batch_size , seq_len ).masked_scatter_ (mask , marginals .permute (2 , 1 , 0 ).nonzero ()[:, 2 ])
522
465
523
- return pad ( preds , total_length = seq_len ). to ( mask . device )
466
+ return preds
524
467
525
468
469
+ @torch .enable_grad ()
526
470
def cky (scores , mask ):
527
471
r"""
528
472
The implementation of `Cocke-Kasami-Younger`_ (CKY) algorithm to parse constituency trees :cite:`zhang-etal-2020-fast`.
@@ -554,17 +498,13 @@ def cky(scores, mask):
554
498
"""
555
499
556
500
lens = mask [:, 0 ].sum (- 1 )
557
- scores = scores .permute (1 , 2 , 3 , 0 )
501
+ scores = scores .permute (1 , 2 , 3 , 0 ). requires_grad_ ()
558
502
seq_len , seq_len , n_labels , batch_size = scores .shape
559
503
s = scores .new_zeros (seq_len , seq_len , batch_size )
560
- p_s = scores .new_zeros (seq_len , seq_len , batch_size ).long ()
561
- p_l = scores .new_zeros (seq_len , seq_len , batch_size ).long ()
562
504
563
505
for w in range (1 , seq_len ):
564
506
n = seq_len - w
565
- starts = p_s .new_tensor (range (n )).unsqueeze (0 )
566
- s_l , p = scores .diagonal (w ).max (0 )
567
- p_l .diagonal (w ).copy_ (p )
507
+ s_l , _ = scores .diagonal (w ).max (0 )
568
508
569
509
if w == 1 :
570
510
s .diagonal (w ).copy_ (s_l )
@@ -574,20 +514,9 @@ def cky(scores, mask):
574
514
# [batch_size, n, w]
575
515
s_s = s_s .permute (2 , 0 , 1 )
576
516
# [batch_size, n]
577
- s_s , p = s_s .max (- 1 )
517
+ s_s , _ = s_s .max (- 1 )
578
518
s .diagonal (w ).copy_ (s_s + s_l )
579
- p_s .diagonal (w ).copy_ (p + starts + 1 )
580
-
581
- def backtrack (p_s , p_l , i , j ):
582
- if j == i + 1 :
583
- return [(i , j , p_l [i ][j ])]
584
- split , label = p_s [i ][j ], p_l [i ][j ]
585
- ltree = backtrack (p_s , p_l , i , split )
586
- rtree = backtrack (p_s , p_l , split , j )
587
- return [(i , j , label )] + ltree + rtree
588
-
589
- p_s = p_s .permute (2 , 0 , 1 ).tolist ()
590
- p_l = p_l .permute (2 , 0 , 1 ).tolist ()
591
- trees = [backtrack (p_s [i ], p_l [i ], 0 , length ) for i , length in enumerate (lens .tolist ())]
592
519
593
- return trees
520
+ logZ = s [0 ].gather (0 , lens .unsqueeze (0 )).sum ()
521
+ marginals , = autograd .grad (logZ , scores )
522
+ return [sorted (i .nonzero ().tolist (), key = lambda x :(x [0 ], - x [1 ])) for i in marginals .permute (3 , 0 , 1 , 2 )]
0 commit comments