Skip to content

Commit fa87743

Browse files
committed
Make metrics addable
1 parent 7e1e37c commit fa87743

File tree

1 file changed

+44
-12
lines changed

1 file changed

+44
-12
lines changed

supar/utils/metric.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,21 @@
1010

1111
class Metric(object):
1212

13-
def __lt__(self, other: 'Metric') -> bool:
13+
def __lt__(self, other: Metric) -> bool:
1414
return self.score < other
1515

16-
def __le__(self, other: 'Metric') -> bool:
16+
def __le__(self, other: Metric) -> bool:
1717
return self.score <= other
1818

19-
def __ge__(self, other: 'Metric') -> bool:
19+
def __ge__(self, other: Metric) -> bool:
2020
return self.score >= other
2121

22-
def __gt__(self, other: 'Metric') -> bool:
22+
def __gt__(self, other: Metric) -> bool:
2323
return self.score > other
2424

25+
def __add__(self, other: Metric) -> Metric:
26+
raise NotImplementedError
27+
2528
@property
2629
def score(self):
2730
return 0.
@@ -68,6 +71,16 @@ def __call__(
6871
self.correct_rels += rel_mask_seq.sum().item()
6972
return self
7073

74+
def __add__(self, other: AttachmentMetric) -> AttachmentMetric:
75+
metric = AttachmentMetric(self.eps)
76+
metric.n = self.n + other.n
77+
metric.n_ucm = self.n_ucm + other.n_ucm
78+
metric.n_lcm = self.n_lcm + other.n_lcm
79+
metric.total = self.total + other.total
80+
metric.correct_arcs = self.correct_arcs + other.correct_arcs
81+
metric.correct_rels = self.correct_rels + other.correct_rels
82+
return metric
83+
7184
@property
7285
def score(self):
7386
return self.las
@@ -103,6 +116,13 @@ def __init__(self, eps: float = 1e-12) -> SpanMetric:
103116
self.gold = 0.0
104117
self.eps = eps
105118

119+
def __repr__(self):
120+
s = f"UCM: {self.ucm:6.2%} LCM: {self.lcm:6.2%} "
121+
s += f"UP: {self.up:6.2%} UR: {self.ur:6.2%} UF: {self.uf:6.2%} "
122+
s += f"LP: {self.lp:6.2%} LR: {self.lr:6.2%} LF: {self.lf:6.2%}"
123+
124+
return s
125+
106126
def __call__(self, preds: List[List[Tuple]], golds: List[List[Tuple]]) -> SpanMetric:
107127
for pred, gold in zip(preds, golds):
108128
upred, ugold = Counter([tuple(span[:-1]) for span in pred]), Counter([tuple(span[:-1]) for span in gold])
@@ -117,12 +137,16 @@ def __call__(self, preds: List[List[Tuple]], golds: List[List[Tuple]]) -> SpanMe
117137
self.gold += len(gold)
118138
return self
119139

120-
def __repr__(self):
121-
s = f"UCM: {self.ucm:6.2%} LCM: {self.lcm:6.2%} "
122-
s += f"UP: {self.up:6.2%} UR: {self.ur:6.2%} UF: {self.uf:6.2%} "
123-
s += f"LP: {self.lp:6.2%} LR: {self.lr:6.2%} LF: {self.lf:6.2%}"
124-
125-
return s
140+
def __add__(self, other: SpanMetric) -> SpanMetric:
141+
metric = SpanMetric(self.eps)
142+
metric.n = self.n + other.n
143+
metric.n_ucm = self.n_ucm + other.n_ucm
144+
metric.n_lcm = self.n_lcm + other.n_lcm
145+
metric.utp = self.utp + other.utp
146+
metric.ltp = self.ltp + other.ltp
147+
metric.pred = self.pred + other.pred
148+
metric.gold = self.gold + other.gold
149+
return metric
126150

127151
@property
128152
def score(self):
@@ -172,6 +196,9 @@ def __init__(self, eps: float = 1e-12) -> ChartMetric:
172196
self.gold = 0.0
173197
self.eps = eps
174198

199+
def __repr__(self):
200+
return f"UP: {self.up:6.2%} UR: {self.ur:6.2%} UF: {self.uf:6.2%} P: {self.p:6.2%} R: {self.r:6.2%} F: {self.f:6.2%}"
201+
175202
def __call__(self, preds: torch.Tensor, golds: torch.Tensor) -> ChartMetric:
176203
pred_mask = preds.ge(0)
177204
gold_mask = golds.ge(0)
@@ -182,8 +209,13 @@ def __call__(self, preds: torch.Tensor, golds: torch.Tensor) -> ChartMetric:
182209
self.utp += span_mask.sum().item()
183210
return self
184211

185-
def __repr__(self):
186-
return f"UP: {self.up:6.2%} UR: {self.ur:6.2%} UF: {self.uf:6.2%} P: {self.p:6.2%} R: {self.r:6.2%} F: {self.f:6.2%}"
212+
def __add__(self, other: ChartMetric) -> ChartMetric:
213+
metric = ChartMetric(self.eps)
214+
metric.tp = self.tp + other.tp
215+
metric.utp = self.utp + other.utp
216+
metric.pred = self.pred + other.pred
217+
metric.gold = self.gold + other.gold
218+
return metric
187219

188220
@property
189221
def score(self):

0 commit comments

Comments
 (0)