Skip to content

Commit 6705207

Browse files
committed
Handle single-root case in MatrixTree
1 parent 4f8a0d2 commit 6705207

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

supar/structs/tree.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,18 @@ class MatrixTree(StructuredDistribution):
3030
>>> s1 = MatrixTree(torch.randn(batch_size, seq_len, seq_len), lens)
3131
>>> s2 = MatrixTree(torch.randn(batch_size, seq_len, seq_len), lens)
3232
>>> s1.max
33-
tensor([2.6816, 7.2115], grad_fn=<CopyBackwards>)
33+
tensor([0.7174, 3.7910], grad_fn=<SumBackward1>)
3434
>>> s1.argmax
35-
tensor([[0, 0, 3, 1, 0],
36-
[0, 3, 0, 2, 3]])
35+
tensor([[0, 0, 1, 1, 0],
36+
[0, 4, 1, 0, 3]])
3737
>>> s1.log_partition
38-
tensor([2.6816, 7.2115], grad_fn=<CopyBackwards>)
38+
tensor([2.0229, 6.0558], grad_fn=<CopyBackwards>)
3939
>>> s1.log_prob(arcs)
40-
tensor([-0.7524, -3.0046], grad_fn=<SubBackward0>)
40+
tensor([-3.2209, -2.5756], grad_fn=<SubBackward0>)
41+
>>> s1.entropy
42+
tensor([1.9711, 3.4497], grad_fn=<SubBackward0>)
43+
>>> s1.kl(s2)
44+
tensor([1.3354, 2.6914], grad_fn=<AddBackward0>)
4145
"""
4246

4347
def __init__(self, scores, lens=None, multiroot=False):
@@ -56,9 +60,15 @@ def __repr__(self):
5660
def __add__(self, other):
5761
return MatrixTree(torch.stack((self.scores, other.scores)), self.lens, self.multiroot)
5862

63+
@lazy_property
64+
def max(self):
65+
arcs = self.argmax
66+
return LogSemiring.prod(LogSemiring.one_mask(self.scores.gather(-1, arcs.unsqueeze(-1)).squeeze(-1), ~self.mask), -1)
67+
5968
@lazy_property
6069
def argmax(self):
61-
return mst(self.scores, self.mask, self.multiroot)
70+
with torch.no_grad():
71+
return mst(self.scores, self.mask, self.multiroot)
6272

6373
def kmax(self, k):
6474
# TODO: Camerini algorithm
@@ -92,9 +102,8 @@ def score(self, value, partial=False):
92102
@torch.enable_grad()
93103
def forward(self, semiring):
94104
s_arc = self.scores
95-
mask, lens = self.mask, self.lens
96-
batch_size, seq_len, _ = s_arc.shape
97-
mask = mask.index_fill(1, lens.new_tensor(0), 1)
105+
batch_size, *_ = s_arc.shape
106+
mask = self.mask.index_fill(1, self.lens.new_tensor(0), 1)
98107
s_arc = semiring.zero_mask(s_arc, ~(mask.unsqueeze(-1) & mask.unsqueeze(-2)))
99108

100109
# A(i, j) = exp(s(i, j))
@@ -107,7 +116,11 @@ def forward(self, semiring):
107116
D.diagonal(0, 1, 2).copy_(A.sum(-1))
108117
# Laplacian matrix
109118
# L(i, j) = D(i, j) - A(i, j)
110-
L = nn.init.eye_(torch.empty_like(A[0])).repeat(batch_size, 1, 1).masked_scatter_(mask.unsqueeze(-1), (D - A)[mask])
119+
L = D - A
120+
if not self.multiroot:
121+
L.diagonal(0, 1, 2).add_(-A[..., 0])
122+
L[..., 1] = A[..., 0]
123+
L = nn.init.eye_(torch.empty_like(A[0])).repeat(batch_size, 1, 1).masked_scatter_(mask.unsqueeze(-1), L[mask])
111124
# Z = L^(0, 0), the minor of L w.r.t row 0 and column 0
112125
return L[:, 1:, 1:].slogdet()[1].float()
113126

0 commit comments

Comments
 (0)