Skip to content

Commit 310f70d

Browse files
authored
Adding stop threshold logic (tensorflow#3863)
* Adding tests * Adding tests * Repackaging * Adding logging * Linting
1 parent aad56e4 commit 310f70d

File tree

9 files changed

+167
-4
lines changed

9 files changed

+167
-4
lines changed

official/mnist/mnist.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from official.mnist import dataset
2626
from official.utils.arg_parsers import parsers
2727
from official.utils.logs import hooks_helper
28+
from official.utils.misc import model_helpers
2829

2930
LEARNING_RATE = 1e-4
3031

@@ -231,6 +232,10 @@ def eval_input_fn():
231232
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
232233
print('\nEvaluation results:\n\t%s\n' % eval_results)
233234

235+
if model_helpers.past_stop_threshold(
236+
flags.stop_threshold, eval_results['accuracy']):
237+
break
238+
234239
# Export the model
235240
if flags.export_dir is not None:
236241
image = tf.placeholder(tf.float32, [None, 28, 28])

official/mnist/mnist_eager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,7 @@ class MNISTEagerArgParser(argparse.ArgumentParser):
164164

165165
def __init__(self):
166166
super(MNISTEagerArgParser, self).__init__(parents=[
167-
parsers.BaseParser(
168-
epochs_between_evals=False, multi_gpu=False, hooks=False),
167+
parsers.EagerParser(),
169168
parsers.ImageModelParser()])
170169

171170
self.add_argument(

official/resnet/imagenet_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,5 +318,6 @@ def test_imagenet_end_to_end_synthetic_v2_huge(self):
318318
extra_flags=['-v', '2', '-rs', '200']
319319
)
320320

321+
321322
if __name__ == '__main__':
322323
tf.test.main()

official/resnet/resnet_run_loop.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from official.utils.export import export
3434
from official.utils.logs import hooks_helper
3535
from official.utils.logs import logger
36+
from official.utils.misc import model_helpers
3637

3738

3839
################################################################################
@@ -438,6 +439,10 @@ def input_fn_eval():
438439
if benchmark_logger:
439440
benchmark_logger.log_estimator_evaluation_result(eval_results)
440441

442+
if model_helpers.past_stop_threshold(
443+
flags.stop_threshold, eval_results['accuracy']):
444+
break
445+
441446
if flags.export_dir is not None:
442447
warn_on_multi_gpu_export(flags.multi_gpu)
443448

official/utils/arg_parsers/parsers.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,17 @@ class BaseParser(argparse.ArgumentParser):
9999
model_dir: Create a flag for specifying the model file directory.
100100
train_epochs: Create a flag to specify the number of training epochs.
101101
epochs_between_evals: Create a flag to specify the frequency of testing.
102+
stop_threshold: Create a flag to specify a threshold accuracy or other
103+
eval metric which should trigger the end of training.
102104
batch_size: Create a flag to specify the batch size.
103105
multi_gpu: Create a flag to allow the use of all available GPUs.
104106
hooks: Create a flag to specify hooks for logging.
105107
"""
106108

107109
def __init__(self, add_help=False, data_dir=True, model_dir=True,
108-
train_epochs=True, epochs_between_evals=True, batch_size=True,
109-
multi_gpu=True, hooks=True):
110+
train_epochs=True, epochs_between_evals=True,
111+
stop_threshold=True, batch_size=True, multi_gpu=True,
112+
hooks=True):
110113
super(BaseParser, self).__init__(add_help=add_help)
111114

112115
if data_dir:
@@ -139,6 +142,15 @@ def __init__(self, add_help=False, data_dir=True, model_dir=True,
139142
metavar="<EBE>"
140143
)
141144

145+
if stop_threshold:
146+
self.add_argument(
147+
"--stop_threshold", "-st", type=float, default=None,
148+
help="[default: %(default)s] If passed, training will stop at "
149+
"the earlier of train_epochs and when the evaluation metric is "
150+
"greater than or equal to stop_threshold.",
151+
metavar="<ST>"
152+
)
153+
142154
if batch_size:
143155
self.add_argument(
144156
"--batch_size", "-bs", type=int, default=32,
@@ -345,3 +357,15 @@ def __init__(self, add_help=False, benchmark_log_dir=True,
345357
" benchmark metric information will be uploaded.",
346358
metavar="<BMT>"
347359
)
360+
361+
362+
class EagerParser(BaseParser):
363+
"""Remove options not relevant for Eager from the BaseParser."""
364+
365+
def __init__(self, add_help=False, data_dir=True, model_dir=True,
366+
train_epochs=True, batch_size=True):
367+
super(EagerParser, self).__init__(
368+
add_help=add_help, data_dir=data_dir, model_dir=model_dir,
369+
train_epochs=train_epochs, epochs_between_evals=False,
370+
stop_threshold=False, batch_size=batch_size, multi_gpu=False,
371+
hooks=False)

official/utils/misc/__init__.py

Whitespace-only changes.

official/utils/misc/model_helpers.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Miscellaneous functions that can be called by models."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numbers
22+
23+
import tensorflow as tf
24+
25+
26+
def past_stop_threshold(stop_threshold, eval_metric):
27+
"""Return a boolean representing whether a model should be stopped.
28+
29+
Args:
30+
stop_threshold: float, the threshold above which a model should stop
31+
training.
32+
eval_metric: float, the current value of the relevant metric to check.
33+
34+
Returns:
35+
True if training should stop, False otherwise.
36+
37+
Raises:
38+
ValueError: if either stop_threshold or eval_metric is not a number
39+
"""
40+
if stop_threshold is None:
41+
return False
42+
43+
if not isinstance(stop_threshold, numbers.Number):
44+
raise ValueError("Threshold for checking stop conditions must be a number.")
45+
if not isinstance(eval_metric, numbers.Number):
46+
raise ValueError("Eval metric being checked against stop conditions "
47+
"must be a number.")
48+
49+
if eval_metric >= stop_threshold:
50+
tf.logging.info(
51+
"Stop threshold of {} was passed with metric value {}.".format(
52+
stop_threshold, eval_metric))
53+
return True
54+
55+
return False
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
""" Tests for Model Helper functions."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf # pylint: disable=g-bad-import-order
22+
23+
from official.utils.misc import model_helpers
24+
25+
26+
class PastStopThresholdTest(tf.test.TestCase):
27+
"""Tests for past_stop_threshold."""
28+
29+
def test_past_stop_threshold(self):
30+
"""Tests for normal operating conditions."""
31+
self.assertTrue(model_helpers.past_stop_threshold(0.54, 1))
32+
self.assertTrue(model_helpers.past_stop_threshold(54, 100))
33+
self.assertFalse(model_helpers.past_stop_threshold(0.54, 0.1))
34+
self.assertFalse(model_helpers.past_stop_threshold(-0.54, -1.5))
35+
self.assertTrue(model_helpers.past_stop_threshold(-0.54, 0))
36+
self.assertTrue(model_helpers.past_stop_threshold(0, 0))
37+
self.assertTrue(model_helpers.past_stop_threshold(0.54, 0.54))
38+
39+
def test_past_stop_threshold_none_false(self):
40+
"""Tests that check None returns false."""
41+
self.assertFalse(model_helpers.past_stop_threshold(None, -1.5))
42+
self.assertFalse(model_helpers.past_stop_threshold(None, None))
43+
self.assertFalse(model_helpers.past_stop_threshold(None, 1.5))
44+
# Zero should be okay, though.
45+
self.assertTrue(model_helpers.past_stop_threshold(0, 1.5))
46+
47+
def test_past_stop_threshold_not_number(self):
48+
"""Tests for error conditions."""
49+
with self.assertRaises(ValueError):
50+
model_helpers.past_stop_threshold("str", 1)
51+
52+
with self.assertRaises(ValueError):
53+
model_helpers.past_stop_threshold("str", tf.constant(5))
54+
55+
with self.assertRaises(ValueError):
56+
model_helpers.past_stop_threshold("str", "another")
57+
58+
with self.assertRaises(ValueError):
59+
model_helpers.past_stop_threshold(0, None)
60+
61+
with self.assertRaises(ValueError):
62+
model_helpers.past_stop_threshold(0.7, "str")
63+
64+
with self.assertRaises(ValueError):
65+
model_helpers.past_stop_threshold(tf.constant(4), None)
66+
67+
68+
if __name__ == "__main__":
69+
tf.test.main()

official/wide_deep/wide_deep.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from official.utils.arg_parsers import parsers
2828
from official.utils.logs import hooks_helper
29+
from official.utils.misc import model_helpers
2930

3031
_CSV_COLUMNS = [
3132
'age', 'workclass', 'fnlwgt', 'education', 'education_num',
@@ -211,6 +212,10 @@ def eval_input_fn():
211212
for key in sorted(results):
212213
print('%s: %s' % (key, results[key]))
213214

215+
if model_helpers.past_stop_threshold(
216+
flags.stop_threshold, results['accuracy']):
217+
break
218+
214219

215220
class WideDeepArgParser(argparse.ArgumentParser):
216221
"""Argument parser for running the wide deep model."""

0 commit comments

Comments
 (0)