Skip to content

Commit 894db67

Browse files
mingruimingruiWang Ming Rui
andauthored
Bugfix: Removal of padding_idx in BartLearnedPositionalEmbedding (huggingface#10200)
* Assumption of padding_idx <2 might not stand * Use offset instead of 2 * Fix with black * Change behavior to warning instead for backward compatibility. * Fix with black * Remove warning * Make padding_idx non-required * padding_idx fix for blenderbot * padding_idx fix for blenderbot_small * padding_idx fix for led * padding_idx fix for mbart * Remove extra whitespaces * padding_idx fix for template * Fix padding_idx passed to nn.Embedding mistake * Fixed padding_idx passed to positional embedding in template * Remove padding_idx from pytorch learned positional embeddings * Remove accidentally added quotes * Remove padding_idx from tf learned positional embeddings * Remove zeroing of weights in __init__ Co-authored-by: Wang Ming Rui <mingrui.wang@C02CJTUYMD6M.local>
1 parent 55fe80d commit 894db67

File tree

12 files changed

+49
-85
lines changed

12 files changed

+49
-85
lines changed

src/transformers/models/bart/modeling_bart.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,11 @@ class BartLearnedPositionalEmbedding(nn.Embedding):
108108
This module learns positional embeddings up to a fixed maximum size.
109109
"""
110110

111-
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
112-
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
111+
def __init__(self, num_embeddings: int, embedding_dim: int):
113112
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
114113
# and adjust num_embeddings appropriately. Other models dont have this hack
115114
self.offset = 2
116-
super().__init__(num_embeddings + self.offset, embedding_dim, padding_idx=padding_idx)
115+
super().__init__(num_embeddings + self.offset, embedding_dim)
117116

118117
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
119118
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
@@ -673,7 +672,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
673672
self.embed_positions = BartLearnedPositionalEmbedding(
674673
config.max_position_embeddings,
675674
embed_dim,
676-
self.padding_idx,
677675
)
678676
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
679677
self.layernorm_embedding = nn.LayerNorm(embed_dim)
@@ -836,7 +834,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
836834
self.embed_positions = BartLearnedPositionalEmbedding(
837835
config.max_position_embeddings,
838836
config.d_model,
839-
self.padding_idx,
840837
)
841838
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
842839
self.layernorm_embedding = nn.LayerNorm(config.d_model)

src/transformers/models/bart/modeling_tf_bart.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
113113
This module learns positional embeddings up to a fixed maximum size.
114114
"""
115115

116-
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
117-
assert padding_idx is not None, "padding_idx cannot be None"
116+
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
118117
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
119118
# and adjust num_embeddings appropriately. Other models dont have this hack
120119
self.offset = 2
@@ -632,7 +631,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings
632631
self.embed_positions = TFBartLearnedPositionalEmbedding(
633632
config.max_position_embeddings,
634633
config.d_model,
635-
self.padding_idx,
636634
name="embed_positions",
637635
)
638636
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
@@ -793,7 +791,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings
793791
self.embed_positions = TFBartLearnedPositionalEmbedding(
794792
config.max_position_embeddings,
795793
config.d_model,
796-
self.padding_idx,
797794
name="embed_positions",
798795
)
799796
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0

src/transformers/models/blenderbot/modeling_blenderbot.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,8 @@ class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
112112
This module learns positional embeddings up to a fixed maximum size.
113113
"""
114114

115-
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
116-
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
117-
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
115+
def __init__(self, num_embeddings: int, embedding_dim: int):
116+
super().__init__(num_embeddings, embedding_dim)
118117

119118
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
120119
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
@@ -635,7 +634,6 @@ def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding
635634
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
636635
config.max_position_embeddings,
637636
embed_dim,
638-
self.padding_idx,
639637
)
640638
self.layers = nn.ModuleList([BlenderbotEncoderLayer(config) for _ in range(config.encoder_layers)])
641639
self.layer_norm = nn.LayerNorm(config.d_model)
@@ -800,7 +798,6 @@ def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding
800798
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
801799
config.max_position_embeddings,
802800
config.d_model,
803-
self.padding_idx,
804801
)
805802
self.layers = nn.ModuleList([BlenderbotDecoderLayer(config) for _ in range(config.decoder_layers)])
806803
self.layer_norm = nn.LayerNorm(config.d_model)

src/transformers/models/blenderbot/modeling_tf_blenderbot.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings):
118118
This module learns positional embeddings up to a fixed maximum size.
119119
"""
120120

121-
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
122-
assert padding_idx is not None, "padding_idx cannot be None"
121+
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
123122
super().__init__(num_embeddings, embedding_dim, **kwargs)
124123

125124
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
@@ -629,7 +628,6 @@ def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[TFSharedEmbe
629628
self.embed_positions = TFBlenderbotLearnedPositionalEmbedding(
630629
config.max_position_embeddings,
631630
config.d_model,
632-
self.padding_idx,
633631
name="embed_positions",
634632
)
635633
self.layers = [TFBlenderbotEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
@@ -797,7 +795,6 @@ def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[TFSharedEmbe
797795
self.embed_positions = TFBlenderbotLearnedPositionalEmbedding(
798796
config.max_position_embeddings,
799797
config.d_model,
800-
self.padding_idx,
801798
name="embed_positions",
802799
)
803800
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0

src/transformers/models/blenderbot_small/modeling_blenderbot_small.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,8 @@ class BlenderbotSmallLearnedPositionalEmbedding(nn.Embedding):
110110
This module learns positional embeddings up to a fixed maximum size.
111111
"""
112112

113-
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
114-
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
115-
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
113+
def __init__(self, num_embeddings: int, embedding_dim: int):
114+
super().__init__(num_embeddings, embedding_dim)
116115

117116
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
118117
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
@@ -636,7 +635,6 @@ def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embe
636635
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
637636
config.max_position_embeddings,
638637
embed_dim,
639-
self.padding_idx,
640638
)
641639
self.layers = nn.ModuleList([BlenderbotSmallEncoderLayer(config) for _ in range(config.encoder_layers)])
642640
self.layernorm_embedding = nn.LayerNorm(embed_dim)
@@ -800,7 +798,6 @@ def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embe
800798
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
801799
config.max_position_embeddings,
802800
config.d_model,
803-
self.padding_idx,
804801
)
805802
self.layers = nn.ModuleList([BlenderbotSmallDecoderLayer(config) for _ in range(config.decoder_layers)])
806803
self.layernorm_embedding = nn.LayerNorm(config.d_model)

src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,7 @@ class TFBlenderbotSmallLearnedPositionalEmbedding(TFSharedEmbeddings):
117117
This module learns positional embeddings up to a fixed maximum size.
118118
"""
119119

120-
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
121-
assert padding_idx is not None, "padding_idx cannot be None"
120+
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
122121
super().__init__(num_embeddings, embedding_dim, **kwargs)
123122

124123
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
@@ -634,7 +633,6 @@ def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[TFShare
634633
self.embed_positions = TFBlenderbotSmallLearnedPositionalEmbedding(
635634
config.max_position_embeddings,
636635
config.d_model,
637-
self.padding_idx,
638636
name="embed_positions",
639637
)
640638
self.layers = [TFBlenderbotSmallEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
@@ -802,7 +800,6 @@ def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[TFShare
802800
self.embed_positions = TFBlenderbotSmallLearnedPositionalEmbedding(
803801
config.max_position_embeddings,
804802
config.d_model,
805-
self.padding_idx,
806803
name="embed_positions",
807804
)
808805
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0

src/transformers/models/led/modeling_led.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,8 @@ class LEDLearnedPositionalEmbedding(nn.Embedding):
112112
This module learns positional embeddings up to a fixed maximum size.
113113
"""
114114

115-
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
116-
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
117-
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
115+
def __init__(self, num_embeddings: int, embedding_dim: int):
116+
super().__init__(num_embeddings, embedding_dim)
118117

119118
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
120119
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
@@ -1622,7 +1621,6 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = Non
16221621
self.embed_positions = LEDLearnedPositionalEmbedding(
16231622
self.max_source_positions,
16241623
embed_dim,
1625-
self.padding_idx,
16261624
)
16271625
self.layers = nn.ModuleList([LEDEncoderLayer(config, i) for i in range(config.encoder_layers)])
16281626
self.layernorm_embedding = nn.LayerNorm(embed_dim)
@@ -1891,7 +1889,6 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = Non
18911889
self.embed_positions = LEDLearnedPositionalEmbedding(
18921890
self.max_target_positions,
18931891
config.d_model,
1894-
self.padding_idx,
18951892
)
18961893
self.layers = nn.ModuleList([LEDDecoderLayer(config) for _ in range(config.decoder_layers)])
18971894
self.layernorm_embedding = nn.LayerNorm(config.d_model)

src/transformers/models/led/modeling_tf_led.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings):
108108
This module learns positional embeddings up to a fixed maximum size.
109109
"""
110110

111-
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
112-
assert padding_idx is not None, "padding_idx cannot be None"
111+
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
113112
super().__init__(num_embeddings, embedding_dim, **kwargs)
114113

115114
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
@@ -1612,7 +1611,6 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[TFSharedEmbeddings]
16121611
self.embed_positions = TFLEDLearnedPositionalEmbedding(
16131612
config.max_encoder_position_embeddings,
16141613
config.d_model,
1615-
self.padding_idx,
16161614
name="embed_positions",
16171615
)
16181616
self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)]
@@ -1865,7 +1863,6 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[TFSharedEmbeddings]
18651863
self.embed_positions = TFLEDLearnedPositionalEmbedding(
18661864
config.max_decoder_position_embeddings,
18671865
config.d_model,
1868-
self.padding_idx,
18691866
name="embed_positions",
18701867
)
18711868
self.layers = [TFLEDDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]

src/transformers/models/mbart/modeling_mbart.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,11 @@ class MBartLearnedPositionalEmbedding(nn.Embedding):
114114
This module learns positional embeddings up to a fixed maximum size.
115115
"""
116116

117-
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
118-
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
117+
def __init__(self, num_embeddings: int, embedding_dim: int):
119118
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
120119
# and adjust num_embeddings appropriately. Other models dont have this hack
121120
self.offset = 2
122-
super().__init__(num_embeddings + self.offset, embedding_dim, padding_idx=padding_idx)
121+
super().__init__(num_embeddings + self.offset, embedding_dim)
123122

124123
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
125124
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
@@ -678,7 +677,6 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N
678677
self.embed_positions = MBartLearnedPositionalEmbedding(
679678
config.max_position_embeddings,
680679
embed_dim,
681-
self.padding_idx,
682680
)
683681
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
684682
self.layernorm_embedding = nn.LayerNorm(embed_dim)
@@ -844,7 +842,6 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N
844842
self.embed_positions = MBartLearnedPositionalEmbedding(
845843
config.max_position_embeddings,
846844
config.d_model,
847-
self.padding_idx,
848845
)
849846
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
850847
self.layernorm_embedding = nn.LayerNorm(config.d_model)

src/transformers/models/mbart/modeling_tf_mbart.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings):
115115
This module learns positional embeddings up to a fixed maximum size.
116116
"""
117117

118-
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
119-
assert padding_idx is not None, "padding_idx cannot be None"
118+
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
120119
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
121120
# and adjust num_embeddings appropriately. Other models dont have this hack
122121
self.offset = 2
@@ -636,7 +635,6 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[TFSharedEmbedding
636635
self.embed_positions = TFMBartLearnedPositionalEmbedding(
637636
config.max_position_embeddings,
638637
config.d_model,
639-
self.padding_idx,
640638
name="embed_positions",
641639
)
642640
self.layers = [TFMBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
@@ -806,7 +804,6 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[TFSharedEmbedding
806804
self.embed_positions = TFMBartLearnedPositionalEmbedding(
807805
config.max_position_embeddings,
808806
config.d_model,
809-
self.padding_idx,
810807
name="embed_positions",
811808
)
812809
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0

0 commit comments

Comments
 (0)