Skip to content

Commit 2d458b2

Browse files
ConvBERT fix torch <> tf weights conversion (huggingface#10314)
* convbert conversion test * fin * fin * fin * clean up tf<->pt conversion * remove from_pt Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
1 parent 3437d12 commit 2d458b2

File tree

4 files changed

+14
-9
lines changed

4 files changed

+14
-9
lines changed

src/transformers/modeling_tf_pytorch_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
5656
tf_name = tf_name[1:] # Remove level zero
5757

5858
# When should we transpose the weights
59-
transpose = bool(tf_name[-1] == "kernel" or "emb_projs" in tf_name or "out_projs" in tf_name)
59+
transpose = bool(
60+
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
61+
or "emb_projs" in tf_name
62+
or "out_projs" in tf_name
63+
)
6064

6165
# Convert standard TF2.0 names in PyTorch names
6266
if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma":

src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch.py renamed to src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import argparse
1818

19-
from transformers import ConvBertConfig, ConvBertModel, load_tf_weights_in_convbert
19+
from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert
2020
from transformers.utils import logging
2121

2222

@@ -30,6 +30,9 @@ def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_f
3030
model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path)
3131
model.save_pretrained(pytorch_dump_path)
3232

33+
tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True)
34+
tf_model.save_pretrained(pytorch_dump_path)
35+
3336

3437
if __name__ == "__main__":
3538
parser = argparse.ArgumentParser()

src/transformers/models/convbert/modeling_tf_convbert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def __init__(self, input_size, output_size, num_groups, kernel_initializer, **kw
343343
def build(self, input_shape):
344344
self.kernel = self.add_weight(
345345
"kernel",
346-
shape=[self.num_groups, self.group_in_dim, self.group_out_dim],
346+
shape=[self.group_out_dim, self.group_in_dim, self.num_groups],
347347
initializer=self.kernel_initializer,
348348
trainable=True,
349349
)
@@ -355,7 +355,7 @@ def build(self, input_shape):
355355
def call(self, hidden_states):
356356
batch_size = shape_list(hidden_states)[0]
357357
x = tf.transpose(tf.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim]), [1, 0, 2])
358-
x = tf.matmul(x, self.kernel)
358+
x = tf.matmul(x, tf.transpose(self.kernel, [2, 1, 0]))
359359
x = tf.transpose(x, [1, 0, 2])
360360
x = tf.reshape(x, [batch_size, -1, self.output_size])
361361
x = tf.nn.bias_add(value=x, bias=self.bias)

tests/test_modeling_tf_convbert.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,14 +399,12 @@ def test_inference_masked_lm(self):
399399
expected_shape = [1, 6, 768]
400400
self.assertEqual(output.shape, expected_shape)
401401

402-
print(output[:, :3, :3])
403-
404402
expected_slice = tf.constant(
405403
[
406404
[
407-
[-0.10334751, -0.37152207, -0.2682219],
408-
[0.20078957, -0.3918426, -0.78811496],
409-
[0.08000169, -0.509474, -0.59314483],
405+
[-0.03475493, -0.4686034, -0.30638832],
406+
[0.22637248, -0.26988646, -0.7423424],
407+
[0.10324868, -0.45013508, -0.58280784],
410408
]
411409
]
412410
)

0 commit comments

Comments
 (0)