|
15 | 15 |
|
16 | 16 | """Function to build box predictor from configuration."""
|
17 | 17 |
|
18 |
| -from object_detection.core import box_predictor |
| 18 | +from object_detection.predictors import convolutional_box_predictor |
| 19 | +from object_detection.predictors import mask_rcnn_box_predictor |
| 20 | +from object_detection.predictors import rfcn_box_predictor |
| 21 | +from object_detection.predictors.mask_rcnn_heads import box_head |
| 22 | +from object_detection.predictors.mask_rcnn_heads import class_head |
| 23 | +from object_detection.predictors.mask_rcnn_heads import mask_head |
19 | 24 | from object_detection.protos import box_predictor_pb2
|
20 | 25 |
|
21 | 26 |
|
@@ -48,92 +53,112 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
|
48 | 53 | box_predictor_oneof = box_predictor_config.WhichOneof('box_predictor_oneof')
|
49 | 54 |
|
50 | 55 | if box_predictor_oneof == 'convolutional_box_predictor':
|
51 |
| - conv_box_predictor = box_predictor_config.convolutional_box_predictor |
52 |
| - conv_hyperparams_fn = argscope_fn(conv_box_predictor.conv_hyperparams, |
| 56 | + config_box_predictor = box_predictor_config.convolutional_box_predictor |
| 57 | + conv_hyperparams_fn = argscope_fn(config_box_predictor.conv_hyperparams, |
53 | 58 | is_training)
|
54 |
| - box_predictor_object = box_predictor.ConvolutionalBoxPredictor( |
55 |
| - is_training=is_training, |
56 |
| - num_classes=num_classes, |
57 |
| - conv_hyperparams_fn=conv_hyperparams_fn, |
58 |
| - min_depth=conv_box_predictor.min_depth, |
59 |
| - max_depth=conv_box_predictor.max_depth, |
60 |
| - num_layers_before_predictor=(conv_box_predictor. |
61 |
| - num_layers_before_predictor), |
62 |
| - use_dropout=conv_box_predictor.use_dropout, |
63 |
| - dropout_keep_prob=conv_box_predictor.dropout_keep_probability, |
64 |
| - kernel_size=conv_box_predictor.kernel_size, |
65 |
| - box_code_size=conv_box_predictor.box_code_size, |
66 |
| - apply_sigmoid_to_scores=conv_box_predictor.apply_sigmoid_to_scores, |
67 |
| - class_prediction_bias_init=(conv_box_predictor. |
68 |
| - class_prediction_bias_init), |
69 |
| - use_depthwise=conv_box_predictor.use_depthwise |
70 |
| - ) |
| 59 | + box_predictor_object = ( |
| 60 | + convolutional_box_predictor.ConvolutionalBoxPredictor( |
| 61 | + is_training=is_training, |
| 62 | + num_classes=num_classes, |
| 63 | + conv_hyperparams_fn=conv_hyperparams_fn, |
| 64 | + min_depth=config_box_predictor.min_depth, |
| 65 | + max_depth=config_box_predictor.max_depth, |
| 66 | + num_layers_before_predictor=( |
| 67 | + config_box_predictor.num_layers_before_predictor), |
| 68 | + use_dropout=config_box_predictor.use_dropout, |
| 69 | + dropout_keep_prob=config_box_predictor.dropout_keep_probability, |
| 70 | + kernel_size=config_box_predictor.kernel_size, |
| 71 | + box_code_size=config_box_predictor.box_code_size, |
| 72 | + apply_sigmoid_to_scores=config_box_predictor. |
| 73 | + apply_sigmoid_to_scores, |
| 74 | + class_prediction_bias_init=( |
| 75 | + config_box_predictor.class_prediction_bias_init), |
| 76 | + use_depthwise=config_box_predictor.use_depthwise)) |
71 | 77 | return box_predictor_object
|
72 | 78 |
|
73 | 79 | if box_predictor_oneof == 'weight_shared_convolutional_box_predictor':
|
74 |
| - conv_box_predictor = (box_predictor_config. |
75 |
| - weight_shared_convolutional_box_predictor) |
76 |
| - conv_hyperparams_fn = argscope_fn(conv_box_predictor.conv_hyperparams, |
| 80 | + config_box_predictor = ( |
| 81 | + box_predictor_config.weight_shared_convolutional_box_predictor) |
| 82 | + conv_hyperparams_fn = argscope_fn(config_box_predictor.conv_hyperparams, |
77 | 83 | is_training)
|
78 |
| - box_predictor_object = box_predictor.WeightSharedConvolutionalBoxPredictor( |
79 |
| - is_training=is_training, |
80 |
| - num_classes=num_classes, |
81 |
| - conv_hyperparams_fn=conv_hyperparams_fn, |
82 |
| - depth=conv_box_predictor.depth, |
83 |
| - num_layers_before_predictor=( |
84 |
| - conv_box_predictor.num_layers_before_predictor), |
85 |
| - kernel_size=conv_box_predictor.kernel_size, |
86 |
| - box_code_size=conv_box_predictor.box_code_size, |
87 |
| - class_prediction_bias_init=conv_box_predictor. |
88 |
| - class_prediction_bias_init, |
89 |
| - use_dropout=conv_box_predictor.use_dropout, |
90 |
| - dropout_keep_prob=conv_box_predictor.dropout_keep_probability, |
91 |
| - share_prediction_tower=conv_box_predictor.share_prediction_tower) |
| 84 | + apply_batch_norm = config_box_predictor.conv_hyperparams.HasField( |
| 85 | + 'batch_norm') |
| 86 | + box_predictor_object = ( |
| 87 | + convolutional_box_predictor.WeightSharedConvolutionalBoxPredictor( |
| 88 | + is_training=is_training, |
| 89 | + num_classes=num_classes, |
| 90 | + conv_hyperparams_fn=conv_hyperparams_fn, |
| 91 | + depth=config_box_predictor.depth, |
| 92 | + num_layers_before_predictor=( |
| 93 | + config_box_predictor.num_layers_before_predictor), |
| 94 | + kernel_size=config_box_predictor.kernel_size, |
| 95 | + box_code_size=config_box_predictor.box_code_size, |
| 96 | + class_prediction_bias_init=config_box_predictor. |
| 97 | + class_prediction_bias_init, |
| 98 | + use_dropout=config_box_predictor.use_dropout, |
| 99 | + dropout_keep_prob=config_box_predictor.dropout_keep_probability, |
| 100 | + share_prediction_tower=config_box_predictor.share_prediction_tower, |
| 101 | + apply_batch_norm=apply_batch_norm)) |
92 | 102 | return box_predictor_object
|
93 | 103 |
|
94 | 104 | if box_predictor_oneof == 'mask_rcnn_box_predictor':
|
95 |
| - mask_rcnn_box_predictor = box_predictor_config.mask_rcnn_box_predictor |
96 |
| - fc_hyperparams_fn = argscope_fn(mask_rcnn_box_predictor.fc_hyperparams, |
| 105 | + config_box_predictor = box_predictor_config.mask_rcnn_box_predictor |
| 106 | + fc_hyperparams_fn = argscope_fn(config_box_predictor.fc_hyperparams, |
97 | 107 | is_training)
|
98 | 108 | conv_hyperparams_fn = None
|
99 |
| - if mask_rcnn_box_predictor.HasField('conv_hyperparams'): |
| 109 | + if config_box_predictor.HasField('conv_hyperparams'): |
100 | 110 | conv_hyperparams_fn = argscope_fn(
|
101 |
| - mask_rcnn_box_predictor.conv_hyperparams, is_training) |
102 |
| - box_predictor_object = box_predictor.MaskRCNNBoxPredictor( |
| 111 | + config_box_predictor.conv_hyperparams, is_training) |
| 112 | + box_prediction_head = box_head.BoxHead( |
103 | 113 | is_training=is_training,
|
104 | 114 | num_classes=num_classes,
|
105 | 115 | fc_hyperparams_fn=fc_hyperparams_fn,
|
106 |
| - use_dropout=mask_rcnn_box_predictor.use_dropout, |
107 |
| - dropout_keep_prob=mask_rcnn_box_predictor.dropout_keep_probability, |
108 |
| - box_code_size=mask_rcnn_box_predictor.box_code_size, |
109 |
| - conv_hyperparams_fn=conv_hyperparams_fn, |
110 |
| - predict_instance_masks=mask_rcnn_box_predictor.predict_instance_masks, |
111 |
| - mask_height=mask_rcnn_box_predictor.mask_height, |
112 |
| - mask_width=mask_rcnn_box_predictor.mask_width, |
113 |
| - mask_prediction_num_conv_layers=( |
114 |
| - mask_rcnn_box_predictor.mask_prediction_num_conv_layers), |
115 |
| - mask_prediction_conv_depth=( |
116 |
| - mask_rcnn_box_predictor.mask_prediction_conv_depth), |
117 |
| - masks_are_class_agnostic=( |
118 |
| - mask_rcnn_box_predictor.masks_are_class_agnostic), |
119 |
| - predict_keypoints=mask_rcnn_box_predictor.predict_keypoints, |
| 116 | + use_dropout=config_box_predictor.use_dropout, |
| 117 | + dropout_keep_prob=config_box_predictor.dropout_keep_probability, |
| 118 | + box_code_size=config_box_predictor.box_code_size, |
120 | 119 | share_box_across_classes=(
|
121 |
| - mask_rcnn_box_predictor.share_box_across_classes)) |
| 120 | + config_box_predictor.share_box_across_classes)) |
| 121 | + class_prediction_head = class_head.ClassHead( |
| 122 | + is_training=is_training, |
| 123 | + num_classes=num_classes, |
| 124 | + fc_hyperparams_fn=fc_hyperparams_fn, |
| 125 | + use_dropout=config_box_predictor.use_dropout, |
| 126 | + dropout_keep_prob=config_box_predictor.dropout_keep_probability) |
| 127 | + third_stage_heads = {} |
| 128 | + if config_box_predictor.predict_instance_masks: |
| 129 | + third_stage_heads[ |
| 130 | + mask_rcnn_box_predictor.MASK_PREDICTIONS] = mask_head.MaskHead( |
| 131 | + num_classes=num_classes, |
| 132 | + conv_hyperparams_fn=conv_hyperparams_fn, |
| 133 | + mask_height=config_box_predictor.mask_height, |
| 134 | + mask_width=config_box_predictor.mask_width, |
| 135 | + mask_prediction_num_conv_layers=( |
| 136 | + config_box_predictor.mask_prediction_num_conv_layers), |
| 137 | + mask_prediction_conv_depth=( |
| 138 | + config_box_predictor.mask_prediction_conv_depth), |
| 139 | + masks_are_class_agnostic=( |
| 140 | + config_box_predictor.masks_are_class_agnostic)) |
| 141 | + box_predictor_object = mask_rcnn_box_predictor.MaskRCNNBoxPredictor( |
| 142 | + is_training=is_training, |
| 143 | + num_classes=num_classes, |
| 144 | + box_prediction_head=box_prediction_head, |
| 145 | + class_prediction_head=class_prediction_head, |
| 146 | + third_stage_heads=third_stage_heads) |
122 | 147 | return box_predictor_object
|
123 | 148 |
|
124 | 149 | if box_predictor_oneof == 'rfcn_box_predictor':
|
125 |
| - rfcn_box_predictor = box_predictor_config.rfcn_box_predictor |
126 |
| - conv_hyperparams_fn = argscope_fn(rfcn_box_predictor.conv_hyperparams, |
| 150 | + config_box_predictor = box_predictor_config.rfcn_box_predictor |
| 151 | + conv_hyperparams_fn = argscope_fn(config_box_predictor.conv_hyperparams, |
127 | 152 | is_training)
|
128 |
| - box_predictor_object = box_predictor.RfcnBoxPredictor( |
| 153 | + box_predictor_object = rfcn_box_predictor.RfcnBoxPredictor( |
129 | 154 | is_training=is_training,
|
130 | 155 | num_classes=num_classes,
|
131 | 156 | conv_hyperparams_fn=conv_hyperparams_fn,
|
132 |
| - crop_size=[rfcn_box_predictor.crop_height, |
133 |
| - rfcn_box_predictor.crop_width], |
134 |
| - num_spatial_bins=[rfcn_box_predictor.num_spatial_bins_height, |
135 |
| - rfcn_box_predictor.num_spatial_bins_width], |
136 |
| - depth=rfcn_box_predictor.depth, |
137 |
| - box_code_size=rfcn_box_predictor.box_code_size) |
| 157 | + crop_size=[config_box_predictor.crop_height, |
| 158 | + config_box_predictor.crop_width], |
| 159 | + num_spatial_bins=[config_box_predictor.num_spatial_bins_height, |
| 160 | + config_box_predictor.num_spatial_bins_width], |
| 161 | + depth=config_box_predictor.depth, |
| 162 | + box_code_size=config_box_predictor.box_code_size) |
138 | 163 | return box_predictor_object
|
139 | 164 | raise ValueError('Unknown box predictor: {}'.format(box_predictor_oneof))
|
0 commit comments