@@ -30,14 +30,18 @@ class MatrixTree(StructuredDistribution):
30
30
>>> s1 = MatrixTree(torch.randn(batch_size, seq_len, seq_len), lens)
31
31
>>> s2 = MatrixTree(torch.randn(batch_size, seq_len, seq_len), lens)
32
32
>>> s1.max
33
- tensor([2.6816, 7.2115 ], grad_fn=<CopyBackwards >)
33
+ tensor([0.7174, 3.7910 ], grad_fn=<SumBackward1 >)
34
34
>>> s1.argmax
35
- tensor([[0, 0, 3 , 1, 0],
36
- [0, 3, 0, 2 , 3]])
35
+ tensor([[0, 0, 1 , 1, 0],
36
+ [0, 4, 1, 0 , 3]])
37
37
>>> s1.log_partition
38
- tensor([2.6816, 7.2115 ], grad_fn=<CopyBackwards>)
38
+ tensor([2.0229, 6.0558 ], grad_fn=<CopyBackwards>)
39
39
>>> s1.log_prob(arcs)
40
- tensor([-0.7524, -3.0046], grad_fn=<SubBackward0>)
40
+ tensor([-3.2209, -2.5756], grad_fn=<SubBackward0>)
41
+ >>> s1.entropy
42
+ tensor([1.9711, 3.4497], grad_fn=<SubBackward0>)
43
+ >>> s1.kl(s2)
44
+ tensor([1.3354, 2.6914], grad_fn=<AddBackward0>)
41
45
"""
42
46
43
47
def __init__ (self , scores , lens = None , multiroot = False ):
@@ -56,9 +60,15 @@ def __repr__(self):
56
60
def __add__ (self , other ):
57
61
return MatrixTree (torch .stack ((self .scores , other .scores )), self .lens , self .multiroot )
58
62
63
+ @lazy_property
64
+ def max (self ):
65
+ arcs = self .argmax
66
+ return LogSemiring .prod (LogSemiring .one_mask (self .scores .gather (- 1 , arcs .unsqueeze (- 1 )).squeeze (- 1 ), ~ self .mask ), - 1 )
67
+
59
68
@lazy_property
60
69
def argmax (self ):
61
- return mst (self .scores , self .mask , self .multiroot )
70
+ with torch .no_grad ():
71
+ return mst (self .scores , self .mask , self .multiroot )
62
72
63
73
def kmax (self , k ):
64
74
# TODO: Camerini algorithm
@@ -92,9 +102,8 @@ def score(self, value, partial=False):
92
102
@torch .enable_grad ()
93
103
def forward (self , semiring ):
94
104
s_arc = self .scores
95
- mask , lens = self .mask , self .lens
96
- batch_size , seq_len , _ = s_arc .shape
97
- mask = mask .index_fill (1 , lens .new_tensor (0 ), 1 )
105
+ batch_size , * _ = s_arc .shape
106
+ mask = self .mask .index_fill (1 , self .lens .new_tensor (0 ), 1 )
98
107
s_arc = semiring .zero_mask (s_arc , ~ (mask .unsqueeze (- 1 ) & mask .unsqueeze (- 2 )))
99
108
100
109
# A(i, j) = exp(s(i, j))
@@ -107,7 +116,11 @@ def forward(self, semiring):
107
116
D .diagonal (0 , 1 , 2 ).copy_ (A .sum (- 1 ))
108
117
# Laplacian matrix
109
118
# L(i, j) = D(i, j) - A(i, j)
110
- L = nn .init .eye_ (torch .empty_like (A [0 ])).repeat (batch_size , 1 , 1 ).masked_scatter_ (mask .unsqueeze (- 1 ), (D - A )[mask ])
119
+ L = D - A
120
+ if not self .multiroot :
121
+ L .diagonal (0 , 1 , 2 ).add_ (- A [..., 0 ])
122
+ L [..., 1 ] = A [..., 0 ]
123
+ L = nn .init .eye_ (torch .empty_like (A [0 ])).repeat (batch_size , 1 , 1 ).masked_scatter_ (mask .unsqueeze (- 1 ), L [mask ])
111
124
# Z = L^(0, 0), the minor of L w.r.t row 0 and column 0
112
125
return L [:, 1 :, 1 :].slogdet ()[1 ].float ()
113
126
0 commit comments