Skip to content

Commit 5483a63

Browse files
authored
Convert weights of darknet to checkpoint of tf
One can use this code to convert weights file of darknet to ckpt file of tensorflow, for detection or continue training. Then you can use the author's code 'freeze_graph.py' to convert the ckpt file to pb file.
1 parent 75e2ab0 commit 5483a63

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

from_darknet_weights_to_ckpt.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import tensorflow as tf
2+
from core.yolov3 import YOLOV3
3+
4+
iput_size = 412
5+
darknet_weights = '<your darknet weights' path>'
6+
ckpt_file = './checkpoint/yolov3.ckpt'
7+
8+
def load_weights(var_list, weights_file):
9+
"""
10+
Loads and converts pre-trained weights.
11+
:param var_list: list of network variables.
12+
:param weights_file: name of the binary file.
13+
:return: list of assign ops
14+
"""
15+
with open(weights_file, "rb") as fp:
16+
_ = np.fromfile(fp, dtype=np.int32, count=5)
17+
weights = np.fromfile(fp, dtype=np.float32) # np.ndarray
18+
print('weights_num:', weights.shape[0])
19+
ptr = 0
20+
i = 0
21+
assign_ops = []
22+
while i < len(var_list) - 1:
23+
var1 = var_list[i]
24+
var2 = var_list[i + 1]
25+
# do something only if we process conv layer
26+
if 'conv' in var1.name.split('/')[-2]:
27+
# check type of next layer
28+
if 'batch_normalization' in var2.name.split('/')[-2]:
29+
# load batch norm params
30+
gamma, beta, mean, var = var_list[i + 1:i + 5]
31+
batch_norm_vars = [beta, gamma, mean, var]
32+
for vari in batch_norm_vars:
33+
shape = vari.shape.as_list()
34+
num_params = np.prod(shape)
35+
vari_weights = weights[ptr:ptr + num_params].reshape(shape)
36+
ptr += num_params
37+
assign_ops.append(
38+
tf.assign(vari, vari_weights, validate_shape=True))
39+
i += 4
40+
elif 'conv' in var2.name.split('/')[-2]:
41+
# load biases
42+
bias = var2
43+
bias_shape = bias.shape.as_list()
44+
bias_params = np.prod(bias_shape)
45+
bias_weights = weights[ptr:ptr +
46+
bias_params].reshape(bias_shape)
47+
ptr += bias_params
48+
assign_ops.append(
49+
tf.assign(bias, bias_weights, validate_shape=True))
50+
i += 1
51+
shape = var1.shape.as_list()
52+
num_params = np.prod(shape)
53+
54+
var_weights = weights[ptr:ptr + num_params].reshape(
55+
(shape[3], shape[2], shape[0], shape[1]))
56+
# remember to transpose to column-major
57+
var_weights = np.transpose(var_weights, (2, 3, 1, 0))
58+
ptr += num_params
59+
assign_ops.append(
60+
tf.assign(var1, var_weights, validate_shape=True))
61+
i += 1
62+
print('ptr:', ptr)
63+
return assign_ops
64+
65+
with tf.name_scope('input'):
66+
input_data = tf.placeholder(dtype=tf.float32,shape=(None, iput_size, iput_size, 3), name='input_data')
67+
model = YOLOV3(input_data, trainable=False)
68+
load_ops = load_weights(tf.global_variables(), darknet_weights)
69+
70+
saver = tf.train.Saver(tf.global_variables())
71+
72+
with tf.Session() as sess:
73+
sess.run(load_ops)
74+
save_path = saver.save(sess, save_path=ckpt_file)
75+
print('Model saved in path: {}'.format(save_path))

0 commit comments

Comments
 (0)