Skip to content

Commit 5f9f6b8

Browse files
author
Taylor Robie
authored
Move argparsing from builtin argparse to absl (tensorflow#4099)
* squash of modular absl usage commits * delint * address PR comments * change hooks to comma separated list, as absl behavior for space separated lists is not as expected
1 parent 6ec3452 commit 5f9f6b8

22 files changed

+891
-636
lines changed

official/mnist/mnist.py

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020
import argparse
2121
import sys
2222

23+
from absl import app as absl_app
24+
from absl import flags
2325
import tensorflow as tf # pylint: disable=g-bad-import-order
2426

2527
from official.mnist import dataset
26-
from official.utils.arg_parsers import parsers
28+
from official.utils.flags import core as flags_core
2729
from official.utils.logs import hooks_helper
2830
from official.utils.misc import model_helpers
2931

32+
3033
LEARNING_RATE = 1e-4
3134

3235

@@ -86,6 +89,16 @@ def create_model(data_format):
8689
])
8790

8891

92+
def define_mnist_flags():
93+
flags_core.define_base()
94+
flags_core.define_image()
95+
flags.adopt_module_key_flags(flags_core)
96+
flags_core.set_defaults(data_dir='/tmp/mnist_data',
97+
model_dir='/tmp/mnist_model',
98+
batch_size=100,
99+
train_epochs=40)
100+
101+
89102
def model_fn(features, labels, mode, params):
90103
"""The model_fn argument for creating an Estimator."""
91104
model = create_model(params['data_format'])
@@ -172,31 +185,28 @@ def validate_batch_size_for_multi_gpu(batch_size):
172185
raise ValueError(err)
173186

174187

175-
def main(argv):
176-
parser = MNISTArgParser()
177-
flags = parser.parse_args(args=argv[1:])
178-
188+
def main(flags_obj):
179189
model_function = model_fn
180190

181-
if flags.multi_gpu:
182-
validate_batch_size_for_multi_gpu(flags.batch_size)
191+
if flags_obj.multi_gpu:
192+
validate_batch_size_for_multi_gpu(flags_obj.batch_size)
183193

184194
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
185195
# and (2) wrap the optimizer. The first happens here, and (2) happens
186196
# in the model_fn itself when the optimizer is defined.
187197
model_function = tf.contrib.estimator.replicate_model_fn(
188198
model_fn, loss_reduction=tf.losses.Reduction.MEAN)
189199

190-
data_format = flags.data_format
200+
data_format = flags_obj.data_format
191201
if data_format is None:
192202
data_format = ('channels_first'
193203
if tf.test.is_built_with_cuda() else 'channels_last')
194204
mnist_classifier = tf.estimator.Estimator(
195205
model_fn=model_function,
196-
model_dir=flags.model_dir,
206+
model_dir=flags_obj.model_dir,
197207
params={
198208
'data_format': data_format,
199-
'multi_gpu': flags.multi_gpu
209+
'multi_gpu': flags_obj.multi_gpu
200210
})
201211

202212
# Set up training and evaluation input functions.
@@ -206,57 +216,42 @@ def train_input_fn():
206216
# When choosing shuffle buffer sizes, larger sizes result in better
207217
# randomness, while smaller sizes use less memory. MNIST is a small
208218
# enough dataset that we can easily shuffle the full epoch.
209-
ds = dataset.train(flags.data_dir)
210-
ds = ds.cache().shuffle(buffer_size=50000).batch(flags.batch_size)
219+
ds = dataset.train(flags_obj.data_dir)
220+
ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
211221

212222
# Iterate through the dataset a set number (`epochs_between_evals`) of times
213223
# during each training session.
214-
ds = ds.repeat(flags.epochs_between_evals)
224+
ds = ds.repeat(flags_obj.epochs_between_evals)
215225
return ds
216226

217227
def eval_input_fn():
218-
return dataset.test(flags.data_dir).batch(
219-
flags.batch_size).make_one_shot_iterator().get_next()
228+
return dataset.test(flags_obj.data_dir).batch(
229+
flags_obj.batch_size).make_one_shot_iterator().get_next()
220230

221231
# Set up hook that outputs training logs every 100 steps.
222232
train_hooks = hooks_helper.get_train_hooks(
223-
flags.hooks, batch_size=flags.batch_size)
233+
flags_obj.hooks, batch_size=flags_obj.batch_size)
224234

225235
# Train and evaluate model.
226-
for _ in range(flags.train_epochs // flags.epochs_between_evals):
236+
for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
227237
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
228238
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
229239
print('\nEvaluation results:\n\t%s\n' % eval_results)
230240

231-
if model_helpers.past_stop_threshold(flags.stop_threshold,
241+
if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
232242
eval_results['accuracy']):
233243
break
234244

235245
# Export the model
236-
if flags.export_dir is not None:
246+
if flags_obj.export_dir is not None:
237247
image = tf.placeholder(tf.float32, [None, 28, 28])
238248
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
239249
'image': image,
240250
})
241-
mnist_classifier.export_savedmodel(flags.export_dir, input_fn)
242-
243-
244-
class MNISTArgParser(argparse.ArgumentParser):
245-
"""Argument parser for running MNIST model."""
246-
247-
def __init__(self):
248-
super(MNISTArgParser, self).__init__(parents=[
249-
parsers.BaseParser(),
250-
parsers.ImageModelParser(),
251-
])
252-
253-
self.set_defaults(
254-
data_dir='/tmp/mnist_data',
255-
model_dir='/tmp/mnist_model',
256-
batch_size=100,
257-
train_epochs=40)
251+
mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn)
258252

259253

260254
if __name__ == '__main__':
261255
tf.logging.set_verbosity(tf.logging.INFO)
262-
main(argv=sys.argv)
256+
define_mnist_flags()
257+
absl_app.run(main)

official/mnist/mnist_eager.py

Lines changed: 57 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,20 @@
2626
from __future__ import division
2727
from __future__ import print_function
2828

29-
import argparse
3029
import os
3130
import sys
3231
import time
3332

34-
import tensorflow as tf # pylint: disable=g-bad-import-order
35-
import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order
33+
# pylint: disable=g-bad-import-order
34+
from absl import app as absl_app
35+
from absl import flags
36+
import tensorflow as tf
37+
import tensorflow.contrib.eager as tfe
38+
# pylint: enable=g-bad-import-order
3639

3740
from official.mnist import dataset as mnist_dataset
3841
from official.mnist import mnist
39-
from official.utils.arg_parsers import parsers
42+
from official.utils.flags import core as flags_core
4043

4144

4245
def loss(logits, labels):
@@ -95,38 +98,36 @@ def test(model, dataset):
9598
tf.contrib.summary.scalar('accuracy', accuracy.result())
9699

97100

98-
def main(argv):
99-
parser = MNISTEagerArgParser()
100-
flags = parser.parse_args(args=argv[1:])
101-
101+
def main(flags_obj):
102102
tf.enable_eager_execution()
103103

104104
# Automatically determine device and data_format
105105
(device, data_format) = ('/gpu:0', 'channels_first')
106-
if flags.no_gpu or not tf.test.is_gpu_available():
106+
if flags_obj.no_gpu or tf.test.is_gpu_available():
107107
(device, data_format) = ('/cpu:0', 'channels_last')
108108
# If data_format is defined in FLAGS, overwrite automatically set value.
109-
if flags.data_format is not None:
110-
data_format = flags.data_format
109+
if flags_obj.data_format is not None:
110+
data_format = flags_obj.data_format
111111
print('Using device %s, and data format %s.' % (device, data_format))
112112

113113
# Load the datasets
114-
train_ds = mnist_dataset.train(flags.data_dir).shuffle(60000).batch(
115-
flags.batch_size)
116-
test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size)
114+
train_ds = mnist_dataset.train(flags_obj.data_dir).shuffle(60000).batch(
115+
flags_obj.batch_size)
116+
test_ds = mnist_dataset.test(flags_obj.data_dir).batch(
117+
flags_obj.batch_size)
117118

118119
# Create the model and optimizer
119120
model = mnist.create_model(data_format)
120-
optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum)
121+
optimizer = tf.train.MomentumOptimizer(flags_obj.lr, flags_obj.momentum)
121122

122123
# Create file writers for writing TensorBoard summaries.
123-
if flags.output_dir:
124+
if flags_obj.output_dir:
124125
# Create directories to which summaries will be written
125126
# tensorboard --logdir=<output_dir>
126127
# can then be used to see the recorded summaries.
127-
train_dir = os.path.join(flags.output_dir, 'train')
128-
test_dir = os.path.join(flags.output_dir, 'eval')
129-
tf.gfile.MakeDirs(flags.output_dir)
128+
train_dir = os.path.join(flags_obj.output_dir, 'train')
129+
test_dir = os.path.join(flags_obj.output_dir, 'eval')
130+
tf.gfile.MakeDirs(flags_obj.output_dir)
130131
else:
131132
train_dir = None
132133
test_dir = None
@@ -136,19 +137,20 @@ def main(argv):
136137
test_dir, flush_millis=10000, name='test')
137138

138139
# Create and restore checkpoint (if one exists on the path)
139-
checkpoint_prefix = os.path.join(flags.model_dir, 'ckpt')
140+
checkpoint_prefix = os.path.join(flags_obj.model_dir, 'ckpt')
140141
step_counter = tf.train.get_or_create_global_step()
141142
checkpoint = tfe.Checkpoint(
142143
model=model, optimizer=optimizer, step_counter=step_counter)
143144
# Restore variables on creation if a checkpoint exists.
144-
checkpoint.restore(tf.train.latest_checkpoint(flags.model_dir))
145+
checkpoint.restore(tf.train.latest_checkpoint(flags_obj.model_dir))
145146

146147
# Train and evaluate for a set number of epochs.
147148
with tf.device(device):
148-
for _ in range(flags.train_epochs):
149+
for _ in range(flags_obj.train_epochs):
149150
start = time.time()
150151
with summary_writer.as_default():
151-
train(model, optimizer, train_ds, step_counter, flags.log_interval)
152+
train(model, optimizer, train_ds, step_counter,
153+
flags_obj.log_interval)
152154
end = time.time()
153155
print('\nTrain time for epoch #%d (%d total steps): %f' %
154156
(checkpoint.save_counter.numpy() + 1,
@@ -159,50 +161,37 @@ def main(argv):
159161
checkpoint.save(checkpoint_prefix)
160162

161163

162-
class MNISTEagerArgParser(argparse.ArgumentParser):
163-
"""Argument parser for running MNIST model with eager training loop."""
164-
165-
def __init__(self):
166-
super(MNISTEagerArgParser, self).__init__(parents=[
167-
parsers.EagerParser(),
168-
parsers.ImageModelParser()])
169-
170-
self.add_argument(
171-
'--log_interval', '-li',
172-
type=int,
173-
default=10,
174-
metavar='N',
175-
help='[default: %(default)s] batches between logging training status')
176-
self.add_argument(
177-
'--output_dir', '-od',
178-
type=str,
179-
default=None,
180-
metavar='<OD>',
181-
help='[default: %(default)s] Directory to write TensorBoard summaries')
182-
self.add_argument(
183-
'--lr', '-lr',
184-
type=float,
185-
default=0.01,
186-
metavar='<LR>',
187-
help='[default: %(default)s] learning rate')
188-
self.add_argument(
189-
'--momentum', '-m',
190-
type=float,
191-
default=0.5,
192-
metavar='<M>',
193-
help='[default: %(default)s] SGD momentum')
194-
self.add_argument(
195-
'--no_gpu', '-nogpu',
196-
action='store_true',
197-
default=False,
198-
help='disables GPU usage even if a GPU is available')
199-
200-
self.set_defaults(
201-
data_dir='/tmp/tensorflow/mnist/input_data',
202-
model_dir='/tmp/tensorflow/mnist/checkpoints/',
203-
batch_size=100,
204-
train_epochs=10,
205-
)
164+
def define_mnist_eager_flags():
165+
"""Defined flags and defaults for MNIST in eager mode."""
166+
flags_core.define_base_eager()
167+
flags_core.define_image()
168+
flags.adopt_module_key_flags(flags_core)
169+
170+
flags.DEFINE_integer(
171+
name='log_interval', short_name='li', default=10,
172+
help=flags_core.help_wrap('batches between logging training status'))
173+
174+
flags.DEFINE_string(
175+
name='output_dir', short_name='od', default=None,
176+
help=flags_core.help_wrap('Directory to write TensorBoard summaries'))
177+
178+
flags.DEFINE_float(name='learning_rate', short_name='lr', default=0.01,
179+
help=flags_core.help_wrap('Learning rate.'))
180+
181+
flags.DEFINE_float(name='momentum', short_name='m', default=0.5,
182+
help=flags_core.help_wrap('SGD momentum.'))
183+
184+
flags.DEFINE_bool(name='no_gpu', short_name='nogpu', default=False,
185+
help=flags_core.help_wrap(
186+
'disables GPU usage even if a GPU is available'))
187+
188+
flags_core.set_defaults(
189+
data_dir='/tmp/tensorflow/mnist/input_data',
190+
model_dir='/tmp/tensorflow/mnist/checkpoints/',
191+
batch_size=100,
192+
train_epochs=10,
193+
)
206194

207195
if __name__ == '__main__':
208-
main(argv=sys.argv)
196+
define_mnist_eager_flags()
197+
absl_app.run(main=main)

official/resnet/cifar10_main.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121
import os
2222
import sys
2323

24+
from absl import app as absl_app
25+
from absl import flags
2426
import tensorflow as tf # pylint: disable=g-bad-import-order
2527

28+
from official.utils.flags import core as flags_core
2629
from official.resnet import resnet_model
2730
from official.resnet import resnet_run_loop
2831

@@ -224,25 +227,27 @@ def loss_filter_fn(_):
224227
)
225228

226229

227-
def main(argv):
228-
parser = resnet_run_loop.ResnetArgParser()
229-
# Set defaults that are reasonable for this model.
230-
parser.set_defaults(data_dir='/tmp/cifar10_data',
231-
model_dir='/tmp/cifar10_model',
232-
resnet_size=32,
233-
train_epochs=250,
234-
epochs_between_evals=10,
235-
batch_size=128)
230+
def define_cifar_flags():
231+
resnet_run_loop.define_resnet_flags()
232+
flags.adopt_module_key_flags(resnet_run_loop)
233+
flags_core.set_defaults(data_dir='/tmp/cifar10_data',
234+
model_dir='/tmp/cifar10_model',
235+
resnet_size='32',
236+
train_epochs=250,
237+
epochs_between_evals=10,
238+
batch_size=128)
236239

237-
flags = parser.parse_args(args=argv[1:])
238240

239-
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
241+
def main(flags_obj):
242+
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
243+
or input_fn)
240244

241245
resnet_run_loop.resnet_main(
242-
flags, cifar10_model_fn, input_function,
246+
flags_obj, cifar10_model_fn, input_function,
243247
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
244248

245249

246250
if __name__ == '__main__':
247251
tf.logging.set_verbosity(tf.logging.INFO)
248-
main(argv=sys.argv)
252+
define_cifar_flags()
253+
absl_app.run(main)

official/resnet/cifar10_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ class BaseTest(tf.test.TestCase):
3737
"""Tests for the Cifar10 version of Resnet.
3838
"""
3939

40+
@classmethod
41+
def setUpClass(cls): # pylint: disable=invalid-name
42+
super(BaseTest, cls).setUpClass()
43+
cifar10_main.define_cifar_flags()
44+
4045
def tearDown(self):
4146
super(BaseTest, self).tearDown()
4247
tf.gfile.DeleteRecursively(self.get_temp_dir())

0 commit comments

Comments
 (0)