Skip to content

Commit 4ed8b6e

Browse files
committed
Fix bug in KMaxSemiring for CKY
1 parent 9e41f4e commit 4ed8b6e

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

supar/structs/crf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,11 @@ def forward(self, semiring):
308308

309309
for w in range(1, seq_len):
310310
n = seq_len - w
311+
if w == 1:
312+
s.diagonal(w, -2, -1).copy_(scores.diagonal(w, -2, -1))
313+
continue
311314
# [..., batch_size, n]
312-
s_s = semiring.dot(stripe(s, n, w-1, (0, 1)), stripe(s, n, w-1, (1, w), 0), -1) if w > 1 else semiring.one
315+
s_s = semiring.dot(stripe(s, n, w-1, (0, 1)), stripe(s, n, w-1, (1, w), 0), -1)
313316
s.diagonal(w, -2, -1).copy_(semiring.mul(s_s, scores.diagonal(w, -2, -1)))
314317
# [..., batch_size, seq_len, seq_len]
315318
s = semiring.unconvert(s)

0 commit comments

Comments
 (0)