Skip to content

Commit fdefb7d

Browse files
authored
Merge pull request YunYang1994#416 from KoapT/master
Convert darknet weights to ckpt&pb file.
2 parents 741037a + 62c1665 commit fdefb7d

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

from_darknet_weights_to_ckpt.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import tensorflow as tf
2+
from core.yolov3 import YOLOV3
3+
4+
iput_size = 412
5+
darknet_weights = '<your darknet weights' path>'
6+
ckpt_file = './checkpoint/yolov3.ckpt'
7+
8+
def load_weights(var_list, weights_file):
9+
"""
10+
Loads and converts pre-trained weights.
11+
:param var_list: list of network variables.
12+
:param weights_file: name of the binary file.
13+
:return: list of assign ops
14+
"""
15+
with open(weights_file, "rb") as fp:
16+
_ = np.fromfile(fp, dtype=np.int32, count=5)
17+
weights = np.fromfile(fp, dtype=np.float32) # np.ndarray
18+
print('weights_num:', weights.shape[0])
19+
ptr = 0
20+
i = 0
21+
assign_ops = []
22+
while i < len(var_list) - 1:
23+
var1 = var_list[i]
24+
var2 = var_list[i + 1]
25+
# do something only if we process conv layer
26+
if 'conv' in var1.name.split('/')[-2]:
27+
# check type of next layer
28+
if 'batch_normalization' in var2.name.split('/')[-2]:
29+
# load batch norm params
30+
gamma, beta, mean, var = var_list[i + 1:i + 5]
31+
batch_norm_vars = [beta, gamma, mean, var]
32+
for vari in batch_norm_vars:
33+
shape = vari.shape.as_list()
34+
num_params = np.prod(shape)
35+
vari_weights = weights[ptr:ptr + num_params].reshape(shape)
36+
ptr += num_params
37+
assign_ops.append(
38+
tf.assign(vari, vari_weights, validate_shape=True))
39+
i += 4
40+
elif 'conv' in var2.name.split('/')[-2]:
41+
# load biases
42+
bias = var2
43+
bias_shape = bias.shape.as_list()
44+
bias_params = np.prod(bias_shape)
45+
bias_weights = weights[ptr:ptr +
46+
bias_params].reshape(bias_shape)
47+
ptr += bias_params
48+
assign_ops.append(
49+
tf.assign(bias, bias_weights, validate_shape=True))
50+
i += 1
51+
shape = var1.shape.as_list()
52+
num_params = np.prod(shape)
53+
54+
var_weights = weights[ptr:ptr + num_params].reshape(
55+
(shape[3], shape[2], shape[0], shape[1]))
56+
# remember to transpose to column-major
57+
var_weights = np.transpose(var_weights, (2, 3, 1, 0))
58+
ptr += num_params
59+
assign_ops.append(
60+
tf.assign(var1, var_weights, validate_shape=True))
61+
i += 1
62+
print('ptr:', ptr)
63+
return assign_ops
64+
65+
with tf.name_scope('input'):
66+
input_data = tf.placeholder(dtype=tf.float32,shape=(None, iput_size, iput_size, 3), name='input_data')
67+
model = YOLOV3(input_data, trainable=False)
68+
load_ops = load_weights(tf.global_variables(), darknet_weights)
69+
70+
saver = tf.train.Saver(tf.global_variables())
71+
72+
with tf.Session() as sess:
73+
sess.run(load_ops)
74+
save_path = saver.save(sess, save_path=ckpt_file)
75+
print('Model saved in path: {}'.format(save_path))

from_darknet_weights_to_pb.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import tensorflow as tf
2+
from core.yolov3 import YOLOV3
3+
from from_darknet_weights_to_ckpt import load_weights
4+
5+
input_size = 416
6+
darknet_weights = '<your darknet weights file path>'
7+
pb_file = './yolov3.pb'
8+
output_node_names = ["input/input_data", "pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"]
9+
10+
with tf.name_scope('input'):
11+
input_data = tf.placeholder(dtype=tf.float32, shape=(None, input_size, input_size, 3), name='input_data')
12+
model = YOLOV3(input_data, trainable=False)
13+
load_ops = load_weights(tf.global_variables(), darknet_weights)
14+
15+
with tf.Session() as sess:
16+
sess.run(load_ops)
17+
output_graph_def = tf.graph_util.convert_variables_to_constants(
18+
sess,
19+
tf.get_default_graph().as_graph_def(),
20+
output_node_names=output_node_names
21+
)
22+
23+
with tf.gfile.GFile(output_graph, "wb") as f:
24+
f.write(output_graph_def.SerializeToString())
25+
26+
print("{} ops written to {}.".format(len(output_graph_def.node), output_graph))

0 commit comments

Comments
 (0)