Skip to content

Commit fe123de

Browse files
YunYang1994YunYang1994
authored andcommitted
I hate tensorflow
1 parent f0d73b7 commit fe123de

File tree

1 file changed

+5
-23
lines changed

1 file changed

+5
-23
lines changed

quick_train.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
file_pattern = "./data/train_data/quick_train_data/tfrecords/quick_train_data*.tfrecords"
2828
anchors = utils.get_anchors('./data/yolo_anchors.txt')
2929

30-
is_training = tf.placeholder(dtype=tf.bool, name="phase_train")
3130
dataset = tf.data.TFRecordDataset(filenames = tf.gfile.Glob(file_pattern))
3231
dataset = dataset.map(utils.parser(anchors, num_classes).parser_example, num_parallel_calls = 10)
3332
dataset = dataset.repeat().shuffle(SHUFFLE_SIZE).batch(BATCH_SIZE).prefetch(BATCH_SIZE)
@@ -36,35 +35,18 @@
3635
images, *y_true = example
3736
model = yolov3.yolov3(num_classes)
3837
with tf.variable_scope('yolov3'):
39-
y_pred = model.forward(images, is_training=is_training)
38+
y_pred = model.forward(images, is_training=False)
4039
loss = model.compute_loss(y_pred, y_true)
4140
y_pred = model.predict(y_pred)
42-
43-
44-
# # train
45-
# optimizer = tf.train.AdamOptimizer(LR)
46-
# train_op = optimizer.minimize(loss[0])
47-
# sess.run(tf.global_variables_initializer())
48-
# for epoch in range(EPOCHS):
49-
# run_items = sess.run([train_op, y_pred, y_true] + loss, feed_dict={is_training:True})
50-
# rec, prec, mAP = utils.evaluate(run_items[1], run_items[2], num_classes)
51-
52-
# print("=> EPOCH: %2d\ttotal_loss:%7.4f\tloss_coord:%7.4f\tloss_sizes:%7.4f\tloss_confs:%7.4f\tloss_class:%7.4f"
53-
# "\trec:%.2f\tprec:%.2f\tmAP:%.2f"
54-
# %(epoch, run_items[3], run_items[4], run_items[5], run_items[6], run_items[7], rec, prec, mAP))
55-
56-
57-
58-
# test
59-
load_ops = utils.load_weights(tf.global_variables(scope='yolov3'), weights_path)
60-
sess.run(load_ops)
41+
load_ops = utils.load_weights(tf.global_variables(scope='yolov3'), weights_path)
42+
sess.run(load_ops)
6143

6244
for epoch in range(EPOCHS):
63-
run_items = sess.run([y_pred, y_true] + loss, feed_dict={is_training:False})
45+
run_items = sess.run([y_pred, y_true] + loss)
6446
rec, prec, mAP = utils.evaluate(run_items[0], run_items[1], num_classes, score_thresh=0.3, iou_thresh=0.5)
6547

6648
print("=> EPOCH: %2d\ttotal_loss:%7.4f\tloss_coord:%7.4f\tloss_sizes:%7.4f\tloss_confs:%7.4f\tloss_class:%7.4f"
67-
"\trec:%.2f\tprec:%.2f\tmAP:%.2f"
49+
"\trec:%7.4f\tprec:%7.4f\tmAP:%7.4f"
6850
%(epoch, run_items[2], run_items[3], run_items[4], run_items[5], run_items[6], rec, prec, mAP))
6951

7052

0 commit comments

Comments
 (0)