Skip to content

Commit 7a11d43

Browse files
committed
topk for linear-chain CRF
1 parent 18ffc1e commit 7a11d43

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

supar/structs/linearchain.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,41 @@
88

99
class LinearChainCRF(StructuredDistribution):
1010
r"""
11+
Linear-chain CRFs (:cite:`lafferty-etal-2001-crf`).
12+
13+
Args:
14+
scores (~torch.Tensor): ``[batch_size, seq_len, n_tags]``.
15+
Log potentials.
16+
trans (~torch.Tensor): ``[n_tags+1, n_tags+1]``.
17+
Transition scores.
18+
``trans[-1, :-1]``/``trans[:-1, -1]`` represent transitions for start/end positions respectively.
19+
lens (~torch.LongTensor): ``[batch_size]``.
20+
Sentence lengths for masking. Default: ``None``.
1121
1222
Examples:
1323
>>> from supar import LinearChainCRF
14-
>>> batch_size, seq_len, n_tags = 3, 5, 4
15-
>>> lens = torch.tensor([3, 4, 5])
24+
>>> batch_size, seq_len, n_tags = 2, 5, 4
25+
>>> lens = torch.tensor([3, 4])
1626
>>> value = torch.randint(n_tags, (batch_size, seq_len))
17-
>>> s1 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags), torch.randn(n_tags+1, n_tags+1), lens)
18-
>>> s2 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags), torch.randn(n_tags+1, n_tags+1), lens)
27+
>>> s1 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags),
28+
torch.randn(n_tags+1, n_tags+1),
29+
lens)
30+
>>> s2 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags),
31+
torch.randn(n_tags+1, n_tags+1),
32+
lens)
1933
>>> s1.max
20-
tensor([2.4978, 5.7460, 4.9088], grad_fn=<MaxBackward0>)
34+
tensor([4.4120, 8.9672], grad_fn=<MaxBackward0>)
2135
>>> s1.argmax
22-
tensor([[3, 1, 3, 0, 0],
23-
[1, 0, 1, 0, 0],
24-
[2, 0, 1, 1, 0]])
36+
tensor([[2, 0, 3, 0, 0],
37+
[3, 3, 3, 2, 0]])
2538
>>> s1.log_partition
26-
tensor([3.7812, 7.9180, 7.8031], grad_fn=<LogsumexpBackward>)
39+
tensor([ 6.3486, 10.9106], grad_fn=<LogsumexpBackward>)
2740
>>> s1.log_prob(value)
28-
tensor([ -8.9096, -11.3473, -9.6189], grad_fn=<SubBackward0>)
41+
tensor([ -8.1515, -10.5572], grad_fn=<SubBackward0>)
42+
>>> s1.entropy
43+
tensor([3.4150, 3.6549], grad_fn=<SelectBackward>)
2944
>>> s1.kl(s2)
30-
tensor([1.9768, 5.1978, 8.6055], grad_fn=<SelectBackward>)
45+
tensor([4.0333, 4.3807], grad_fn=<SelectBackward>)
3146
"""
3247

3348
def __init__(self, scores, trans=None, lens=None):
@@ -51,6 +66,10 @@ def __add__(self, other):
5166
def argmax(self):
5267
return self.lens.new_zeros(self.mask.shape).masked_scatter_(self.mask, torch.where(self.backward(self.max.sum()))[2])
5368

69+
def topk(self, k):
70+
preds = torch.stack([torch.where(self.backward(i))[2] for i in self.kmax(k).sum(0)], -1)
71+
return self.lens.new_zeros(*self.mask.shape, k).masked_scatter_(self.mask.unsqueeze(-1), preds)
72+
5473
def score(self, value):
5574
scores, mask, value = self.scores.transpose(0, 1), self.mask.t(), value.t()
5675
prev, succ = torch.cat((torch.full_like(value[:1], -1), value[:-1]), 0), value

0 commit comments

Comments
 (0)