8
8
9
9
class LinearChainCRF (StructuredDistribution ):
10
10
r"""
11
+ Linear-chain CRFs (:cite:`lafferty-etal-2001-crf`).
12
+
13
+ Args:
14
+ scores (~torch.Tensor): ``[batch_size, seq_len, n_tags]``.
15
+ Log potentials.
16
+ trans (~torch.Tensor): ``[n_tags+1, n_tags+1]``.
17
+ Transition scores.
18
+ ``trans[-1, :-1]``/``trans[:-1, -1]`` represent transitions for start/end positions respectively.
19
+ lens (~torch.LongTensor): ``[batch_size]``.
20
+ Sentence lengths for masking. Default: ``None``.
11
21
12
22
Examples:
13
23
>>> from supar import LinearChainCRF
14
- >>> batch_size, seq_len, n_tags = 3 , 5, 4
15
- >>> lens = torch.tensor([3, 4, 5 ])
24
+ >>> batch_size, seq_len, n_tags = 2 , 5, 4
25
+ >>> lens = torch.tensor([3, 4])
16
26
>>> value = torch.randint(n_tags, (batch_size, seq_len))
17
- >>> s1 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags), torch.randn(n_tags+1, n_tags+1), lens)
18
- >>> s2 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags), torch.randn(n_tags+1, n_tags+1), lens)
27
+ >>> s1 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags),
28
+ torch.randn(n_tags+1, n_tags+1),
29
+ lens)
30
+ >>> s2 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags),
31
+ torch.randn(n_tags+1, n_tags+1),
32
+ lens)
19
33
>>> s1.max
20
- tensor([2.4978, 5.7460, 4.9088 ], grad_fn=<MaxBackward0>)
34
+ tensor([4.4120, 8.9672 ], grad_fn=<MaxBackward0>)
21
35
>>> s1.argmax
22
- tensor([[3, 1, 3, 0, 0],
23
- [1, 0, 1, 0, 0],
24
- [2, 0, 1, 1, 0]])
36
+ tensor([[2, 0, 3, 0, 0],
37
+ [3, 3, 3, 2, 0]])
25
38
>>> s1.log_partition
26
- tensor([3.7812, 7.9180, 7.8031 ], grad_fn=<LogsumexpBackward>)
39
+ tensor([ 6.3486, 10.9106 ], grad_fn=<LogsumexpBackward>)
27
40
>>> s1.log_prob(value)
28
- tensor([ -8.9096, -11.3473, -9.6189], grad_fn=<SubBackward0>)
41
+ tensor([ -8.1515, -10.5572], grad_fn=<SubBackward0>)
42
+ >>> s1.entropy
43
+ tensor([3.4150, 3.6549], grad_fn=<SelectBackward>)
29
44
>>> s1.kl(s2)
30
- tensor([1.9768, 5.1978, 8.6055 ], grad_fn=<SelectBackward>)
45
+ tensor([4.0333, 4.3807 ], grad_fn=<SelectBackward>)
31
46
"""
32
47
33
48
def __init__ (self , scores , trans = None , lens = None ):
@@ -51,6 +66,10 @@ def __add__(self, other):
51
66
def argmax (self ):
52
67
return self .lens .new_zeros (self .mask .shape ).masked_scatter_ (self .mask , torch .where (self .backward (self .max .sum ()))[2 ])
53
68
69
+ def topk (self , k ):
70
+ preds = torch .stack ([torch .where (self .backward (i ))[2 ] for i in self .kmax (k ).sum (0 )], - 1 )
71
+ return self .lens .new_zeros (* self .mask .shape , k ).masked_scatter_ (self .mask .unsqueeze (- 1 ), preds )
72
+
54
73
def score (self , value ):
55
74
scores , mask , value = self .scores .transpose (0 , 1 ), self .mask .t (), value .t ()
56
75
prev , succ = torch .cat ((torch .full_like (value [:1 ], - 1 ), value [:- 1 ]), 0 ), value
0 commit comments