|
18 | 18 | BATCH_SIZE = 1
|
19 | 19 | EPOCHS = 20
|
20 | 20 | LR = 0.001
|
| 21 | +SHUFFLE_SIZE = 1 |
| 22 | +weights_path = "/home/yang/test/yolov3.weights" |
21 | 23 |
|
22 | 24 | sess = tf.Session()
|
23 | 25 | classes = utils.read_coco_names('./data/coco.names')
|
24 | 26 | num_classes = len(classes)
|
25 | 27 | file_pattern = "./data/train_data/quick_train_data/tfrecords/quick_train_data*.tfrecords"
|
26 | 28 | anchors = utils.get_anchors('./data/yolo_anchors.txt')
|
27 | 29 |
|
28 |
| - |
| 30 | +is_training = tf.placeholder(dtype=tf.bool, name="phase_train") |
29 | 31 | dataset = tf.data.TFRecordDataset(filenames = tf.gfile.Glob(file_pattern))
|
30 | 32 | dataset = dataset.map(utils.parser(anchors, num_classes).parser_example, num_parallel_calls = 10)
|
31 |
| -dataset = dataset.repeat().batch(BATCH_SIZE).prefetch(BATCH_SIZE) |
| 33 | +dataset = dataset.repeat().shuffle(SHUFFLE_SIZE).batch(BATCH_SIZE).prefetch(BATCH_SIZE) |
32 | 34 | iterator = dataset.make_one_shot_iterator()
|
33 | 35 | example = iterator.get_next()
|
34 | 36 | images, *y_true = example
|
| 37 | +model = yolov3.yolov3(num_classes) |
| 38 | +with tf.variable_scope('yolov3'): |
| 39 | + y_pred = model.forward(images, is_training=is_training) |
| 40 | + loss = model.compute_loss(y_pred, y_true) |
| 41 | + y_pred = model.predict(y_pred) |
35 | 42 |
|
36 |
| -# model = yolov3.yolov3(num_classes) |
37 |
| -# with tf.variable_scope('yolov3'): |
38 |
| - # y_pred = model.forward(images, is_training=True) |
39 |
| - # result = model.compute_loss(y_pred, y_true) |
40 | 43 |
|
| 44 | +# # train |
41 | 45 | # optimizer = tf.train.AdamOptimizer(LR)
|
42 |
| -# train_op = optimizer.minimize(result[3]) |
| 46 | +# train_op = optimizer.minimize(loss[0]) |
43 | 47 | # sess.run(tf.global_variables_initializer())
|
44 |
| - |
45 | 48 | # for epoch in range(EPOCHS):
|
46 |
| - # run_items = sess.run([train_op] + result) |
47 |
| - # print("=> EPOCH:%4d\t| prec_50:%.4f\trec_50:%.4f\tavg_iou:%.4f\t | total_loss:%7.4f\tloss_coord:%7.4f" |
48 |
| - # "\tloss_sizes:%7.4f\tloss_confs:%7.4f\tloss_class:%7.4f" %(epoch, run_items[1], run_items[2], |
49 |
| - # run_items[3], run_items[4], run_items[5], run_items[6], run_items[7], run_items[8])) |
| 49 | + # run_items = sess.run([train_op, y_pred, y_true] + loss, feed_dict={is_training:True}) |
| 50 | + # rec, prec, mAP = utils.evaluate(run_items[1], run_items[2], num_classes) |
50 | 51 |
|
51 |
| -#************************ test with yolov3.weights ****************************# |
| 52 | + # print("=> EPOCH: %2d\ttotal_loss:%7.4f\tloss_coord:%7.4f\tloss_sizes:%7.4f\tloss_confs:%7.4f\tloss_class:%7.4f" |
| 53 | + # "\trec:%.2f\tprec:%.2f\tmAP:%.2f" |
| 54 | + # %(epoch, run_items[3], run_items[4], run_items[5], run_items[6], run_items[7], rec, prec, mAP)) |
52 | 55 |
|
53 |
| -model = yolov3.yolov3(num_classes) |
54 |
| -with tf.variable_scope('yolov3'): |
55 |
| - y_pred = model.forward(images, is_training=False) |
56 |
| - load_ops = utils.load_weights(tf.global_variables(scope='yolov3'), "./checkpoint/yolov3.weights") |
57 |
| - sess.run(load_ops) |
58 |
| - result = model.compute_loss(y_pred, y_true) |
| 56 | + |
| 57 | + |
| 58 | +# test |
| 59 | +load_ops = utils.load_weights(tf.global_variables(scope='yolov3'), weights_path) |
| 60 | +sess.run(load_ops) |
59 | 61 |
|
60 | 62 | for epoch in range(EPOCHS):
|
61 |
| - run_items = sess.run(result) |
62 |
| - print("=> EPOCH:%4d\t| prec_50:%.4f\trec_50:%.4f\tavg_iou:%.4f\t | total_loss:%7.4f\tloss_coord:%7.4f" |
63 |
| - "\tloss_sizes:%7.4f\tloss_confs:%7.4f\tloss_class:%7.4f" %(epoch, run_items[0], run_items[1], |
64 |
| - run_items[2], run_items[3], run_items[4], run_items[5], run_items[6], run_items[7])) |
| 63 | + run_items = sess.run([y_pred, y_true] + loss, feed_dict={is_training:True}) |
| 64 | + rec, prec, mAP = utils.evaluate(run_items[0], run_items[1], num_classes, score_thresh=0.3, iou_thresh=0.5) |
| 65 | + |
| 66 | + print("=> EPOCH: %2d\ttotal_loss:%7.4f\tloss_coord:%7.4f\tloss_sizes:%7.4f\tloss_confs:%7.4f\tloss_class:%7.4f" |
| 67 | + "\trec:%.2f\tprec:%.2f\tmAP:%.2f" |
| 68 | + %(epoch, run_items[2], run_items[3], run_items[4], run_items[5], run_items[6], rec, prec, mAP)) |
| 69 | + |
65 | 70 |
|
66 | 71 |
|
0 commit comments