Skip to content

Commit 88b1d03

Browse files
author
zysite
committed
Make the interface compatible with pytorch
1 parent 2f32fbd commit 88b1d03

File tree

1 file changed

+77
-44
lines changed

1 file changed

+77
-44
lines changed

parser/modules/bilstm.py

Lines changed: 77 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
import torch.nn as nn
7+
from torch.nn.modules.rnn import apply_permutation
78
from torch.nn.utils.rnn import PackedSequence
89

910

@@ -19,7 +20,7 @@ def __init__(self, input_size, hidden_size, num_layers=1, dropout=0):
1920

2021
self.f_cells = nn.ModuleList()
2122
self.b_cells = nn.ModuleList()
22-
for layer in range(self.num_layers):
23+
for _ in range(self.num_layers):
2324
self.f_cells.append(nn.LSTMCell(input_size=input_size,
2425
hidden_size=hidden_size))
2526
self.b_cells.append(nn.LSTMCell(input_size=input_size,
@@ -28,67 +29,99 @@ def __init__(self, input_size, hidden_size, num_layers=1, dropout=0):
2829

2930
self.reset_parameters()
3031

32+
def __repr__(self):
33+
s = self.__class__.__name__ + '('
34+
s += f"{self.input_size}, {self.hidden_size}"
35+
if self.num_layers > 1:
36+
s += f", num_layers={self.num_layers}"
37+
if self.dropout > 0:
38+
s += f", dropout={self.dropout}"
39+
s += ')'
40+
41+
return s
42+
3143
def reset_parameters(self):
32-
for i in self.parameters():
44+
for param in self.parameters():
3345
# apply orthogonal_ to weight
34-
if len(i.shape) > 1:
35-
nn.init.orthogonal_(i)
46+
if len(param.shape) > 1:
47+
nn.init.orthogonal_(param)
3648
# apply zeros_ to bias
3749
else:
38-
nn.init.zeros_(i)
50+
nn.init.zeros_(param)
51+
52+
def permute_hidden(self, hx, permutation):
53+
if permutation is None:
54+
return hx
55+
h = apply_permutation(hx[0], permutation)
56+
c = apply_permutation(hx[1], permutation)
57+
58+
return h, c
3959

4060
def layer_forward(self, x, hx, cell, batch_sizes, reverse=False):
41-
h, c = hx
42-
init_h, init_c = h, c
43-
output, seq_len = [], len(x)
44-
steps = reversed(range(seq_len)) if reverse else range(seq_len)
61+
hx_0 = hx_i = hx
62+
hx_n, output = [], []
63+
steps = reversed(range(len(x))) if reverse else range(len(x))
4564
if self.training:
46-
hid_mask = SharedDropout.get_mask(h, self.dropout)
65+
hid_mask = SharedDropout.get_mask(hx_0[0], self.dropout)
4766

4867
for t in steps:
49-
last_batch_size, batch_size = len(h), batch_sizes[t]
68+
last_batch_size, batch_size = len(hx_i[0]), batch_sizes[t]
5069
if last_batch_size < batch_size:
51-
h = torch.cat((h, init_h[last_batch_size:batch_size]))
52-
c = torch.cat((c, init_c[last_batch_size:batch_size]))
70+
hx_i = [torch.cat((h, ih[last_batch_size:batch_size]))
71+
for h, ih in zip(hx_i, hx_0)]
5372
else:
54-
h = h[:batch_size]
55-
c = c[:batch_size]
56-
h, c = cell(input=x[t], hx=(h, c))
57-
output.append(h)
73+
hx_n.append([h[batch_size:] for h in hx_i])
74+
hx_i = [h[:batch_size] for h in hx_i]
75+
hx_i = [h for h in cell(x[t], hx_i)]
76+
output.append(hx_i[0])
5877
if self.training:
59-
h = h * hid_mask[:batch_size]
78+
hx_i[0] = hx_i[0] * hid_mask[:batch_size]
6079
if reverse:
80+
hx_n = hx_i
6181
output.reverse()
82+
else:
83+
hx_n.append(hx_i)
84+
hx_n.reverse()
85+
hx_n = [torch.cat(h) for h in zip(*hx_n)]
6286
output = torch.cat(output)
6387

64-
return output
88+
return output, hx_n
6589

66-
def forward(self, x, hx=None):
67-
x, batch_sizes = x
90+
def forward(self, sequence, hx=None):
91+
x, batch_sizes = sequence.data, sequence.batch_sizes.tolist()
6892
batch_size = batch_sizes[0]
93+
h_n, c_n = [], []
6994

7095
if hx is None:
71-
init = x.new_zeros(batch_size, self.hidden_size)
72-
hx = (init, init)
73-
74-
for layer in range(self.num_layers):
96+
ih = x.new_zeros(self.num_layers * 2, batch_size, self.hidden_size)
97+
h, c = ih, ih
98+
else:
99+
h, c = self.permute_hidden(hx, sequence.sorted_indices)
100+
h = h.view(self.num_layers, 2, batch_size, self.hidden_size)
101+
c = c.view(self.num_layers, 2, batch_size, self.hidden_size)
102+
103+
for i in range(self.num_layers):
104+
x = torch.split(x, batch_sizes)
75105
if self.training:
76-
mask = SharedDropout.get_mask(x[:batch_size], self.dropout)
77-
mask = torch.cat([mask[:batch_size]
78-
for batch_size in batch_sizes])
79-
x *= mask
80-
x = torch.split(x, batch_sizes.tolist())
81-
f_output = self.layer_forward(x=x,
82-
hx=hx,
83-
cell=self.f_cells[layer],
84-
batch_sizes=batch_sizes,
85-
reverse=False)
86-
b_output = self.layer_forward(x=x,
87-
hx=hx,
88-
cell=self.b_cells[layer],
89-
batch_sizes=batch_sizes,
90-
reverse=True)
91-
x = torch.cat([f_output, b_output], -1)
92-
x = PackedSequence(x, batch_sizes)
93-
94-
return x
106+
mask = SharedDropout.get_mask(x[0], self.dropout)
107+
x = [i * mask[:len(i)] for i in x]
108+
x_f, (h_f, c_f) = self.layer_forward(x=x,
109+
hx=(h[i, 0], c[i, 0]),
110+
cell=self.f_cells[i],
111+
batch_sizes=batch_sizes)
112+
x_b, (h_b, c_b) = self.layer_forward(x=x,
113+
hx=(h[i, 1], c[i, 1]),
114+
cell=self.b_cells[i],
115+
batch_sizes=batch_sizes,
116+
reverse=True)
117+
x = torch.cat((x_f, x_b), -1)
118+
h_n.append(torch.stack((h_f, h_b)))
119+
c_n.append(torch.stack((c_f, c_b)))
120+
x = PackedSequence(x,
121+
sequence.batch_sizes,
122+
sequence.sorted_indices,
123+
sequence.unsorted_indices)
124+
hx = torch.cat(h_n, 0), torch.cat(c_n, 0)
125+
hx = self.permute_hidden(hx, sequence.unsorted_indices)
126+
127+
return x, hx

0 commit comments

Comments
 (0)