Skip to content

Commit 76f6be6

Browse files
committed
Backtrack via back-propagation
1 parent 57d64d5 commit 76f6be6

File tree

1 file changed

+31
-102
lines changed

1 file changed

+31
-102
lines changed

supar/utils/alg.py

Lines changed: 31 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33
import torch
4+
import torch.autograd as autograd
45
from supar.utils.fn import pad, stripe
56

67

@@ -284,6 +285,7 @@ def mst(scores, mask, multiroot=False):
284285
return pad(preds, total_length=seq_len).to(mask.device)
285286

286287

288+
@torch.enable_grad()
287289
def eisner(scores, mask, multiroot=False):
288290
r"""
289291
First-order Eisner algorithm for projective decoding :cite:`mcdonald-etal-2005-online`.
@@ -313,66 +315,43 @@ def eisner(scores, mask, multiroot=False):
313315

314316
lens = mask.sum(1)
315317
batch_size, seq_len, _ = scores.shape
316-
scores = scores.permute(2, 1, 0)
317-
s_i = torch.full_like(scores, float('-inf'))
318-
s_c = torch.full_like(scores, float('-inf'))
319-
p_i = scores.new_zeros(seq_len, seq_len, batch_size).long()
320-
p_c = scores.new_zeros(seq_len, seq_len, batch_size).long()
318+
scores = scores.permute(2, 1, 0).requires_grad_()
319+
s_i = torch.full_like(scores, -1e30)
320+
s_c = torch.full_like(scores, -1e30)
321321
s_c.diagonal().fill_(0)
322322

323323
for w in range(1, seq_len):
324324
n = seq_len - w
325-
starts = p_i.new_tensor(range(n)).unsqueeze(0)
326325
# ilr = C(i->r) + C(j->r+1)
327326
ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
328327
# [batch_size, n, w]
329328
il = ir = ilr.permute(2, 0, 1)
330329
# I(j->i) = max(C(i->r) + C(j->r+1) + s(j->i)), i <= r < j
331-
il_span, il_path = il.max(-1)
330+
il_span, _ = il.max(-1)
332331
s_i.diagonal(-w).copy_(il_span + scores.diagonal(-w))
333-
p_i.diagonal(-w).copy_(il_path + starts)
334332
# I(i->j) = max(C(i->r) + C(j->r+1) + s(i->j)), i <= r < j
335-
ir_span, ir_path = ir.max(-1)
333+
ir_span, _ = ir.max(-1)
336334
s_i.diagonal(w).copy_(ir_span + scores.diagonal(w))
337-
p_i.diagonal(w).copy_(ir_path + starts)
338335

339336
# C(j->i) = max(C(r->i) + I(j->r)), i <= r < j
340337
cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
341-
cl_span, cl_path = cl.permute(2, 0, 1).max(-1)
338+
cl_span, _ = cl.permute(2, 0, 1).max(-1)
342339
s_c.diagonal(-w).copy_(cl_span)
343-
p_c.diagonal(-w).copy_(cl_path + starts)
344340
# C(i->j) = max(I(i->r) + C(r->j)), i < r <= j
345341
cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
346-
cr_span, cr_path = cr.permute(2, 0, 1).max(-1)
342+
cr_span, _ = cr.permute(2, 0, 1).max(-1)
347343
s_c.diagonal(w).copy_(cr_span)
348344
if not multiroot:
349345
s_c[0, w][lens.ne(w)] = float('-inf')
350-
p_c.diagonal(w).copy_(cr_path + starts + 1)
351-
352-
def backtrack(p_i, p_c, heads, i, j, complete):
353-
if i == j:
354-
return
355-
if complete:
356-
r = p_c[i, j]
357-
backtrack(p_i, p_c, heads, i, r, False)
358-
backtrack(p_i, p_c, heads, r, j, True)
359-
else:
360-
r, heads[j] = p_i[i, j], i
361-
i, j = sorted((i, j))
362-
backtrack(p_i, p_c, heads, i, r, True)
363-
backtrack(p_i, p_c, heads, j, r + 1, True)
364346

365-
preds = []
366-
p_c = p_c.permute(2, 0, 1).cpu()
367-
p_i = p_i.permute(2, 0, 1).cpu()
368-
for i, length in enumerate(lens.tolist()):
369-
heads = p_c.new_zeros(length + 1, dtype=torch.long)
370-
backtrack(p_i[i], p_c[i], heads, 0, length, True)
371-
preds.append(heads.to(mask.device))
347+
logZ = s_c[0].gather(0, lens.unsqueeze(0)).sum()
348+
marginals, = autograd.grad(logZ, scores)
349+
preds = lens.new_zeros(batch_size, seq_len).masked_scatter_(mask, marginals.permute(2, 1, 0).nonzero()[:, 2])
372350

373-
return pad(preds, total_length=seq_len).to(mask.device)
351+
return preds
374352

375353

354+
@torch.enable_grad()
376355
def eisner2o(scores, mask, multiroot=False):
377356
r"""
378357
Second-order Eisner algorithm for projective decoding :cite:`mcdonald-pereira-2006-online`.
@@ -421,7 +400,7 @@ def eisner2o(scores, mask, multiroot=False):
421400

422401
# the end position of each sentence in a batch
423402
lens = mask.sum(1)
424-
s_arc, s_sib = scores
403+
s_arc, s_sib = (s.requires_grad_() for s in scores)
425404
batch_size, seq_len, _ = s_arc.shape
426405
# [seq_len, seq_len, batch_size]
427406
s_arc = s_arc.permute(2, 1, 0)
@@ -430,16 +409,12 @@ def eisner2o(scores, mask, multiroot=False):
430409
s_i = torch.full_like(s_arc, float('-inf'))
431410
s_s = torch.full_like(s_arc, float('-inf'))
432411
s_c = torch.full_like(s_arc, float('-inf'))
433-
p_i = s_arc.new_zeros(seq_len, seq_len, batch_size).long()
434-
p_s = s_arc.new_zeros(seq_len, seq_len, batch_size).long()
435-
p_c = s_arc.new_zeros(seq_len, seq_len, batch_size).long()
436412
s_c.diagonal().fill_(0)
437413

438414
for w in range(1, seq_len):
439415
# n denotes the number of spans to iterate,
440416
# from span (0, w) to span (n, n+w) given width w
441417
n = seq_len - w
442-
starts = p_i.new_tensor(range(n)).unsqueeze(0)
443418
# I(j->i) = max(I(j->r) + S(j->r, i)), i < r < j |
444419
# C(j->j) + C(i->j-1))
445420
# + s(j->i)
@@ -450,9 +425,8 @@ def eisner2o(scores, mask, multiroot=False):
450425
il0 = stripe(s_c, n, 1, (w, w)) + stripe(s_c, n, 1, (0, w - 1))
451426
# il0[0] are set to zeros since the scores of the complete spans starting from 0 are always -inf
452427
il[:, -1] = il0.index_fill_(0, lens.new_tensor(0), 0).squeeze(1)
453-
il_span, il_path = il.permute(2, 0, 1).max(-1)
428+
il_span, _ = il.permute(2, 0, 1).max(-1)
454429
s_i.diagonal(-w).copy_(il_span + s_arc.diagonal(-w))
455-
p_i.diagonal(-w).copy_(il_path + starts + 1)
456430
# I(i->j) = max(I(i->r) + S(i->r, j), i < r < j |
457431
# C(i->i) + C(j->i+1))
458432
# + s(i->j)
@@ -463,66 +437,36 @@ def eisner2o(scores, mask, multiroot=False):
463437
# [n, 1, batch_size]
464438
ir0 = stripe(s_c, n, 1) + stripe(s_c, n, 1, (w, 1))
465439
ir[:, 0] = ir0.squeeze(1)
466-
ir_span, ir_path = ir.permute(2, 0, 1).max(-1)
440+
ir_span, _ = ir.permute(2, 0, 1).max(-1)
467441
s_i.diagonal(w).copy_(ir_span + s_arc.diagonal(w))
468-
p_i.diagonal(w).copy_(ir_path + starts)
469442

470443
# [n, w, batch_size]
471444
slr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
472-
slr_span, slr_path = slr.permute(2, 0, 1).max(-1)
445+
slr_span, _ = slr.permute(2, 0, 1).max(-1)
473446
# S(j, i) = max(C(i->r) + C(j->r+1)), i <= r < j
474447
s_s.diagonal(-w).copy_(slr_span)
475-
p_s.diagonal(-w).copy_(slr_path + starts)
476448
# S(i, j) = max(C(i->r) + C(j->r+1)), i <= r < j
477449
s_s.diagonal(w).copy_(slr_span)
478-
p_s.diagonal(w).copy_(slr_path + starts)
479450

480451
# C(j->i) = max(C(r->i) + I(j->r)), i <= r < j
481452
cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
482-
cl_span, cl_path = cl.permute(2, 0, 1).max(-1)
453+
cl_span, _ = cl.permute(2, 0, 1).max(-1)
483454
s_c.diagonal(-w).copy_(cl_span)
484-
p_c.diagonal(-w).copy_(cl_path + starts)
485455
# C(i->j) = max(I(i->r) + C(r->j)), i < r <= j
486456
cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
487-
cr_span, cr_path = cr.permute(2, 0, 1).max(-1)
457+
cr_span, _ = cr.permute(2, 0, 1).max(-1)
488458
s_c.diagonal(w).copy_(cr_span)
489459
if not multiroot:
490460
s_c[0, w][lens.ne(w)] = float('-inf')
491-
p_c.diagonal(w).copy_(cr_path + starts + 1)
492-
493-
def backtrack(p_i, p_s, p_c, heads, i, j, flag):
494-
if i == j:
495-
return
496-
if flag == 'c':
497-
r = p_c[i, j]
498-
backtrack(p_i, p_s, p_c, heads, i, r, 'i')
499-
backtrack(p_i, p_s, p_c, heads, r, j, 'c')
500-
elif flag == 's':
501-
r = p_s[i, j]
502-
i, j = sorted((i, j))
503-
backtrack(p_i, p_s, p_c, heads, i, r, 'c')
504-
backtrack(p_i, p_s, p_c, heads, j, r + 1, 'c')
505-
elif flag == 'i':
506-
r, heads[j] = p_i[i, j], i
507-
if r == i:
508-
r = i + 1 if i < j else i - 1
509-
backtrack(p_i, p_s, p_c, heads, j, r, 'c')
510-
else:
511-
backtrack(p_i, p_s, p_c, heads, i, r, 'i')
512-
backtrack(p_i, p_s, p_c, heads, r, j, 's')
513461

514-
preds = []
515-
p_i = p_i.permute(2, 0, 1).cpu()
516-
p_s = p_s.permute(2, 0, 1).cpu()
517-
p_c = p_c.permute(2, 0, 1).cpu()
518-
for i, length in enumerate(lens.tolist()):
519-
heads = p_c.new_zeros(length + 1, dtype=torch.long)
520-
backtrack(p_i[i], p_s[i], p_c[i], heads, 0, length, 'c')
521-
preds.append(heads.to(mask.device))
462+
logZ = s_c[0].gather(0, lens.unsqueeze(0)).sum()
463+
marginals, = autograd.grad(logZ, s_arc)
464+
preds = lens.new_zeros(batch_size, seq_len).masked_scatter_(mask, marginals.permute(2, 1, 0).nonzero()[:, 2])
522465

523-
return pad(preds, total_length=seq_len).to(mask.device)
466+
return preds
524467

525468

469+
@torch.enable_grad()
526470
def cky(scores, mask):
527471
r"""
528472
The implementation of `Cocke-Kasami-Younger`_ (CKY) algorithm to parse constituency trees :cite:`zhang-etal-2020-fast`.
@@ -554,17 +498,13 @@ def cky(scores, mask):
554498
"""
555499

556500
lens = mask[:, 0].sum(-1)
557-
scores = scores.permute(1, 2, 3, 0)
501+
scores = scores.permute(1, 2, 3, 0).requires_grad_()
558502
seq_len, seq_len, n_labels, batch_size = scores.shape
559503
s = scores.new_zeros(seq_len, seq_len, batch_size)
560-
p_s = scores.new_zeros(seq_len, seq_len, batch_size).long()
561-
p_l = scores.new_zeros(seq_len, seq_len, batch_size).long()
562504

563505
for w in range(1, seq_len):
564506
n = seq_len - w
565-
starts = p_s.new_tensor(range(n)).unsqueeze(0)
566-
s_l, p = scores.diagonal(w).max(0)
567-
p_l.diagonal(w).copy_(p)
507+
s_l, _ = scores.diagonal(w).max(0)
568508

569509
if w == 1:
570510
s.diagonal(w).copy_(s_l)
@@ -574,20 +514,9 @@ def cky(scores, mask):
574514
# [batch_size, n, w]
575515
s_s = s_s.permute(2, 0, 1)
576516
# [batch_size, n]
577-
s_s, p = s_s.max(-1)
517+
s_s, _ = s_s.max(-1)
578518
s.diagonal(w).copy_(s_s + s_l)
579-
p_s.diagonal(w).copy_(p + starts + 1)
580-
581-
def backtrack(p_s, p_l, i, j):
582-
if j == i + 1:
583-
return [(i, j, p_l[i][j])]
584-
split, label = p_s[i][j], p_l[i][j]
585-
ltree = backtrack(p_s, p_l, i, split)
586-
rtree = backtrack(p_s, p_l, split, j)
587-
return [(i, j, label)] + ltree + rtree
588-
589-
p_s = p_s.permute(2, 0, 1).tolist()
590-
p_l = p_l.permute(2, 0, 1).tolist()
591-
trees = [backtrack(p_s[i], p_l[i], 0, length) for i, length in enumerate(lens.tolist())]
592519

593-
return trees
520+
logZ = s[0].gather(0, lens.unsqueeze(0)).sum()
521+
marginals, = autograd.grad(logZ, scores)
522+
return [sorted(i.nonzero().tolist(), key=lambda x:(x[0], -x[1])) for i in marginals.permute(3, 0, 1, 2)]

0 commit comments

Comments
 (0)