Skip to content

Commit d7d3e8f

Browse files
committed
Improve comments
1 parent 5848e7c commit d7d3e8f

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

supar/structs/tree.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -196,18 +196,18 @@ def forward(self, semiring):
196196

197197
# [n, batch_size, ...]
198198
il = ir = semiring.dot(stripe(s_c, n, w), stripe(s_c, n, w, (w, 1)), 1)
199-
# I(j->i) = logsumexp(C(i->r) + C(j->r+1)) + s(j->i), i <= r < j
199+
# I(j->i) = <C(i->r), C(j->r+1)> * s(j->i), i <= r < j
200200
# fill the w-th diagonal of the lower triangular part of s_i with I(j->i) of n spans
201201
s_i.diagonal(-w).copy_(semiring.mul(il, s_arc.diagonal(-w).movedim(-1, 0)).movedim(0, -1))
202-
# I(i->j) = logsumexp(C(i->r) + C(j->r+1)) + s(i->j), i <= r < j
202+
# I(i->j) = <C(i->r), C(j->r+1)> * s(i->j), i <= r < j
203203
# fill the w-th diagonal of the upper triangular part of s_i with I(i->j) of n spans
204204
s_i.diagonal(w).copy_(semiring.mul(ir, s_arc.diagonal(w).movedim(-1, 0)).movedim(0, -1))
205205

206206
# [n, batch_size, ...]
207-
# C(j->i) = logsumexp(C(r->i) + I(j->r)), i <= r < j
207+
# C(j->i) = <C(r->i), I(j->r)>, i <= r < j
208208
cl = semiring.dot(stripe(s_c, n, w, (0, 0), 0), stripe(s_i, n, w, (w, 0)), 1)
209209
s_c.diagonal(-w).copy_(cl.movedim(0, -1))
210-
# C(i->j) = logsumexp(I(i->r) + C(r->j)), i < r <= j
210+
# C(i->j) = <I(i->r), C(r->j)>, i < r <= j
211211
cr = semiring.dot(stripe(s_i, n, w, (0, 1)), stripe(s_c, n, w, (1, w), 0), 1)
212212
s_c.diagonal(w).copy_(cr.movedim(0, -1))
213213
if not self.multiroot:
@@ -310,19 +310,17 @@ def forward(self, semiring):
310310
for w in range(1, seq_len):
311311
n = seq_len - w
312312

313-
# I(j->i) = logsum(exp(I(j->r) + S(j->r, i)) +, i < r < j
314-
# exp(C(j->j) + C(i->j-1)))
315-
# + s(j->i)
313+
# I(j->i) = <I(j->r), S(j->r, i)> * s(j->i), i < r < j
314+
# <C(j->j), C(i->j-1)> * s(j->i), otherwise
316315
# [n, w, batch_size, ...]
317316
il = semiring.times(stripe(s_i, n, w, (w, 1)),
318317
stripe(s_s, n, w, (1, 0), 0),
319318
stripe(s_sib[range(w, n+w), range(n), :], n, w, (0, 1)))
320319
il[:, -1] = semiring.mul(stripe(s_c, n, 1, (w, w)), stripe(s_c, n, 1, (0, w - 1))).squeeze(1)
321320
il = semiring.sum(il, 1)
322321
s_i.diagonal(-w).copy_(semiring.mul(il, s_arc.diagonal(-w).movedim(-1, 0)).movedim(0, -1))
323-
# I(i->j) = logsum(exp(I(i->r) + S(i->r, j)) +, i < r < j
324-
# exp(C(i->i) + C(j->i+1)))
325-
# + s(i->j)
322+
# I(i->j) = <I(i->r), S(i->r, j)> * s(i->j), i < r < j
323+
# <C(i->i), C(j->i+1)> * s(i->j), otherwise
326324
# [n, w, batch_size, ...]
327325
ir = semiring.times(stripe(s_i, n, w),
328326
stripe(s_s, n, w, (0, w), 0),
@@ -335,16 +333,16 @@ def forward(self, semiring):
335333

336334
# [batch_size, ..., n]
337335
sl = sr = semiring.dot(stripe(s_c, n, w), stripe(s_c, n, w, (w, 1)), 1).movedim(0, -1)
338-
# S(j, i) = logsumexp(C(i->r) + C(j->r+1)), i <= r < j
336+
# S(j, i) = <C(i->r), C(j->r+1)>, i <= r < j
339337
s_s.diagonal(-w).copy_(sl)
340-
# S(i, j) = logsumexp(C(i->r) + C(j->r+1)), i <= r < j
338+
# S(i, j) = <C(i->r), C(j->r+1)>, i <= r < j
341339
s_s.diagonal(w).copy_(sr)
342340

343341
# [n, batch_size, ...]
344-
# C(j->i) = logsumexp(C(r->i) + I(j->r)), i <= r < j
342+
# C(j->i) = <C(r->i), I(j->r)>, i <= r < j
345343
cl = semiring.dot(stripe(s_c, n, w, (0, 0), 0), stripe(s_i, n, w, (w, 0)), 1)
346344
s_c.diagonal(-w).copy_(cl.movedim(0, -1))
347-
# C(i->j) = logsumexp(I(i->r) + C(r->j)), i < r <= j
345+
# C(i->j) = <I(i->r), C(r->j)>, i < r <= j
348346
cr = semiring.dot(stripe(s_i, n, w, (0, 1)), stripe(s_c, n, w, (1, w), 0), 1)
349347
s_c.diagonal(w).copy_(cr.movedim(0, -1))
350348
return semiring.unconvert(s_c)[0][self.lens, range(batch_size)]

0 commit comments

Comments
 (0)