Skip to content

Commit 66d650f

Browse files
author
zysite
committed
Replace LSTM with BiLSTM
1 parent e8542fe commit 66d650f

File tree

3 files changed

+23
-32
lines changed

3 files changed

+23
-32
lines changed

parser/modules/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# -*- coding: utf-8 -*-
22

33
from .biaffine import Biaffine
4-
from .lstm import LSTM
4+
from .bilstm import BiLSTM
5+
from .dropout import IndependentDropout, SharedDropout
56
from .mlp import MLP
67

78

8-
__all__ = ['LSTM', 'MLP', 'Biaffine']
9+
__all__ = ['MLP', 'Biaffine', 'BiLSTM', 'IndependentDropout', 'SharedDropout']

parser/modules/lstm.py renamed to parser/modules/bilstm.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,24 @@
77
from torch.nn.utils.rnn import PackedSequence
88

99

10-
class LSTM(nn.Module):
10+
class BiLSTM(nn.Module):
1111

12-
def __init__(self, input_size, hidden_size, num_layers=1,
13-
dropout=0, bidirectional=False):
14-
super(LSTM, self).__init__()
12+
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0):
13+
super(BiLSTM, self).__init__()
1514

1615
self.input_size = input_size
1716
self.hidden_size = hidden_size
1817
self.num_layers = num_layers
1918
self.dropout = dropout
20-
self.bidirectional = bidirectional
21-
self.num_directions = 2 if bidirectional else 1
2219

2320
self.f_cells = nn.ModuleList()
2421
self.b_cells = nn.ModuleList()
2522
for layer in range(self.num_layers):
2623
self.f_cells.append(nn.LSTMCell(input_size=input_size,
2724
hidden_size=hidden_size))
28-
if bidirectional:
29-
self.b_cells.append(nn.LSTMCell(input_size=input_size,
30-
hidden_size=hidden_size))
31-
input_size = hidden_size * self.num_directions
25+
self.b_cells.append(nn.LSTMCell(input_size=input_size,
26+
hidden_size=hidden_size))
27+
input_size = hidden_size * 2
3228

3329
self.reset_parameters()
3430

@@ -88,17 +84,12 @@ def forward(self, x, hx=None):
8884
cell=self.f_cells[layer],
8985
batch_sizes=batch_sizes,
9086
reverse=False)
91-
92-
if self.bidirectional:
93-
b_output = self.layer_forward(x=x,
94-
hx=hx,
95-
cell=self.b_cells[layer],
96-
batch_sizes=batch_sizes,
97-
reverse=True)
98-
if self.bidirectional:
99-
x = torch.cat([f_output, b_output], -1)
100-
else:
101-
x = f_output
87+
b_output = self.layer_forward(x=x,
88+
hx=hx,
89+
cell=self.b_cells[layer],
90+
batch_sizes=batch_sizes,
91+
reverse=True)
92+
x = torch.cat([f_output, b_output], -1)
10293
x = PackedSequence(x, batch_sizes)
10394

10495
return x

parser/parser.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22

3-
from parser.modules import LSTM, MLP, Biaffine
4-
from parser.modules.dropout import IndependentDropout, SharedDropout
3+
from parser.modules import (MLP, Biaffine, BiLSTM, IndependentDropout,
4+
SharedDropout)
55

66
import torch
77
import torch.nn as nn
@@ -23,11 +23,10 @@ def __init__(self, params, embeddings):
2323
self.embed_dropout = IndependentDropout(p=params['embed_dropout'])
2424

2525
# the word-lstm layer
26-
self.lstm = LSTM(input_size=params['n_embed']+params['n_tag_embed'],
27-
hidden_size=params['n_lstm_hidden'],
28-
num_layers=params['n_lstm_layers'],
29-
dropout=params['lstm_dropout'],
30-
bidirectional=True)
26+
self.lstm = BiLSTM(input_size=params['n_embed']+params['n_tag_embed'],
27+
hidden_size=params['n_lstm_hidden'],
28+
num_layers=params['n_lstm_layers'],
29+
dropout=params['lstm_dropout'])
3130
self.lstm_dropout = SharedDropout(p=params['lstm_dropout'])
3231

3332
# the MLP layers
@@ -82,7 +81,7 @@ def forward(self, words, tags):
8281
x, _ = pad_packed_sequence(x, True)
8382
x = self.lstm_dropout(x)[inverse_indices]
8483

85-
# apply MLPs to the LSTM output states
84+
# apply MLPs to the BiLSTM output states
8685
arc_h = self.mlp_arc_h(x)
8786
arc_d = self.mlp_arc_d(x)
8887
rel_h = self.mlp_rel_h(x)
@@ -94,7 +93,7 @@ def forward(self, words, tags):
9493
# [batch_size, seq_len, seq_len, n_rels]
9594
s_rel = self.rel_attn(rel_d, rel_h).permute(0, 2, 3, 1)
9695
# set the scores that exceed the length of each sentence to -inf
97-
s_arc.masked_fill_((1 - mask).unsqueeze(1), float('-inf'))
96+
s_arc.masked_fill_(~mask.unsqueeze(1), float('-inf'))
9897

9998
return s_arc, s_rel
10099

0 commit comments

Comments
 (0)