Skip to content

Commit 71d097e

Browse files
committed
Cross Entropy & KL for non-proj trees
1 parent 50bd788 commit 71d097e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

supar/structs/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def entropy(self):
7272
return self.log_partition - (self.marginals * self.scores).sum((-1, -2))
7373

7474
def cross_entropy(self, other):
75-
raise NotImplementedError
75+
return other.log_partition - (self.marginals * other.scores).sum((-1, -2))
7676

7777
def kl(self, other):
78-
raise NotImplementedError
78+
return other.log_partition - self.log_partition + (self.marginals * (self.scores - other.scores)).sum((-1, -2))
7979

8080
def score(self, value, partial=False):
8181
arcs = value

0 commit comments

Comments
 (0)