1
+ # Copyright 2018 The TensorFlow Authors All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Dense Prediction Cell class that can be evolved in semantic segmentation.
17
+
18
+ DensePredictionCell is used as a `layer` in semantic segmentation whose
19
+ architecture is determined by the `config`, a dictionary specifying
20
+ the architecture.
21
+ """
22
+
23
+ from __future__ import absolute_import
24
+ from __future__ import division
25
+ from __future__ import print_function
26
+
27
+ import tensorflow as tf
28
+
29
+ from deeplab .core import utils
30
+
31
+ slim = tf .contrib .slim
32
+
33
+ # Local constants.
34
+ _META_ARCHITECTURE_SCOPE = 'meta_architecture'
35
+ _CONCAT_PROJECTION_SCOPE = 'concat_projection'
36
+ _OP = 'op'
37
+ _CONV = 'conv'
38
+ _PYRAMID_POOLING = 'pyramid_pooling'
39
+ _KERNEL = 'kernel'
40
+ _RATE = 'rate'
41
+ _GRID_SIZE = 'grid_size'
42
+ _TARGET_SIZE = 'target_size'
43
+ _INPUT = 'input'
44
+
45
+
46
+ def dense_prediction_cell_hparams ():
47
+ """DensePredictionCell HParams.
48
+
49
+ Returns:
50
+ A dictionary of hyper-parameters used for dense prediction cell with keys:
51
+ - reduction_size: Integer, the number of output filters for each operation
52
+ inside the cell.
53
+ - dropout_on_concat_features: Boolean, apply dropout on the concatenated
54
+ features or not.
55
+ - dropout_on_projection_features: Boolean, apply dropout on the projection
56
+ features or not.
57
+ - dropout_keep_prob: Float, when `dropout_on_concat_features' or
58
+ `dropout_on_projection_features' is True, the `keep_prob` value used
59
+ in the dropout operation.
60
+ - concat_channels: Integer, the concatenated features will be
61
+ channel-reduced to `concat_channels` channels.
62
+ - conv_rate_multiplier: Integer, used to multiply the convolution rates.
63
+ This is useful in the case when the output_stride is changed from 16
64
+ to 8, we need to double the convolution rates correspondingly.
65
+ """
66
+ return {
67
+ 'reduction_size' : 256 ,
68
+ 'dropout_on_concat_features' : True ,
69
+ 'dropout_on_projection_features' : False ,
70
+ 'dropout_keep_prob' : 0.9 ,
71
+ 'concat_channels' : 256 ,
72
+ 'conv_rate_multiplier' : 1 ,
73
+ }
74
+
75
+
76
+ class DensePredictionCell (object ):
77
+ """DensePredictionCell class used as a 'layer' in semantic segmentation."""
78
+
79
+ def __init__ (self , config , hparams = None ):
80
+ """Initializes the dense prediction cell.
81
+
82
+ Args:
83
+ config: A dictionary storing the architecture of a dense prediction cell.
84
+ hparams: A dictionary of hyper-parameters, provided by users. This
85
+ dictionary will be used to update the default dictionary returned by
86
+ dense_prediction_cell_hparams().
87
+
88
+ Raises:
89
+ ValueError: If `conv_rate_multiplier` has value < 1.
90
+ """
91
+ self .hparams = dense_prediction_cell_hparams ()
92
+ if hparams is not None :
93
+ self .hparams .update (hparams )
94
+ self .config = config
95
+
96
+ # Check values in hparams are valid or not.
97
+ if self .hparams ['conv_rate_multiplier' ] < 1 :
98
+ raise ValueError ('conv_rate_multiplier cannot have value < 1.' )
99
+
100
+ def _get_pyramid_pooling_arguments (
101
+ self , crop_size , output_stride , image_grid , image_pooling_crop_size = None ):
102
+ """Gets arguments for pyramid pooling.
103
+
104
+ Args:
105
+ crop_size: A list of two integers, [crop_height, crop_width] specifying
106
+ whole patch crop size.
107
+ output_stride: Integer, output stride value for extracted features.
108
+ image_grid: A list of two integers, [image_grid_height, image_grid_width],
109
+ specifying the grid size of how the pyramid pooling will be performed.
110
+ image_pooling_crop_size: A list of two integers, [crop_height, crop_width]
111
+ specifying the crop size for image pooling operations. Note that we
112
+ decouple whole patch crop_size and image_pooling_crop_size as one could
113
+ perform the image_pooling with different crop sizes.
114
+
115
+ Returns:
116
+ A list of (resize_value, pooled_kernel)
117
+ """
118
+ resize_height = utils .scale_dimension (crop_size [0 ], 1. / output_stride )
119
+ resize_width = utils .scale_dimension (crop_size [1 ], 1. / output_stride )
120
+ # If image_pooling_crop_size is not specified, use crop_size.
121
+ if image_pooling_crop_size is None :
122
+ image_pooling_crop_size = crop_size
123
+ pooled_height = utils .scale_dimension (
124
+ image_pooling_crop_size [0 ], 1. / (output_stride * image_grid [0 ]))
125
+ pooled_width = utils .scale_dimension (
126
+ image_pooling_crop_size [1 ], 1. / (output_stride * image_grid [1 ]))
127
+ return ([resize_height , resize_width ], [pooled_height , pooled_width ])
128
+
129
+ def _parse_operation (self , config , crop_size , output_stride ,
130
+ image_pooling_crop_size = None ):
131
+ """Parses one operation.
132
+
133
+ When 'operation' is 'pyramid_pooling', we compute the required
134
+ hyper-parameters and save in config.
135
+
136
+ Args:
137
+ config: A dictionary storing required hyper-parameters for one
138
+ operation.
139
+ crop_size: A list of two integers, [crop_height, crop_width] specifying
140
+ whole patch crop size.
141
+ output_stride: Integer, output stride value for extracted features.
142
+ image_pooling_crop_size: A list of two integers, [crop_height, crop_width]
143
+ specifying the crop size for image pooling operations. Note that we
144
+ decouple whole patch crop_size and image_pooling_crop_size as one could
145
+ perform the image_pooling with different crop sizes.
146
+
147
+ Returns:
148
+ A dictionary stores the related information for the operation.
149
+ """
150
+ if config [_OP ] == _PYRAMID_POOLING :
151
+ (config [_TARGET_SIZE ],
152
+ config [_KERNEL ]) = self ._get_pyramid_pooling_arguments (
153
+ crop_size = crop_size ,
154
+ output_stride = output_stride ,
155
+ image_grid = config [_GRID_SIZE ],
156
+ image_pooling_crop_size = image_pooling_crop_size )
157
+
158
+ return config
159
+
160
+ def build_cell (self ,
161
+ features ,
162
+ output_stride = 16 ,
163
+ crop_size = None ,
164
+ image_pooling_crop_size = None ,
165
+ weight_decay = 0.00004 ,
166
+ reuse = None ,
167
+ is_training = False ,
168
+ fine_tune_batch_norm = False ,
169
+ scope = None ):
170
+ """Builds the dense prediction cell based on the config.
171
+
172
+ Args:
173
+ features: Input feature map of size [batch, height, width, channels].
174
+ output_stride: Int, output stride at which the features were extracted.
175
+ crop_size: A list [crop_height, crop_width], determining the input
176
+ features resolution.
177
+ image_pooling_crop_size: A list of two integers, [crop_height, crop_width]
178
+ specifying the crop size for image pooling operations. Note that we
179
+ decouple whole patch crop_size and image_pooling_crop_size as one could
180
+ perform the image_pooling with different crop sizes.
181
+ weight_decay: Float, the weight decay for model variables.
182
+ reuse: Reuse the model variables or not.
183
+ is_training: Boolean, is training or not.
184
+ fine_tune_batch_norm: Boolean, fine-tuning batch norm parameters or not.
185
+ scope: Optional string, specifying the variable scope.
186
+
187
+ Returns:
188
+ Features after passing through the constructed dense prediction cell with
189
+ shape = [batch, height, width, channels] where channels are determined
190
+ by `reduction_size` returned by dense_prediction_cell_hparams().
191
+
192
+ Raises:
193
+ ValueError: Use Convolution with kernel size not equal to 1x1 or 3x3 or
194
+ the operation is not recognized.
195
+ """
196
+ batch_norm_params = {
197
+ 'is_training' : is_training and fine_tune_batch_norm ,
198
+ 'decay' : 0.9997 ,
199
+ 'epsilon' : 1e-5 ,
200
+ 'scale' : True ,
201
+ }
202
+ hparams = self .hparams
203
+ with slim .arg_scope (
204
+ [slim .conv2d , slim .separable_conv2d ],
205
+ weights_regularizer = slim .l2_regularizer (weight_decay ),
206
+ activation_fn = tf .nn .relu ,
207
+ normalizer_fn = slim .batch_norm ,
208
+ padding = 'SAME' ,
209
+ stride = 1 ,
210
+ reuse = reuse ):
211
+ with slim .arg_scope ([slim .batch_norm ], ** batch_norm_params ):
212
+ with tf .variable_scope (scope , _META_ARCHITECTURE_SCOPE , [features ]):
213
+ depth = hparams ['reduction_size' ]
214
+ branch_logits = []
215
+ for i , current_config in enumerate (self .config ):
216
+ scope = 'branch%d' % i
217
+ current_config = self ._parse_operation (
218
+ config = current_config ,
219
+ crop_size = crop_size ,
220
+ output_stride = output_stride ,
221
+ image_pooling_crop_size = image_pooling_crop_size )
222
+ tf .logging .info (current_config )
223
+ if current_config [_INPUT ] < 0 :
224
+ operation_input = features
225
+ else :
226
+ operation_input = branch_logits [current_config [_INPUT ]]
227
+ if current_config [_OP ] == _CONV :
228
+ if current_config [_KERNEL ] == [1 , 1 ] or current_config [
229
+ _KERNEL ] == 1 :
230
+ branch_logits .append (
231
+ slim .conv2d (operation_input , depth , 1 , scope = scope ))
232
+ else :
233
+ conv_rate = [r * hparams ['conv_rate_multiplier' ]
234
+ for r in current_config [_RATE ]]
235
+ branch_logits .append (
236
+ utils .split_separable_conv2d (
237
+ operation_input ,
238
+ filters = depth ,
239
+ kernel_size = current_config [_KERNEL ],
240
+ rate = conv_rate ,
241
+ weight_decay = weight_decay ,
242
+ scope = scope ))
243
+ elif current_config [_OP ] == _PYRAMID_POOLING :
244
+ pooled_features = slim .avg_pool2d (
245
+ operation_input ,
246
+ kernel_size = current_config [_KERNEL ],
247
+ stride = [1 , 1 ],
248
+ padding = 'VALID' )
249
+ pooled_features = slim .conv2d (
250
+ pooled_features ,
251
+ depth ,
252
+ 1 ,
253
+ scope = scope )
254
+ pooled_features = tf .image .resize_bilinear (
255
+ pooled_features ,
256
+ current_config [_TARGET_SIZE ],
257
+ align_corners = True )
258
+ # Set shape for resize_height/resize_width if they are not Tensor.
259
+ resize_height = current_config [_TARGET_SIZE ][0 ]
260
+ resize_width = current_config [_TARGET_SIZE ][1 ]
261
+ if isinstance (resize_height , tf .Tensor ):
262
+ resize_height = None
263
+ if isinstance (resize_width , tf .Tensor ):
264
+ resize_width = None
265
+ pooled_features .set_shape (
266
+ [None , resize_height , resize_width , depth ])
267
+ branch_logits .append (pooled_features )
268
+ else :
269
+ raise ValueError ('Unrecognized operation.' )
270
+ # Merge branch logits.
271
+ concat_logits = tf .concat (branch_logits , 3 )
272
+ if self .hparams ['dropout_on_concat_features' ]:
273
+ concat_logits = slim .dropout (
274
+ concat_logits ,
275
+ keep_prob = self .hparams ['dropout_keep_prob' ],
276
+ is_training = is_training ,
277
+ scope = _CONCAT_PROJECTION_SCOPE + '_dropout' )
278
+ concat_logits = slim .conv2d (concat_logits ,
279
+ self .hparams ['concat_channels' ],
280
+ 1 ,
281
+ scope = _CONCAT_PROJECTION_SCOPE )
282
+ if self .hparams ['dropout_on_projection_features' ]:
283
+ concat_logits = slim .dropout (
284
+ concat_logits ,
285
+ keep_prob = self .hparams ['dropout_keep_prob' ],
286
+ is_training = is_training ,
287
+ scope = _CONCAT_PROJECTION_SCOPE + '_dropout' )
288
+ return concat_logits
0 commit comments