Skip to content

Commit ceb7774

Browse files
committed
Bug of MatrixTree
1 parent 3e14cf2 commit ceb7774

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

supar/structs/crf.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,16 @@ def kl(self, other):
4848
raise NotImplementedError
4949

5050
def score(self, value):
51+
arcs = value
5152
if self.partial:
5253
mask, lens = self.mask, self.lens
53-
mask = self.mask.index_fill(1, self.lens.new_tensor(0), 1)
54+
mask = mask.index_fill(1, self.lens.new_tensor(0), 1)
5455
mask = mask.unsqueeze(1) & mask.unsqueeze(2)
55-
value = value.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1)
56-
value = value.eq(lens.new_tensor(range(mask.shape[1]))) | value.lt(0)
57-
value = value & mask
58-
scores = LogSemiring.zero_mask(self.scores, ~value)
56+
arcs = arcs.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1)
57+
arcs = arcs.eq(lens.new_tensor(range(mask.shape[1]))) | arcs.lt(0)
58+
scores = LogSemiring.zero_mask(self.scores, ~(arcs & mask))
5959
return self.__class__(scores, self.mask, **self.kwargs).log_partition
60-
return LogSemiring.prod(LogSemiring.one_mask(self.scores.gather(-1, value.unsqueeze(-1)).squeeze(-1), ~self.mask), -1)
60+
return LogSemiring.prod(LogSemiring.one_mask(self.scores.gather(-1, arcs.unsqueeze(-1)).squeeze(-1), ~self.mask), -1)
6161

6262
@torch.enable_grad()
6363
def forward(self, semiring):
@@ -76,10 +76,9 @@ def forward(self, semiring):
7676

7777
s_arc = self.scores
7878
mask, lens = self.mask, self.lens
79-
lens = mask.sum(-1)
8079
batch_size, seq_len, _ = s_arc.shape
8180
mask = mask.index_fill(1, lens.new_tensor(0), 1)
82-
s_arc = semiring.zero_mask(s_arc, mask.unsqueeze(-1) & mask.unsqueeze(-2))
81+
s_arc = semiring.zero_mask(s_arc, ~(mask.unsqueeze(-1) & mask.unsqueeze(-2)))
8382

8483
# A(i, j) = exp(s(i, j))
8584
# double precision to prevent overflows
@@ -93,7 +92,7 @@ def forward(self, semiring):
9392
# L(i, j) = D(i, j) - A(i, j)
9493
L = nn.init.eye_(torch.empty_like(A[0])).repeat(batch_size, 1, 1).masked_scatter_(mask.unsqueeze(-1), (D - A)[mask])
9594
# Z = L^(0, 0), the minor of L w.r.t row 0 and column 0
96-
return L[:, 1:, 1:].logdet().float()
95+
return L[:, 1:, 1:].slogdet()[1].float()
9796

9897

9998
class CRFDependency(StructuredDistribution):

0 commit comments

Comments
 (0)