26
26
from __future__ import division
27
27
from __future__ import print_function
28
28
29
- import argparse
30
29
import os
31
30
import sys
32
31
import time
33
32
34
- import tensorflow as tf # pylint: disable=g-bad-import-order
35
- import tensorflow .contrib .eager as tfe # pylint: disable=g-bad-import-order
33
+ # pylint: disable=g-bad-import-order
34
+ from absl import app as absl_app
35
+ from absl import flags
36
+ import tensorflow as tf
37
+ import tensorflow .contrib .eager as tfe
38
+ # pylint: enable=g-bad-import-order
36
39
37
40
from official .mnist import dataset as mnist_dataset
38
41
from official .mnist import mnist
39
- from official .utils .arg_parsers import parsers
42
+ from official .utils .flags import core as flags_core
40
43
41
44
42
45
def loss (logits , labels ):
@@ -95,38 +98,36 @@ def test(model, dataset):
95
98
tf .contrib .summary .scalar ('accuracy' , accuracy .result ())
96
99
97
100
98
- def main (argv ):
99
- parser = MNISTEagerArgParser ()
100
- flags = parser .parse_args (args = argv [1 :])
101
-
101
+ def main (flags_obj ):
102
102
tf .enable_eager_execution ()
103
103
104
104
# Automatically determine device and data_format
105
105
(device , data_format ) = ('/gpu:0' , 'channels_first' )
106
- if flags .no_gpu or not tf .test .is_gpu_available ():
106
+ if flags_obj .no_gpu or tf .test .is_gpu_available ():
107
107
(device , data_format ) = ('/cpu:0' , 'channels_last' )
108
108
# If data_format is defined in FLAGS, overwrite automatically set value.
109
- if flags .data_format is not None :
110
- data_format = flags .data_format
109
+ if flags_obj .data_format is not None :
110
+ data_format = flags_obj .data_format
111
111
print ('Using device %s, and data format %s.' % (device , data_format ))
112
112
113
113
# Load the datasets
114
- train_ds = mnist_dataset .train (flags .data_dir ).shuffle (60000 ).batch (
115
- flags .batch_size )
116
- test_ds = mnist_dataset .test (flags .data_dir ).batch (flags .batch_size )
114
+ train_ds = mnist_dataset .train (flags_obj .data_dir ).shuffle (60000 ).batch (
115
+ flags_obj .batch_size )
116
+ test_ds = mnist_dataset .test (flags_obj .data_dir ).batch (
117
+ flags_obj .batch_size )
117
118
118
119
# Create the model and optimizer
119
120
model = mnist .create_model (data_format )
120
- optimizer = tf .train .MomentumOptimizer (flags .lr , flags .momentum )
121
+ optimizer = tf .train .MomentumOptimizer (flags_obj .lr , flags_obj .momentum )
121
122
122
123
# Create file writers for writing TensorBoard summaries.
123
- if flags .output_dir :
124
+ if flags_obj .output_dir :
124
125
# Create directories to which summaries will be written
125
126
# tensorboard --logdir=<output_dir>
126
127
# can then be used to see the recorded summaries.
127
- train_dir = os .path .join (flags .output_dir , 'train' )
128
- test_dir = os .path .join (flags .output_dir , 'eval' )
129
- tf .gfile .MakeDirs (flags .output_dir )
128
+ train_dir = os .path .join (flags_obj .output_dir , 'train' )
129
+ test_dir = os .path .join (flags_obj .output_dir , 'eval' )
130
+ tf .gfile .MakeDirs (flags_obj .output_dir )
130
131
else :
131
132
train_dir = None
132
133
test_dir = None
@@ -136,19 +137,20 @@ def main(argv):
136
137
test_dir , flush_millis = 10000 , name = 'test' )
137
138
138
139
# Create and restore checkpoint (if one exists on the path)
139
- checkpoint_prefix = os .path .join (flags .model_dir , 'ckpt' )
140
+ checkpoint_prefix = os .path .join (flags_obj .model_dir , 'ckpt' )
140
141
step_counter = tf .train .get_or_create_global_step ()
141
142
checkpoint = tfe .Checkpoint (
142
143
model = model , optimizer = optimizer , step_counter = step_counter )
143
144
# Restore variables on creation if a checkpoint exists.
144
- checkpoint .restore (tf .train .latest_checkpoint (flags .model_dir ))
145
+ checkpoint .restore (tf .train .latest_checkpoint (flags_obj .model_dir ))
145
146
146
147
# Train and evaluate for a set number of epochs.
147
148
with tf .device (device ):
148
- for _ in range (flags .train_epochs ):
149
+ for _ in range (flags_obj .train_epochs ):
149
150
start = time .time ()
150
151
with summary_writer .as_default ():
151
- train (model , optimizer , train_ds , step_counter , flags .log_interval )
152
+ train (model , optimizer , train_ds , step_counter ,
153
+ flags_obj .log_interval )
152
154
end = time .time ()
153
155
print ('\n Train time for epoch #%d (%d total steps): %f' %
154
156
(checkpoint .save_counter .numpy () + 1 ,
@@ -159,50 +161,37 @@ def main(argv):
159
161
checkpoint .save (checkpoint_prefix )
160
162
161
163
162
- class MNISTEagerArgParser (argparse .ArgumentParser ):
163
- """Argument parser for running MNIST model with eager training loop."""
164
-
165
- def __init__ (self ):
166
- super (MNISTEagerArgParser , self ).__init__ (parents = [
167
- parsers .EagerParser (),
168
- parsers .ImageModelParser ()])
169
-
170
- self .add_argument (
171
- '--log_interval' , '-li' ,
172
- type = int ,
173
- default = 10 ,
174
- metavar = 'N' ,
175
- help = '[default: %(default)s] batches between logging training status' )
176
- self .add_argument (
177
- '--output_dir' , '-od' ,
178
- type = str ,
179
- default = None ,
180
- metavar = '<OD>' ,
181
- help = '[default: %(default)s] Directory to write TensorBoard summaries' )
182
- self .add_argument (
183
- '--lr' , '-lr' ,
184
- type = float ,
185
- default = 0.01 ,
186
- metavar = '<LR>' ,
187
- help = '[default: %(default)s] learning rate' )
188
- self .add_argument (
189
- '--momentum' , '-m' ,
190
- type = float ,
191
- default = 0.5 ,
192
- metavar = '<M>' ,
193
- help = '[default: %(default)s] SGD momentum' )
194
- self .add_argument (
195
- '--no_gpu' , '-nogpu' ,
196
- action = 'store_true' ,
197
- default = False ,
198
- help = 'disables GPU usage even if a GPU is available' )
199
-
200
- self .set_defaults (
201
- data_dir = '/tmp/tensorflow/mnist/input_data' ,
202
- model_dir = '/tmp/tensorflow/mnist/checkpoints/' ,
203
- batch_size = 100 ,
204
- train_epochs = 10 ,
205
- )
164
+ def define_mnist_eager_flags ():
165
+ """Defined flags and defaults for MNIST in eager mode."""
166
+ flags_core .define_base_eager ()
167
+ flags_core .define_image ()
168
+ flags .adopt_module_key_flags (flags_core )
169
+
170
+ flags .DEFINE_integer (
171
+ name = 'log_interval' , short_name = 'li' , default = 10 ,
172
+ help = flags_core .help_wrap ('batches between logging training status' ))
173
+
174
+ flags .DEFINE_string (
175
+ name = 'output_dir' , short_name = 'od' , default = None ,
176
+ help = flags_core .help_wrap ('Directory to write TensorBoard summaries' ))
177
+
178
+ flags .DEFINE_float (name = 'learning_rate' , short_name = 'lr' , default = 0.01 ,
179
+ help = flags_core .help_wrap ('Learning rate.' ))
180
+
181
+ flags .DEFINE_float (name = 'momentum' , short_name = 'm' , default = 0.5 ,
182
+ help = flags_core .help_wrap ('SGD momentum.' ))
183
+
184
+ flags .DEFINE_bool (name = 'no_gpu' , short_name = 'nogpu' , default = False ,
185
+ help = flags_core .help_wrap (
186
+ 'disables GPU usage even if a GPU is available' ))
187
+
188
+ flags_core .set_defaults (
189
+ data_dir = '/tmp/tensorflow/mnist/input_data' ,
190
+ model_dir = '/tmp/tensorflow/mnist/checkpoints/' ,
191
+ batch_size = 100 ,
192
+ train_epochs = 10 ,
193
+ )
206
194
207
195
if __name__ == '__main__' :
208
- main (argv = sys .argv )
196
+ define_mnist_eager_flags ()
197
+ absl_app .run (main = main )
0 commit comments