Skip to content

Commit b9ca525

Browse files
author
Jonathan Huang
authored
Merge pull request tensorflow#4232 from pkulzc/master
Release ssdlite mobilenet v2 coco trained model, add quantized training and minor fixes.
2 parents 0270cac + 324d6dc commit b9ca525

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1088
-299
lines changed

research/object_detection/box_coders/mean_stddev_box_coder.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@
2525
class MeanStddevBoxCoder(box_coder.BoxCoder):
2626
"""Mean stddev box coder."""
2727

28+
def __init__(self, stddev=0.01):
29+
"""Constructor for MeanStddevBoxCoder.
30+
31+
Args:
32+
stddev: The standard deviation used to encode and decode boxes.
33+
"""
34+
self._stddev = stddev
35+
2836
@property
2937
def code_size(self):
3038
return 4
@@ -34,37 +42,38 @@ def _encode(self, boxes, anchors):
3442
3543
Args:
3644
boxes: BoxList holding N boxes to be encoded.
37-
anchors: BoxList of N anchors. We assume that anchors has an associated
38-
stddev field.
45+
anchors: BoxList of N anchors.
3946
4047
Returns:
4148
a tensor representing N anchor-encoded boxes
49+
4250
Raises:
43-
ValueError: if the anchors BoxList does not have a stddev field
51+
ValueError: if the anchors still have deprecated stddev field.
4452
"""
45-
if not anchors.has_field('stddev'):
46-
raise ValueError('anchors must have a stddev field')
4753
box_corners = boxes.get()
54+
if anchors.has_field('stddev'):
55+
raise ValueError("'stddev' is a parameter of MeanStddevBoxCoder and "
56+
"should not be specified in the box list.")
4857
means = anchors.get()
49-
stddev = anchors.get_field('stddev')
50-
return (box_corners - means) / stddev
58+
return (box_corners - means) / self._stddev
5159

5260
def _decode(self, rel_codes, anchors):
5361
"""Decode.
5462
5563
Args:
5664
rel_codes: a tensor representing N anchor-encoded boxes.
57-
anchors: BoxList of anchors. We assume that anchors has an associated
58-
stddev field.
65+
anchors: BoxList of anchors.
5966
6067
Returns:
6168
boxes: BoxList holding N bounding boxes
69+
6270
Raises:
63-
ValueError: if the anchors BoxList does not have a stddev field
71+
ValueError: if the anchors still have deprecated stddev field and expects
72+
the decode method to use stddev value from that field.
6473
"""
65-
if not anchors.has_field('stddev'):
66-
raise ValueError('anchors must have a stddev field')
6774
means = anchors.get()
68-
stddevs = anchors.get_field('stddev')
69-
box_corners = rel_codes * stddevs + means
75+
if anchors.has_field('stddev'):
76+
raise ValueError("'stddev' is a parameter of MeanStddevBoxCoder and "
77+
"should not be specified in the box list.")
78+
box_corners = rel_codes * self._stddev + means
7079
return box_list.BoxList(box_corners)

research/object_detection/box_coders/mean_stddev_box_coder_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@ def testGetCorrectRelativeCodesAfterEncoding(self):
2828
boxes = box_list.BoxList(tf.constant(box_corners))
2929
expected_rel_codes = [[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]]
3030
prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]])
31-
prior_stddevs = tf.constant(2 * [4 * [.1]])
3231
priors = box_list.BoxList(prior_means)
33-
priors.add_field('stddev', prior_stddevs)
3432

35-
coder = mean_stddev_box_coder.MeanStddevBoxCoder()
33+
coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
3634
rel_codes = coder.encode(boxes, priors)
3735
with self.test_session() as sess:
3836
rel_codes_out = sess.run(rel_codes)
@@ -42,11 +40,9 @@ def testGetCorrectBoxesAfterDecoding(self):
4240
rel_codes = tf.constant([[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]])
4341
expected_box_corners = [[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5]]
4442
prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]])
45-
prior_stddevs = tf.constant(2 * [4 * [.1]])
4643
priors = box_list.BoxList(prior_means)
47-
priors.add_field('stddev', prior_stddevs)
4844

49-
coder = mean_stddev_box_coder.MeanStddevBoxCoder()
45+
coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
5046
decoded_boxes = coder.decode(rel_codes, priors)
5147
decoded_box_corners = decoded_boxes.get()
5248
with self.test_session() as sess:

research/object_detection/builders/box_coder_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def build(box_coder_config):
5555
])
5656
if (box_coder_config.WhichOneof('box_coder_oneof') ==
5757
'mean_stddev_box_coder'):
58-
return mean_stddev_box_coder.MeanStddevBoxCoder()
58+
return mean_stddev_box_coder.MeanStddevBoxCoder(
59+
stddev=box_coder_config.mean_stddev_box_coder.stddev)
5960
if box_coder_config.WhichOneof('box_coder_oneof') == 'square_box_coder':
6061
return square_box_coder.SquareBoxCoder(scale_factors=[
6162
box_coder_config.square_box_coder.y_scale,
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Functions for quantized training and evaluation."""
16+
17+
import tensorflow as tf
18+
19+
20+
def build(graph_rewriter_config, is_training):
21+
"""Returns a function that modifies default graph based on options.
22+
23+
Args:
24+
graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto.
25+
is_training: whether in training of eval mode.
26+
"""
27+
def graph_rewrite_fn():
28+
"""Function to quantize weights and activation of the default graph."""
29+
if (graph_rewriter_config.quantization.weight_bits != 8 or
30+
graph_rewriter_config.quantization.activation_bits != 8):
31+
raise ValueError('Only 8bit quantization is supported')
32+
33+
# Quantize the graph by inserting quantize ops for weights and activations
34+
if is_training:
35+
tf.contrib.quantize.create_training_graph(
36+
input_graph=tf.get_default_graph(),
37+
quant_delay=graph_rewriter_config.quantization.delay)
38+
else:
39+
tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph())
40+
41+
tf.contrib.layers.summarize_collection('quant_vars')
42+
return graph_rewrite_fn
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for graph_rewriter_builder."""
16+
import mock
17+
import tensorflow as tf
18+
from object_detection.builders import graph_rewriter_builder
19+
from object_detection.protos import graph_rewriter_pb2
20+
21+
22+
class QuantizationBuilderTest(tf.test.TestCase):
23+
24+
def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
25+
with mock.patch.object(
26+
tf.contrib.quantize, 'create_training_graph') as mock_quant_fn:
27+
with mock.patch.object(tf.contrib.layers,
28+
'summarize_collection') as mock_summarize_col:
29+
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
30+
graph_rewriter_proto.quantization.delay = 10
31+
graph_rewriter_proto.quantization.weight_bits = 8
32+
graph_rewriter_proto.quantization.activation_bits = 8
33+
graph_rewrite_fn = graph_rewriter_builder.build(
34+
graph_rewriter_proto, is_training=True)
35+
graph_rewrite_fn()
36+
_, kwargs = mock_quant_fn.call_args
37+
self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
38+
self.assertEqual(kwargs['quant_delay'], 10)
39+
mock_summarize_col.assert_called_with('quant_vars')
40+
41+
def testQuantizationBuilderSetsUpCorrectEvalArguments(self):
42+
with mock.patch.object(tf.contrib.quantize,
43+
'create_eval_graph') as mock_quant_fn:
44+
with mock.patch.object(tf.contrib.layers,
45+
'summarize_collection') as mock_summarize_col:
46+
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
47+
graph_rewriter_proto.quantization.delay = 10
48+
graph_rewrite_fn = graph_rewriter_builder.build(
49+
graph_rewriter_proto, is_training=False)
50+
graph_rewrite_fn()
51+
_, kwargs = mock_quant_fn.call_args
52+
self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
53+
mock_summarize_col.assert_called_with('quant_vars')
54+
55+
56+
if __name__ == '__main__':
57+
tf.test.main()

research/object_detection/builders/losses_builder.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""A function to build localization and classification losses from config."""
1717

18+
from object_detection.core import balanced_positive_negative_sampler as sampler
1819
from object_detection.core import losses
1920
from object_detection.protos import losses_pb2
2021

@@ -34,9 +35,12 @@ def build(loss_config):
3435
classification_weight: Classification loss weight.
3536
localization_weight: Localization loss weight.
3637
hard_example_miner: Hard example miner object.
38+
random_example_sampler: BalancedPositiveNegativeSampler object.
3739
3840
Raises:
3941
ValueError: If hard_example_miner is used with sigmoid_focal_loss.
42+
ValueError: If random_example_sampler is getting non-positive value as
43+
desired positive example fraction.
4044
"""
4145
classification_loss = _build_classification_loss(
4246
loss_config.classification_loss)
@@ -54,9 +58,16 @@ def build(loss_config):
5458
loss_config.hard_example_miner,
5559
classification_weight,
5660
localization_weight)
57-
return (classification_loss, localization_loss,
58-
classification_weight,
59-
localization_weight, hard_example_miner)
61+
random_example_sampler = None
62+
if loss_config.HasField('random_example_sampler'):
63+
if loss_config.random_example_sampler.positive_sample_fraction <= 0:
64+
raise ValueError('RandomExampleSampler should not use non-positive'
65+
'value as positive sample fraction.')
66+
random_example_sampler = sampler.BalancedPositiveNegativeSampler(
67+
positive_fraction=loss_config.random_example_sampler.
68+
positive_sample_fraction)
69+
return (classification_loss, localization_loss, classification_weight,
70+
localization_weight, hard_example_miner, random_example_sampler)
6071

6172

6273
def build_hard_example_miner(config,

0 commit comments

Comments
 (0)