Skip to content

Commit 0cbc31e

Browse files
YunYang1994YunYang1994
authored andcommitted
fix issue15
1 parent ac51691 commit 0cbc31e

9 files changed

+82
-91
lines changed
0 Bytes
Binary file not shown.

core/__pycache__/utils.cpython-35.pyc

605 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

core/utils.py

Lines changed: 78 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -163,16 +163,16 @@ def cpu_nms(boxes, scores, num_classes, max_boxes=20, score_thresh=0.4, iou_thre
163163

164164
def resize_image_correct_bbox(image, bboxes, input_shape):
165165

166-
image_size = tf.to_float(tf.shape(image)[1:3])[::-1]
166+
image_size = tf.to_float(tf.shape(image)[0:2])[::-1]
167167
image = tf.image.resize_images(image, size=input_shape)
168168

169169
# correct bbox
170-
xx1 = bboxes[..., 0] * input_shape[0] / image_size[0]
171-
yy1 = bboxes[..., 1] * input_shape[1] / image_size[1]
172-
xx2 = bboxes[..., 2] * input_shape[0] / image_size[0]
173-
yy2 = bboxes[..., 3] * input_shape[1] / image_size[1]
170+
xx1 = bboxes[:, 0] * input_shape[0] / image_size[0]
171+
yy1 = bboxes[:, 1] * input_shape[1] / image_size[1]
172+
xx2 = bboxes[:, 2] * input_shape[0] / image_size[0]
173+
yy2 = bboxes[:, 3] * input_shape[1] / image_size[1]
174174

175-
bboxes = tf.stack([xx1, yy1, xx2, yy2], axis=2)
175+
bboxes = tf.stack([xx1, yy1, xx2, yy2], axis=1)
176176
return image, bboxes
177177

178178

@@ -306,14 +306,12 @@ def load_weights(var_list, weights_file):
306306
return assign_ops
307307

308308

309-
310309
def preprocess_true_boxes(true_boxes, true_labels, input_shape, anchors, num_classes):
311310
"""
312311
Preprocess true boxes to training input format
313312
Parameters:
314313
-----------
315-
:param true_boxes: numpy.ndarray of shape [N, T, 4]
316-
N: the number of images,
314+
:param true_boxes: numpy.ndarray of shape [T, 4]
317315
T: the number of boxes in each image.
318316
4: coordinate => x_min, y_min, x_max, y_max
319317
:param true_labels: class id
@@ -322,65 +320,61 @@ def preprocess_true_boxes(true_boxes, true_labels, input_shape, anchors, num_cla
322320
:param num_classes: integer, for coco dataset, it is 80
323321
Returns:
324322
----------
325-
y_true: list(3 array), shape like yolo_outputs, [N,, 13, 13, 3, 85]
323+
y_true: list(3 array), shape like yolo_outputs, [13, 13, 3, 85]
326324
13:cell szie, 3:number of anchors
327325
85: box_centers, box_sizes, confidence, probability
328326
"""
329-
330327
input_shape = np.array(input_shape, dtype=np.int32)
331-
num_images = true_boxes.shape[0]
332328
num_layers = len(anchors) // 3
333329
anchor_mask = [[6,7,8], [3,4,5], [0,1,2]] if num_layers==3 else [[3,4,5], [1,2,3]]
334330
grid_sizes = [input_shape//32, input_shape//16, input_shape//8]
335331

336-
box_centers = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) / 2 # the center of box
337-
box_sizes = true_boxes[..., 2:4] - true_boxes[..., 0:2] # the height and width of box
332+
box_centers = (true_boxes[:, 0:2] + true_boxes[:, 2:4]) / 2 # the center of box
333+
box_sizes = true_boxes[:, 2:4] - true_boxes[:, 0:2] # the height and width of box
338334

339-
true_boxes[..., 0:2] = box_centers
340-
true_boxes[..., 2:4] = box_sizes
335+
true_boxes[:, 0:2] = box_centers
336+
true_boxes[:, 2:4] = box_sizes
341337

342-
y_true_13 = np.zeros(shape=[num_images, grid_sizes[0][0], grid_sizes[0][1], 3, 5+num_classes], dtype=np.float32)
343-
y_true_26 = np.zeros(shape=[num_images, grid_sizes[1][0], grid_sizes[1][1], 3, 5+num_classes], dtype=np.float32)
344-
y_true_52 = np.zeros(shape=[num_images, grid_sizes[2][0], grid_sizes[2][1], 3, 5+num_classes], dtype=np.float32)
338+
y_true_13 = np.zeros(shape=[grid_sizes[0][0], grid_sizes[0][1], 3, 5+num_classes], dtype=np.float32)
339+
y_true_26 = np.zeros(shape=[grid_sizes[1][0], grid_sizes[1][1], 3, 5+num_classes], dtype=np.float32)
340+
y_true_52 = np.zeros(shape=[grid_sizes[2][0], grid_sizes[2][1], 3, 5+num_classes], dtype=np.float32)
345341

346342
y_true = [y_true_13, y_true_26, y_true_52]
347-
anchors = np.expand_dims(anchors, 0)
348343
anchors_max = anchors / 2.
349344
anchors_min = -anchors_max
350-
valid_mask = box_sizes[..., 0] > 0
351-
352-
for b in range(num_images): # for each image, do:
353-
# Discard zero rows.
354-
wh = box_sizes[b, valid_mask[b]]
355-
if len(wh) == 0: continue
356-
# set the center of all boxes as the origin of their coordinates
357-
# and correct their coordinates
358-
wh = np.expand_dims(wh, -2)
359-
boxes_max = wh / 2.
360-
boxes_min = -boxes_max
361-
362-
intersect_mins = np.maximum(boxes_min, anchors_min)
363-
intersect_maxs = np.minimum(boxes_max, anchors_max)
364-
intersect_wh = np.maximum(intersect_maxs - intersect_mins, 0.)
365-
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
366-
box_area = wh[..., 0] * wh[..., 1]
367-
anchor_area = anchors[..., 0] * anchors[..., 1]
368-
iou = intersect_area / (box_area + anchor_area - intersect_area)
369-
# Find best anchor for each true box
370-
best_anchor = np.argmax(iou, axis=-1)
371-
372-
for t, n in enumerate(best_anchor):
373-
for l in range(num_layers):
374-
if n not in anchor_mask[l]: continue
375-
i = np.floor(true_boxes[b,t,1]/input_shape[::-1]*grid_sizes[l][0]).astype('int32')
376-
j = np.floor(true_boxes[b,t,0]/input_shape[::-1]*grid_sizes[l][1]).astype('int32')
377-
k = anchor_mask[l].index(n)
378-
c = true_labels[b,t].astype('int32')
379-
y_true[l][b, i, j, k, 0:4] = true_boxes[b,t, 0:4]
380-
y_true[l][b, i, j, k, 4] = 1
381-
y_true[l][b, i, j, k, 5+c] = 1
382-
383-
return y_true
345+
valid_mask = box_sizes[:, 0] > 0
346+
347+
# Discard zero rows.
348+
wh = box_sizes[valid_mask]
349+
# set the center of all boxes as the origin of their coordinates
350+
# and correct their coordinates
351+
wh = np.expand_dims(wh, -2)
352+
boxes_max = wh / 2.
353+
boxes_min = -boxes_max
354+
355+
intersect_mins = np.maximum(boxes_min, anchors_min)
356+
intersect_maxs = np.minimum(boxes_max, anchors_max)
357+
intersect_wh = np.maximum(intersect_maxs - intersect_mins, 0.)
358+
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
359+
box_area = wh[..., 0] * wh[..., 1]
360+
361+
anchor_area = anchors[:, 0] * anchors[:, 1]
362+
iou = intersect_area / (box_area + anchor_area - intersect_area)
363+
# Find best anchor for each true box
364+
best_anchor = np.argmax(iou, axis=-1)
365+
366+
for t, n in enumerate(best_anchor):
367+
for l in range(num_layers):
368+
if n not in anchor_mask[l]: continue
369+
i = np.floor(true_boxes[t,1]/input_shape[::-1]*grid_sizes[l][0]).astype('int32')
370+
j = np.floor(true_boxes[t,0]/input_shape[::-1]*grid_sizes[l][1]).astype('int32')
371+
k = anchor_mask[l].index(n)
372+
c = true_labels[t].astype('int32')
373+
y_true[l][i, j, k, 0:4] = true_boxes[t, 0:4]
374+
y_true[l][i, j, k, 4] = 1
375+
y_true[l][i, j, k, 5+c] = 1
376+
377+
return y_true_13, y_true_26, y_true_52
384378

385379

386380

@@ -414,44 +408,44 @@ def get_anchors(anchors_path):
414408
return anchors.reshape(-1, 2)
415409

416410

411+
class parser(object):
412+
def __init__(self, anchors, num_classes, input_shape=[416, 416]):
413+
self.anchors = anchors
414+
self.num_classes = num_classes
415+
self.input_shape = input_shape
417416

418-
def parser(serialized_example):
419-
features = tf.parse_single_example(
420-
serialized_example,
421-
features = {
422-
'image' : tf.FixedLenFeature([], dtype = tf.string),
423-
'bboxes': tf.FixedLenFeature([], dtype = tf.string),
424-
'labels': tf.VarLenFeature(dtype = tf.int64),
425-
}
426-
)
427-
428-
image = tf.image.decode_jpeg(features['image'], channels = 3)
429-
image = tf.image.convert_image_dtype(image, tf.uint8)
430-
431-
bboxes = tf.decode_raw(features['bboxes'], tf.float32)
432-
bboxes = tf.reshape(bboxes, shape=[-1,4])
433-
434-
labels = features['labels'].values
435-
return image, bboxes, labels
436-
437-
438-
439-
440-
441-
442-
443-
444-
445-
446-
447-
417+
def preprocess(self, image, true_labels, true_boxes):
418+
# resize_image_correct_bbox
419+
image, true_boxes = resize_image_correct_bbox(image, true_boxes,
420+
input_shape=self.input_shape)
448421

422+
y_true_13, y_true_26, y_true_52 = tf.py_func(preprocess_true_boxes,
423+
inp=[true_boxes, true_labels, self.input_shape, self.anchors, self.num_classes],
424+
Tout = [tf.float32, tf.float32, tf.float32])
425+
# data augmentation
426+
# pass
449427

428+
return image, y_true_13, y_true_26, y_true_52
450429

430+
def parser_example(self, serialized_example):
451431

432+
features = tf.parse_single_example(
433+
serialized_example,
434+
features = {
435+
'image' : tf.FixedLenFeature([], dtype = tf.string),
436+
'bboxes': tf.FixedLenFeature([], dtype = tf.string),
437+
'labels': tf.VarLenFeature(dtype = tf.int64),
438+
}
439+
)
452440

441+
image = tf.image.decode_jpeg(features['image'], channels = 3)
442+
image = tf.image.convert_image_dtype(image, tf.uint8)
453443

444+
true_boxes = tf.decode_raw(features['bboxes'], tf.float32)
445+
true_boxes = tf.reshape(true_boxes, shape=[-1,4])
446+
true_labels = features['labels'].values
454447

448+
return self.preprocess(image, true_labels, true_boxes)
455449

456450

457451

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

quick_train.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,13 @@
2525
file_pattern = "./data/train_data/tfrecords/quick_train_data*.tfrecords"
2626
anchors = utils.get_anchors('./data/yolo_anchors.txt')
2727

28+
2829
dataset = tf.data.TFRecordDataset(filenames = tf.gfile.Glob(file_pattern))
29-
dataset = dataset.map(utils.parser, num_parallel_calls = 10)
30+
dataset = dataset.map(utils.parser(anchors, num_classes).parser_example, num_parallel_calls = 10)
3031
dataset = dataset.repeat().batch(BATCH_SIZE).prefetch(BATCH_SIZE)
3132
iterator = dataset.make_one_shot_iterator()
32-
images, true_boxes, true_labels = iterator.get_next()
33-
images, true_boxes = utils.resize_image_correct_bbox(images, true_boxes, [INPUT_SIZE, INPUT_SIZE])
34-
35-
y_true = tf.py_func(utils.preprocess_true_boxes,
36-
inp=[true_boxes, true_labels, [INPUT_SIZE, INPUT_SIZE], anchors, num_classes],
37-
Tout = [tf.float32, tf.float32, tf.float32])
33+
example = iterator.get_next()
34+
images, *y_true = example
3835

3936
model = yolov3.yolov3(num_classes)
4037
with tf.variable_scope('yolov3'):

0 commit comments

Comments
 (0)