Skip to content

Commit 40fe60a

Browse files
author
zysite
committed
Trivial modifications
1 parent 66d650f commit 40fe60a

File tree

3 files changed

+32
-45
lines changed

3 files changed

+32
-45
lines changed

parser/metric.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,39 +3,17 @@
33

44
class Metric(object):
55

6-
def __lt__(self, other):
7-
return self.score < other
8-
9-
def __le__(self, other):
10-
return self.score <= other
11-
12-
def __eq__(self, other):
13-
return self.score == other
14-
15-
def __ge__(self, other):
16-
return self.score >= other
17-
18-
def __gt__(self, other):
19-
return self.score > other
20-
21-
def __ne__(self, other):
22-
return self.score != other
23-
24-
@property
25-
def score(self):
26-
raise AttributeError
27-
28-
29-
class AttachmentMethod(Metric):
30-
316
def __init__(self, eps=1e-5):
32-
super(AttachmentMethod, self).__init__()
7+
super(Metric, self).__init__()
338

349
self.eps = eps
3510
self.total = 0.0
3611
self.correct_arcs = 0.0
3712
self.correct_rels = 0.0
3813

14+
def __repr__(self):
15+
return f"UAS: {self.uas:.2%} LAS: {self.las:.2%}"
16+
3917
def __call__(self, pred_arcs, pred_rels, gold_arcs, gold_rels):
4018
arc_mask = pred_arcs.eq(gold_arcs)
4119
rel_mask = pred_rels.eq(gold_rels) & arc_mask
@@ -44,8 +22,17 @@ def __call__(self, pred_arcs, pred_rels, gold_arcs, gold_rels):
4422
self.correct_arcs += arc_mask.sum().item()
4523
self.correct_rels += rel_mask.sum().item()
4624

47-
def __repr__(self):
48-
return f"UAS: {self.uas:.2%} LAS: {self.las:.2%}"
25+
def __lt__(self, other):
26+
return self.score < other
27+
28+
def __le__(self, other):
29+
return self.score <= other
30+
31+
def __ge__(self, other):
32+
return self.score >= other
33+
34+
def __gt__(self, other):
35+
return self.score > other
4936

5037
@property
5138
def score(self):

parser/modules/bilstm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def layer_forward(self, x, hx, cell, batch_sizes, reverse=False):
4646
hid_mask = SharedDropout.get_mask(h, self.dropout)
4747

4848
for t in steps:
49-
batch_size = batch_sizes[t]
50-
if len(h) < batch_size:
49+
last_batch_size, batch_size = len(h), batch_sizes[t]
50+
if last_batch_size < batch_size:
5151
h = torch.cat((h, init_h[last_batch_size:batch_size]))
5252
c = torch.cat((c, init_c[last_batch_size:batch_size]))
5353
else:
@@ -57,7 +57,6 @@ def layer_forward(self, x, hx, cell, batch_sizes, reverse=False):
5757
output.append(h)
5858
if self.training:
5959
h = h * hid_mask[:batch_size]
60-
last_batch_size = batch_size
6160
if reverse:
6261
output.reverse()
6362
output = torch.cat(output)

parser/utils/vocab.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,10 @@ def __init__(self, words, tags, rels):
3232
self.n_train_words = self.n_words
3333

3434
def __repr__(self):
35-
info = f"{self.__class__.__name__}(\n"
36-
info += f" num of words: {self.n_words}\n"
37-
info += f" num of tags: {self.n_tags}\n"
38-
info += f" num of rels: {self.n_rels}\n"
39-
info += f")"
35+
info = f"{self.__class__.__name__}: "
36+
info += f"{self.n_words} words, "
37+
info += f"{self.n_tags} tags, "
38+
info += f"{self.n_rels} rels"
4039

4140
return info
4241

@@ -55,20 +54,20 @@ def rel2id(self, sequence):
5554
def id2rel(self, ids):
5655
return [self.rels[i] for i in ids]
5756

58-
def read_embeddings(self, embed, unk=None):
59-
words = embed.words
60-
# if the UNK token has existed in pretrained vocab,
61-
# then replace it with a self-defined one
62-
if unk in embed:
63-
words[words.index(unk)] = self.UNK
57+
def read_embeddings(self, embed, smooth=True):
58+
# if the UNK token has existed in the pretrained,
59+
# then use it to replace the one in the vocab
60+
if embed.unk:
61+
self.UNK = embed.unk
6462

65-
self.extend(words)
63+
self.extend(embed.tokens)
6664
self.embeddings = torch.zeros(self.n_words, embed.dim)
6765

6866
for i, word in enumerate(self.words):
6967
if word in embed:
7068
self.embeddings[i] = embed[word]
71-
self.embeddings /= torch.std(self.embeddings)
69+
if smooth:
70+
self.embeddings /= torch.std(self.embeddings)
7271

7372
def extend(self, words):
7473
self.words.extend(sorted(set(words).difference(self.word_dict)))
@@ -77,9 +76,11 @@ def extend(self, words):
7776
if regex.match(r'\p{P}+$', word))
7877
self.n_words = len(self.words)
7978

80-
def numericalize(self, corpus):
79+
def numericalize(self, corpus, training=True):
8180
words = [self.word2id(seq) for seq in corpus.words]
8281
tags = [self.tag2id(seq) for seq in corpus.tags]
82+
if not training:
83+
return words, tags
8384
arcs = [torch.tensor(seq) for seq in corpus.heads]
8485
rels = [self.rel2id(seq) for seq in corpus.rels]
8586

0 commit comments

Comments
 (0)