Skip to content

Commit 439e7ab

Browse files
authored
use float 16 in causal mask and masked bias (huggingface#13194)
1 parent 8be921f commit 439e7ab

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,11 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
157157
) and weight_or_bias == "weight":
158158

159159
# Insert a tensor of 1x1xDxD bias.
160-
causal_mask = torch.tril(torch.ones((n_embed, n_embed), dtype=torch.uint8)).view(1, 1, n_embed, n_embed)
160+
causal_mask = torch.tril(torch.ones((n_embed, n_embed), dtype=torch.float16)).view(1, 1, n_embed, n_embed)
161161
output_state_dict[layer_name + ".attn.bias"] = causal_mask
162162

163163
# Insert a "dummy" tensor for masked_bias.
164-
masked_bias = torch.tensor(-1e4)
164+
masked_bias = torch.tensor(-1e4, dtype=torch.float16)
165165
output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias
166166

167167
out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)

0 commit comments

Comments
 (0)