Skip to content

Commit d994f8e

Browse files
committed
Simplify context layer in span ranking SRL
1 parent f6e085a commit d994f8e

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

hanlp/components/srl/span_rank/span_ranking_srl_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def __init__(self, config, embed: torch.nn.Module, context_layer: torch.nn.Modul
473473
self.embed = embed
474474
# Initialize context layer
475475
self.context_layer = context_layer
476-
context_layer_output_dim = context_layer.get_output_dim()
476+
context_layer_output_dim = context_layer.get_output_dim() if context_layer else self.word_embedding_dim
477477
self.decoder = SpanRankingSRLDecoder(context_layer_output_dim, label_space_size, config)
478478

479479
def forward(self,
@@ -484,9 +484,10 @@ def forward(self,
484484

485485
context_embeddings = self.embed(batch)
486486
context_embeddings = F.dropout(context_embeddings, self.lexical_dropout, self.training)
487-
contextualized_embeddings = self.context_layer(context_embeddings, masks)
487+
if self.context_layer:
488+
context_embeddings = self.context_layer(context_embeddings, masks)
488489

489-
return self.decoder.decode(contextualized_embeddings, sent_lengths, masks, gold_arg_starts, gold_arg_ends,
490+
return self.decoder.decode(context_embeddings, sent_lengths, masks, gold_arg_starts, gold_arg_ends,
490491
gold_arg_labels, gold_predicates)
491492

492493
@staticmethod

0 commit comments

Comments
 (0)