Skip to content

Commit 6d202fb

Browse files
YunYang1994YunYang1994
authored andcommitted
I hate tensorflow
1 parent 8d5178e commit 6d202fb

File tree

6 files changed

+132
-37
lines changed

6 files changed

+132
-37
lines changed
0 Bytes
Binary file not shown.

core/__pycache__/utils.cpython-35.pyc

1.32 KB
Binary file not shown.
0 Bytes
Binary file not shown.

core/utils.py

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def preprocess_true_boxes(true_boxes, true_labels, input_shape, anchors, num_cla
345345
anchors_min = -anchors_max
346346
valid_mask = box_sizes[:, 0] > 0
347347

348+
348349
# Discard zero rows.
349350
wh = box_sizes[valid_mask]
350351
# set the center of all boxes as the origin of their coordinates
@@ -355,7 +356,7 @@ def preprocess_true_boxes(true_boxes, true_labels, input_shape, anchors, num_cla
355356

356357
intersect_mins = np.maximum(boxes_min, anchors_min)
357358
intersect_maxs = np.minimum(boxes_max, anchors_max)
358-
intersect_wh = np.maximum(intersect_maxs - intersect_mins, 0.)
359+
intersect_wh = np.maximum(intersect_maxs - intersect_mins, 0.)
359360
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
360361
box_area = wh[..., 0] * wh[..., 1]
361362

@@ -448,43 +449,80 @@ def parser_example(self, serialized_example):
448449

449450
return self.preprocess(image, true_labels, true_boxes)
450451

452+
def bbox_iou(A, B):
453+
454+
intersect_mins = np.maximum(A[:, 0:2], B[:, 0:2])
455+
intersect_maxs = np.minimum(A[:, 2:4], B[:, 2:4])
456+
intersect_wh = np.maximum(intersect_maxs - intersect_mins, 0.)
457+
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
458+
459+
A_area = np.prod(A[:, 2:4] - A[:, 0:2], axis=1)
460+
B_area = np.prod(B[:, 2:4] - B[:, 0:2], axis=1)
461+
462+
iou = intersect_area / (A_area + B_area - intersect_area)
463+
464+
return iou
465+
451466
def evaluate(y_pred, y_true, num_classes, score_thresh=0.5, iou_thresh=0.5):
452467

453468
num_images = y_true[0].shape[0]
454-
true_labels = {i:0 for i in range(num_classes)} # {class: count}
455-
pred_labels = {i:0 for i in range(num_classes)}
456-
true_positive = {i:0 for i in range(num_classes)}
469+
true_labels_dict = {i:0 for i in range(num_classes)} # {class: count}
470+
pred_labels_dict = {i:0 for i in range(num_classes)}
471+
true_positive_dict = {i:0 for i in range(num_classes)}
457472

458473
for i in range(num_images):
459-
true_labels_list = []
474+
true_labels_list, true_boxes_list = [], []
460475
for j in range(3): # three feature maps
461-
true_probs_temp = y_true[j][i][...,5:]
462-
true_probs_temp = true_probs_temp[true_probs_temp.sum(axis=-1) > 0]
463-
true_labels_list += list(np.argmax(true_probs_temp, axis=-1))
476+
true_probs_temp = y_true[j][i][...,5: ]
477+
true_boxes_temp = y_true[j][i][...,0:4]
478+
479+
object_mask = true_probs_temp.sum(axis=-1) > 0
480+
481+
true_probs_temp = true_probs_temp[object_mask]
482+
true_boxes_temp = true_boxes_temp[object_mask]
483+
484+
true_labels_list += np.argmax(true_probs_temp, axis=-1).tolist()
485+
true_boxes_list += true_boxes_temp.tolist()
464486

465487
if len(true_labels_list) != 0:
466-
print(true_labels_list)
467-
for cls, count in Counter(true_labels_list).items(): true_labels[cls] += count
488+
for cls, count in Counter(true_labels_list).items(): true_labels_dict[cls] += count
468489

469490
pred_boxes = y_pred[0][i:i+1]
470491
pred_confs = y_pred[1][i:i+1]
471492
pred_probs = y_pred[2][i:i+1]
472-
pred_labels_list = cpu_nms(pred_boxes, pred_confs*pred_probs, num_classes,
473-
100, score_thresh, iou_thresh)[2]
474-
pred_labels_list = [] if pred_labels_list is None else list(pred_labels_list)
475-
476-
if len(pred_labels_list) != 0:
477-
for cls, count in Counter(pred_labels_list).items(): pred_labels[cls] += count
478-
479-
for k in range(num_classes):
480-
t = true_labels_list.count(k)
481-
p = pred_labels_list.count(k)
482-
true_positive[k] += p if t >= p else t
483-
484-
recall = sum(true_positive.values()) / (sum(true_labels.values()) + 1e-6)
485-
precision = sum(true_positive.values()) / (sum(pred_labels.values()) + 1e-6)
486-
avg_prec = [true_positive[i] / (true_labels[i] + 1e-6) for i in range(num_classes)]
487-
mAP = sum(avg_prec) / num_classes
493+
494+
pred_boxes, pred_confs, pred_labels = cpu_nms(pred_boxes, pred_confs*pred_probs, num_classes,
495+
score_thresh=score_thresh, iou_thresh=iou_thresh)
496+
497+
true_boxes = np.array(true_boxes_list)
498+
box_centers, box_sizes = true_boxes[:,0:2], true_boxes[:,2:4]
499+
500+
true_boxes[:,0:2] = box_centers - box_sizes / 2.
501+
true_boxes[:,2:4] = true_boxes[:,0:2] + box_sizes
502+
503+
pred_labels_list = [] if pred_labels is None else pred_labels.tolist()
504+
if pred_labels_list == []: continue
505+
506+
detected = []
507+
for k in range(len(true_labels_list)):
508+
# compute iou between predicted box and ground_truth boxes
509+
iou = bbox_iou(true_boxes[k:k+1], pred_boxes)
510+
# Extract index of largest overlap
511+
m = np.argmax(iou)
512+
if iou[m] >= iou_thresh and true_labels_list[k] == pred_labels_list[m] and m not in detected:
513+
pred_labels_dict[true_labels_list[k]] += 1
514+
detected.append(m)
515+
pred_labels_list = [pred_labels_list[m] for m in detected]
516+
517+
for c in range(num_classes):
518+
t = true_labels_list.count(c)
519+
p = pred_labels_list.count(c)
520+
true_positive_dict[c] += p if t >= p else t
521+
522+
recall = sum(true_positive_dict.values()) / (sum(true_labels_dict.values()) + 1e-6)
523+
precision = sum(true_positive_dict.values()) / (sum(pred_labels_dict.values()) + 1e-6)
524+
avg_prec = [true_positive_dict[i] / (true_labels_dict[i] + 1e-6) for i in range(num_classes)]
525+
mAP = sum(avg_prec) / (sum([avg_prec[i] != 0 for i in range(num_classes)]) + 1e-6)
488526

489527
return recall, precision, mAP
490528

test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,52 @@
1212
#================================================================
1313

1414
# continue to work
15+
import tensorflow as tf
16+
from core import utils, yolov3
17+
18+
INPUT_SIZE = 416
19+
BATCH_SIZE = 1
20+
EPOCHS = 20
21+
LR = 0.001
22+
SHUFFLE_SIZE = 1
23+
24+
sess = tf.Session()
25+
classes = utils.read_coco_names('./data/coco.names')
26+
num_classes = len(classes)
27+
# file_pattern = "../COCO/tfrecords/coco*.tfrecords"
28+
file_pattern = "./data/train_data/quick_train_data/tfrecords/quick_train_data*.tfrecords"
29+
anchors = utils.get_anchors('./data/yolo_anchors.txt')
30+
31+
is_training = tf.placeholder(dtype=tf.bool, name="phase_train")
32+
dataset = tf.data.TFRecordDataset(filenames = tf.gfile.Glob(file_pattern))
33+
dataset = dataset.map(utils.parser(anchors, num_classes).parser_example, num_parallel_calls = 10)
34+
dataset = dataset.repeat().shuffle(SHUFFLE_SIZE).batch(BATCH_SIZE).prefetch(BATCH_SIZE)
35+
iterator = dataset.make_one_shot_iterator()
36+
example = iterator.get_next()
37+
38+
images, *y_true = example
39+
model = yolov3.yolov3(num_classes)
40+
with tf.variable_scope('yolov3'):
41+
y_pred = model.forward(images, is_training=is_training)
42+
y_pred = model.predict(y_pred)
43+
44+
load_ops = utils.load_weights(tf.global_variables(scope='yolov3'), "/home/yang/test/yolov3.weights")
45+
sess.run(load_ops)
46+
47+
48+
for epoch in range(EPOCHS):
49+
run_items = sess.run([y_pred, y_true], feed_dict={is_training:False})
50+
rec, prec, mAP = utils.evaluate(run_items[0], run_items[1], num_classes)
51+
print("=> EPOCH: %2d recall: %.2f precision: %.2f mAP: %.2f" %(epoch, rec, prec, mAP))
52+
53+
54+
55+
56+
57+
58+
59+
60+
61+
62+
63+

train.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,37 +40,45 @@
4040
with tf.variable_scope('yolov3'):
4141
y_pred = model.forward(images, is_training=is_training)
4242
loss = model.compute_loss(y_pred, y_true)
43-
4443
y_pred = model.predict(y_pred)
4544

46-
4745
optimizer = tf.train.AdamOptimizer(LR)
4846
train_op = optimizer.minimize(loss[0])
4947
saver = tf.train.Saver(max_to_keep=2)
50-
sess.run(tf.global_variables_initializer())
5148

52-
tf.summary.scalar("yolov3/total_loss", loss[0])
49+
rec_tensor = tf.Variable(0.)
50+
prec_tensor = tf.Variable(0.)
51+
mAP_tensor = tf.Variable(0.)
5352

53+
tf.summary.scalar("yolov3/recall", rec_tensor)
54+
tf.summary.scalar("yolov3/precision", prec_tensor)
55+
tf.summary.scalar("yolov3/mAP", mAP_tensor)
56+
tf.summary.scalar("yolov3/total_loss", loss[0])
5457

5558
tf.summary.scalar("loss/coord_loss", loss[1])
5659
tf.summary.scalar("loss/sizes_loss", loss[2])
5760
tf.summary.scalar("loss/confs_loss", loss[3])
5861
tf.summary.scalar("loss/class_loss", loss[4])
5962
write_op = tf.summary.merge_all()
6063
writer_train = tf.summary.FileWriter("./data/log/train", graph=sess.graph)
64+
sess.run(tf.global_variables_initializer())
6165

6266
for epoch in range(EPOCHS):
63-
run_items = sess.run([train_op, write_op, y_pred, y_true] + loss, feed_dict={is_training:True})
64-
65-
data = utils.evaluate(run_items[2], run_items[3], num_classes)
67+
run_items = sess.run([train_op, y_pred, y_true] + loss, feed_dict={is_training:True})
68+
rec, prec, mAP = utils.evaluate(run_items[1], run_items[2], num_classes)
69+
_, _, _, summary = sess.run([tf.assign(rec_tensor, rec),
70+
tf.assign(prec_tensor, prec),
71+
tf.assign(mAP_tensor, mAP), write_op], feed_dict={is_training:True})
6672

67-
writer_train.add_summary(summary=run_items[1], global_step=epoch)
73+
writer_train.add_summary(summary, global_step=epoch)
6874
writer_train.flush() # Flushes the event file to disk
6975
if epoch%1000 == 0: saver.save(sess, save_path="./checkpoint/yolov3.ckpt", global_step=epoch)
7076

71-
print("=> EPOCH:%10d | total_loss:%7.4f\tloss_coord:%7.4f\tloss_sizes:%7.4f\tloss_confs:%7.4f\tloss_class:%7.4f"
72-
%(epoch, run_items[4], run_items[5], run_items[6], run_items[7], run_items[8]))
73-
break
77+
print("=> EPOCH:%10d\ttotal_loss:%7.4f\tloss_coord:%7.4f\tloss_sizes:%7.4f\tloss_confs:%7.4f\tloss_class:%7.4f"
78+
"\trec:%.2f\tprec:%.2f\tmAP:%.2f"
79+
%(epoch, run_items[3], run_items[4], run_items[5], run_items[6], run_items[7], rec, prec, mAP))
80+
81+
7482

7583

7684

0 commit comments

Comments
 (0)