Skip to content

Commit 3eea052

Browse files
authored
Use this code to translate weights file of darknet to ckpt file of tensorflow.
Use this code to translate your weights file of darknet to ckpt file of tensorflow. Then you can use the author's code 'freeze_graph.py' to freeze the ckpt file to .pb model, and use the pb file path in 'image_demo.py' or 'video_demo.py' .
1 parent 42cd567 commit 3eea052

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

from_darknet_weights_to_ckpt

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import tensorflow as tf
2+
from core.yolov3 import YOLOV3
3+
4+
darknet_weights = '<your darknet weights' path>'
5+
ckpt_file = './checkpoint/yolov3.ckpt'
6+
7+
def load_weights(var_list, weights_file):
8+
"""
9+
Loads and converts pre-trained weights.
10+
:param var_list: list of network variables.
11+
:param weights_file: name of the binary file.
12+
:return: list of assign ops
13+
"""
14+
with open(weights_file, "rb") as fp:
15+
_ = np.fromfile(fp, dtype=np.int32, count=5)
16+
weights = np.fromfile(fp, dtype=np.float32) # np.ndarray
17+
print('weights_num:', weights.shape[0])
18+
ptr = 0
19+
i = 0
20+
assign_ops = []
21+
while i < len(var_list) - 1:
22+
var1 = var_list[i]
23+
var2 = var_list[i + 1]
24+
# do something only if we process conv layer
25+
if 'conv' in var1.name.split('/')[-2]:
26+
# check type of next layer
27+
if 'batch_normalization' in var2.name.split('/')[-2]:
28+
# load batch norm params
29+
gamma, beta, mean, var = var_list[i + 1:i + 5]
30+
batch_norm_vars = [beta, gamma, mean, var]
31+
for vari in batch_norm_vars:
32+
shape = vari.shape.as_list()
33+
num_params = np.prod(shape)
34+
vari_weights = weights[ptr:ptr + num_params].reshape(shape)
35+
ptr += num_params
36+
assign_ops.append(
37+
tf.assign(vari, vari_weights, validate_shape=True))
38+
i += 4
39+
elif 'conv' in var2.name.split('/')[-2]:
40+
# load biases
41+
bias = var2
42+
bias_shape = bias.shape.as_list()
43+
bias_params = np.prod(bias_shape)
44+
bias_weights = weights[ptr:ptr +
45+
bias_params].reshape(bias_shape)
46+
ptr += bias_params
47+
assign_ops.append(
48+
tf.assign(bias, bias_weights, validate_shape=True))
49+
i += 1
50+
shape = var1.shape.as_list()
51+
num_params = np.prod(shape)
52+
53+
var_weights = weights[ptr:ptr + num_params].reshape(
54+
(shape[3], shape[2], shape[0], shape[1]))
55+
# remember to transpose to column-major
56+
var_weights = np.transpose(var_weights, (2, 3, 1, 0))
57+
ptr += num_params
58+
assign_ops.append(
59+
tf.assign(var1, var_weights, validate_shape=True))
60+
i += 1
61+
print('ptr:', ptr)
62+
return assign_ops
63+
64+
with tf.name_scope('input'):
65+
input_data = tf.placeholder(dtype=tf.float32,shape=(None, 608, 608, 3), name='input_data')
66+
model = YOLOV3(input_data, trainable=False)
67+
load_ops = load_weights(tf.global_variables(), darknet_weights)
68+
69+
saver = tf.train.Saver(tf.global_variables())
70+
71+
with tf.Session() as sess:
72+
sess.run(load_ops)
73+
save_path = saver.save(sess, save_path=ckpt_file)
74+
print('Model saved in path: {}'.format(save_path))

0 commit comments

Comments
 (0)