@@ -48,16 +48,16 @@ def kl(self, other):
48
48
raise NotImplementedError
49
49
50
50
def score (self , value ):
51
+ arcs = value
51
52
if self .partial :
52
53
mask , lens = self .mask , self .lens
53
- mask = self . mask .index_fill (1 , self .lens .new_tensor (0 ), 1 )
54
+ mask = mask .index_fill (1 , self .lens .new_tensor (0 ), 1 )
54
55
mask = mask .unsqueeze (1 ) & mask .unsqueeze (2 )
55
- value = value .index_fill (1 , lens .new_tensor (0 ), - 1 ).unsqueeze (- 1 )
56
- value = value .eq (lens .new_tensor (range (mask .shape [1 ]))) | value .lt (0 )
57
- value = value & mask
58
- scores = LogSemiring .zero_mask (self .scores , ~ value )
56
+ arcs = arcs .index_fill (1 , lens .new_tensor (0 ), - 1 ).unsqueeze (- 1 )
57
+ arcs = arcs .eq (lens .new_tensor (range (mask .shape [1 ]))) | arcs .lt (0 )
58
+ scores = LogSemiring .zero_mask (self .scores , ~ (arcs & mask ))
59
59
return self .__class__ (scores , self .mask , ** self .kwargs ).log_partition
60
- return LogSemiring .prod (LogSemiring .one_mask (self .scores .gather (- 1 , value .unsqueeze (- 1 )).squeeze (- 1 ), ~ self .mask ), - 1 )
60
+ return LogSemiring .prod (LogSemiring .one_mask (self .scores .gather (- 1 , arcs .unsqueeze (- 1 )).squeeze (- 1 ), ~ self .mask ), - 1 )
61
61
62
62
@torch .enable_grad ()
63
63
def forward (self , semiring ):
@@ -76,10 +76,9 @@ def forward(self, semiring):
76
76
77
77
s_arc = self .scores
78
78
mask , lens = self .mask , self .lens
79
- lens = mask .sum (- 1 )
80
79
batch_size , seq_len , _ = s_arc .shape
81
80
mask = mask .index_fill (1 , lens .new_tensor (0 ), 1 )
82
- s_arc = semiring .zero_mask (s_arc , mask .unsqueeze (- 1 ) & mask .unsqueeze (- 2 ))
81
+ s_arc = semiring .zero_mask (s_arc , ~ ( mask .unsqueeze (- 1 ) & mask .unsqueeze (- 2 ) ))
83
82
84
83
# A(i, j) = exp(s(i, j))
85
84
# double precision to prevent overflows
@@ -93,7 +92,7 @@ def forward(self, semiring):
93
92
# L(i, j) = D(i, j) - A(i, j)
94
93
L = nn .init .eye_ (torch .empty_like (A [0 ])).repeat (batch_size , 1 , 1 ).masked_scatter_ (mask .unsqueeze (- 1 ), (D - A )[mask ])
95
94
# Z = L^(0, 0), the minor of L w.r.t row 0 and column 0
96
- return L [:, 1 :, 1 :].logdet () .float ()
95
+ return L [:, 1 :, 1 :].slogdet ()[ 1 ] .float ()
97
96
98
97
99
98
class CRFDependency (StructuredDistribution ):
0 commit comments