Skip to content

Commit 155704f

Browse files
committed
Minor modifications
1 parent dc026e1 commit 155704f

File tree

4 files changed

+71
-8
lines changed

4 files changed

+71
-8
lines changed

supar/models/dependency.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,14 +536,78 @@ class CRF2oDependencyModel(BiaffineDependencyModel):
536536
https://www.aclweb.org/anthology/2020.acl-main.302/
537537
"""
538538

539-
def __init__(self, n_lstm_hidden=400, n_mlp_sib=100, mlp_dropout=.33, **kwargs):
540-
super().__init__(**kwargs)
539+
def __init__(self,
540+
n_words,
541+
n_feats,
542+
n_rels,
543+
feat='char',
544+
n_embed=100,
545+
n_feat_embed=100,
546+
n_char_embed=50,
547+
bert=None,
548+
n_bert_layers=4,
549+
mix_dropout=.0,
550+
embed_dropout=.33,
551+
n_lstm_hidden=400,
552+
n_lstm_layers=3,
553+
lstm_dropout=.33,
554+
n_mlp_arc=500,
555+
n_mlp_sib=100,
556+
n_mlp_rel=100,
557+
mlp_dropout=.33,
558+
feat_pad_index=0,
559+
pad_index=0,
560+
unk_index=1,
561+
**kwargs):
562+
super().__init__(**Config().update(locals()))
541563

564+
# the embedding layer
565+
self.word_embed = nn.Embedding(num_embeddings=n_words,
566+
embedding_dim=n_embed)
567+
if feat == 'char':
568+
self.feat_embed = CharLSTM(n_chars=n_feats,
569+
n_embed=n_char_embed,
570+
n_out=n_feat_embed,
571+
pad_index=feat_pad_index)
572+
elif feat == 'bert':
573+
self.feat_embed = BertEmbedding(model=bert,
574+
n_layers=n_bert_layers,
575+
n_out=n_feat_embed,
576+
pad_index=feat_pad_index,
577+
dropout=mix_dropout)
578+
self.n_feat_embed = self.feat_embed.n_out
579+
elif feat == 'tag':
580+
self.feat_embed = nn.Embedding(num_embeddings=n_feats,
581+
embedding_dim=n_feat_embed)
582+
else:
583+
raise RuntimeError("The feat type should be in ['char', 'bert', 'tag'].")
584+
self.embed_dropout = IndependentDropout(p=embed_dropout)
585+
586+
# the lstm layer
587+
self.lstm = LSTM(input_size=n_embed+n_feat_embed,
588+
hidden_size=n_lstm_hidden,
589+
num_layers=n_lstm_layers,
590+
bidirectional=True,
591+
dropout=lstm_dropout)
592+
self.lstm_dropout = SharedDropout(p=lstm_dropout)
593+
594+
# the MLP layers
595+
self.mlp_arc_d = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_arc, dropout=mlp_dropout)
596+
self.mlp_arc_h = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_arc, dropout=mlp_dropout)
542597
self.mlp_sib_s = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_sib, dropout=mlp_dropout)
543598
self.mlp_sib_d = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_sib, dropout=mlp_dropout)
544599
self.mlp_sib_h = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_sib, dropout=mlp_dropout)
600+
self.mlp_rel_d = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_rel, dropout=mlp_dropout)
601+
self.mlp_rel_h = MLP(n_in=n_lstm_hidden*2, n_out=n_mlp_rel, dropout=mlp_dropout)
545602

603+
# the Biaffine layers
604+
self.arc_attn = Biaffine(n_in=n_mlp_arc, bias_x=True, bias_y=False)
546605
self.sib_attn = Triaffine(n_in=n_mlp_sib, bias_x=True, bias_y=True)
606+
self.rel_attn = Biaffine(n_in=n_mlp_rel, n_out=n_rels, bias_x=True, bias_y=True)
607+
self.criterion = nn.CrossEntropyLoss()
608+
self.pad_index = pad_index
609+
self.unk_index = unk_index
610+
547611
self.crf = CRF2oDependency()
548612

549613
def forward(self, words, feats):

supar/utils/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ def __init__(self, conf=None, **kwargs):
1717
**kwargs})
1818

1919
def __repr__(self):
20-
s = line = "-" * 15 + "-+-" + "-" * 25 + "\n"
21-
s += f"{'Param':15} | {'Value':^25}\n" + line
20+
s = line = "-" * 20 + "-+-" + "-" * 30 + "\n"
21+
s += f"{'Param':20} | {'Value':^30}\n" + line
2222
for name, value in vars(self).items():
23-
s += f"{name:15} | {str(value):^25}\n"
23+
s += f"{name:20} | {str(value):^30}\n"
2424
s += line
2525

2626
return s

supar/utils/field.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def __init__(self, name, pad=None, unk=None, bos=None, eos=None,
8282
self.tokenize = tokenize
8383
self.fn = fn
8484

85-
self.specials = [token for token in [pad, unk, bos, eos]
86-
if token is not None]
85+
self.specials = [token for token in [pad, unk, bos, eos] if token is not None]
8786

8887
def __repr__(self):
8988
s, params = f"({self.name}): {self.__class__.__name__}(", []

supar/utils/logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def init_logger(logger,
3232

3333
def progress_bar(iterator,
3434
ncols=None,
35-
bar_format='{l_bar}{bar:36}| {n_fmt}/{total_fmt} {elapsed}<{remaining}, {rate_fmt}{postfix}',
35+
bar_format='{l_bar}{bar:18}| {n_fmt}/{total_fmt} {elapsed}<{remaining}, {rate_fmt}{postfix}',
3636
leave=True):
3737
return tqdm(iterator,
3838
ncols=ncols,

0 commit comments

Comments
 (0)