Skip to content

Commit 845bf05

Browse files
author
zysite
committed
Trivial modifications
1 parent e779bb1 commit 845bf05

File tree

5 files changed

+13
-16
lines changed

5 files changed

+13
-16
lines changed

parser/modules/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
from .char_lstm import CHAR_LSTM
88
from .mlp import MLP
99

10-
__all__ = ['MLP', 'BertEmbedding', 'Biaffine',
11-
'BiLSTM', 'CHAR_LSTM', 'dropout']
10+
__all__ = ['CHAR_LSTM', 'MLP', 'BertEmbedding',
11+
'Biaffine', 'BiLSTM', 'dropout']

parser/modules/bilstm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ def layer_forward(self, x, hx, cell, batch_sizes, reverse=False):
8181
output.reverse()
8282
else:
8383
hx_n.append(hx_i)
84-
hx_n.reverse()
85-
hx_n = [torch.cat(h) for h in zip(*hx_n)]
84+
hx_n = [torch.cat(h) for h in zip(*reversed(hx_n))]
8685
output = torch.cat(output)
8786

8887
return output, hx_n

parser/utils/alg.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ def kmeans(x, k):
1515
# assign labels to each datapoint based on centroids
1616
dists, y = torch.abs_(d.unsqueeze(-1) - c).min(dim=-1)
1717
# make sure number of datapoints is greater than that of clusters
18-
if len(d) < k:
19-
raise AssertionError(f"unable to assign {len(d)} datapoints to "
20-
f"{k} clusters")
18+
assert len(d) >= k, f"unable to assign {len(d)} datapoints to {k} clusters"
2119

2220
while old is None or not c.equal(old):
2321
# if an empty cluster is encountered,
@@ -59,27 +57,27 @@ def eisner(scores, mask):
5957
for w in range(1, seq_len):
6058
n = seq_len - w
6159
starts = p_i.new_tensor(range(n)).unsqueeze(0)
62-
# ilr = C(i, r) + C(j, r+1)
60+
# ilr = C(i->r) + C(j->r+1)
6361
ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
6462
# [batch_size, n, w]
6563
ilr = ilr.permute(2, 0, 1)
6664
il = ilr + scores.diagonal(-w).unsqueeze(-1)
67-
# I(j, i) = max(C(i, r) + C(j, r+1) + S(j, i)), i <= r < j
65+
# I(j->i) = max(C(i->r) + C(j->r+1) + s(j->i)), i <= r < j
6866
il_span, il_path = il.max(-1)
6967
s_i.diagonal(-w).copy_(il_span)
7068
p_i.diagonal(-w).copy_(il_path + starts)
7169
ir = ilr + scores.diagonal(w).unsqueeze(-1)
72-
# I(i, j) = max(C(i, r) + C(j, r+1) + S(i, j)), i <= r < j
70+
# I(i->j) = max(C(i->r) + C(j->r+1) + s(i->j)), i <= r < j
7371
ir_span, ir_path = ir.max(-1)
7472
s_i.diagonal(w).copy_(ir_span)
7573
p_i.diagonal(w).copy_(ir_path + starts)
7674

77-
# C(j, i) = max(C(r, i) + I(j, r)), i <= r < j
78-
cl = stripe(s_c, n, w, dim=0) + stripe(s_i, n, w, (w, 0))
75+
# C(j->i) = max(C(r->i) + I(j->r)), i <= r < j
76+
cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
7977
cl_span, cl_path = cl.permute(2, 0, 1).max(-1)
8078
s_c.diagonal(-w).copy_(cl_span)
8179
p_c.diagonal(-w).copy_(cl_path + starts)
82-
# C(i, j) = max(I(i, r) + C(r, j)), i < r <= j
80+
# C(i->j) = max(I(i->r) + C(r->j)), i < r <= j
8381
cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
8482
cr_span, cr_path = cr.permute(2, 0, 1).max(-1)
8583
s_c.diagonal(w).copy_(cr_span)
@@ -136,7 +134,7 @@ def stripe(x, n, w, offset=(0, 0), dim=1):
136134
tensor([[ 0, 5, 10],
137135
[ 6, 11, 16]])
138136
'''
139-
seq_len = x.size(1)
137+
x, seq_len = x.contiguous(), x.size(1)
140138
stride, numel = list(x.stride()), x[0, 0].numel()
141139
stride[0] = (seq_len + 1) * numel
142140
stride[1] = (1 if dim == 1 else seq_len) * numel

parser/utils/corpus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def load(cls, path, fields):
7676
lines = [line.strip() for line in f]
7777
for i, line in enumerate(lines):
7878
if not line:
79-
values = list(zip(*[l.split() for l in lines[start:i]]))
79+
values = list(zip(*[l.split('\t') for l in lines[start:i]]))
8080
sentences.append(Sentence(fields, values))
8181
start = i + 1
8282

parser/utils/metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self, eps=1e-5):
1212
self.correct_rels = 0.0
1313

1414
def __repr__(self):
15-
return f"UAS: {self.uas:.2%} LAS: {self.las:.2%}"
15+
return f"UAS: {self.uas:6.2%} LAS: {self.las:6.2%}"
1616

1717
def __call__(self, arc_preds, rel_preds, arc_golds, rel_golds, mask):
1818
arc_mask = arc_preds.eq(arc_golds)[mask]

0 commit comments

Comments
 (0)