|
27 | 27 | file_pattern = "./data/train_data/quick_train_data/tfrecords/quick_train_data*.tfrecords"
|
28 | 28 | anchors = utils.get_anchors('./data/yolo_anchors.txt')
|
29 | 29 |
|
30 |
| -is_training = tf.placeholder(dtype=tf.bool, name="phase_train") |
31 | 30 | dataset = tf.data.TFRecordDataset(filenames = tf.gfile.Glob(file_pattern))
|
32 | 31 | dataset = dataset.map(utils.parser(anchors, num_classes).parser_example, num_parallel_calls = 10)
|
33 | 32 | dataset = dataset.repeat().shuffle(SHUFFLE_SIZE).batch(BATCH_SIZE).prefetch(BATCH_SIZE)
|
|
36 | 35 | images, *y_true = example
|
37 | 36 | model = yolov3.yolov3(num_classes)
|
38 | 37 | with tf.variable_scope('yolov3'):
|
39 |
| - y_pred = model.forward(images, is_training=is_training) |
| 38 | + y_pred = model.forward(images, is_training=False) |
40 | 39 | loss = model.compute_loss(y_pred, y_true)
|
41 | 40 | y_pred = model.predict(y_pred)
|
42 |
| - |
43 |
| - |
44 |
| -# # train |
45 |
| -# optimizer = tf.train.AdamOptimizer(LR) |
46 |
| -# train_op = optimizer.minimize(loss[0]) |
47 |
| -# sess.run(tf.global_variables_initializer()) |
48 |
| -# for epoch in range(EPOCHS): |
49 |
| - # run_items = sess.run([train_op, y_pred, y_true] + loss, feed_dict={is_training:True}) |
50 |
| - # rec, prec, mAP = utils.evaluate(run_items[1], run_items[2], num_classes) |
51 |
| - |
52 |
| - # print("=> EPOCH: %2d\ttotal_loss:%7.4f\tloss_coord:%7.4f\tloss_sizes:%7.4f\tloss_confs:%7.4f\tloss_class:%7.4f" |
53 |
| - # "\trec:%.2f\tprec:%.2f\tmAP:%.2f" |
54 |
| - # %(epoch, run_items[3], run_items[4], run_items[5], run_items[6], run_items[7], rec, prec, mAP)) |
55 |
| - |
56 |
| - |
57 |
| - |
58 |
| -# test |
59 |
| -load_ops = utils.load_weights(tf.global_variables(scope='yolov3'), weights_path) |
60 |
| -sess.run(load_ops) |
| 41 | + load_ops = utils.load_weights(tf.global_variables(scope='yolov3'), weights_path) |
| 42 | + sess.run(load_ops) |
61 | 43 |
|
62 | 44 | for epoch in range(EPOCHS):
|
63 |
| - run_items = sess.run([y_pred, y_true] + loss, feed_dict={is_training:False}) |
| 45 | + run_items = sess.run([y_pred, y_true] + loss) |
64 | 46 | rec, prec, mAP = utils.evaluate(run_items[0], run_items[1], num_classes, score_thresh=0.3, iou_thresh=0.5)
|
65 | 47 |
|
66 | 48 | print("=> EPOCH: %2d\ttotal_loss:%7.4f\tloss_coord:%7.4f\tloss_sizes:%7.4f\tloss_confs:%7.4f\tloss_class:%7.4f"
|
67 |
| - "\trec:%.2f\tprec:%.2f\tmAP:%.2f" |
| 49 | + "\trec:%7.4f\tprec:%7.4f\tmAP:%7.4f" |
68 | 50 | %(epoch, run_items[2], run_items[3], run_items[4], run_items[5], run_items[6], rec, prec, mAP))
|
69 | 51 |
|
70 | 52 |
|
|
0 commit comments