Skip to content

Commit 3164944

Browse files
committed
Fix tests and adjust input functions
1 parent 2b47f77 commit 3164944

File tree

6 files changed

+35
-27
lines changed

6 files changed

+35
-27
lines changed

official/resnet/cifar10_main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,16 @@ def preprocess_image(image, is_training):
106106
return image
107107

108108

109-
def input_fn(is_training, data_dir, global_batch_size, num_epochs=1,
110-
num_gpus=1, datasets_num_private_threads=None):
109+
def input_fn(is_training, data_dir, global_batch_size,
110+
num_gpus, num_epochs=1, datasets_num_private_threads=None):
111111
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
112112
113113
Args:
114114
is_training: A boolean denoting whether the input is for training.
115115
data_dir: The directory containing the input data.
116116
global_batch_size: The number of samples per batch.
117-
num_epochs: The number of epochs to repeat the dataset.
118117
num_gpus: The number of GPUs.
118+
num_epochs: The number of epochs to repeat the dataset.
119119
datasets_num_private_threads: Number of threads for a private
120120
threadpool created for all datasets computation.
121121
@@ -128,7 +128,7 @@ def input_fn(is_training, data_dir, global_batch_size, num_epochs=1,
128128

129129
return resnet_run_loop.process_record_dataset(
130130
dataset, is_training, global_batch_size, _NUM_IMAGES['train'],
131-
parse_record, num_epochs, num_gpus, datasets_num_private_threads
131+
parse_record, num_gpus, num_epochs, datasets_num_private_threads
132132
)
133133

134134

official/resnet/cifar10_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,21 @@ def test_dataset_input_fn(self):
6464
lambda val: cifar10_main.parse_record(val, False))
6565
image, label = fake_dataset.make_one_shot_iterator().get_next()
6666

67-
self.assertAllEqual(label.shape, (10,))
67+
self.assertAllEqual(label.shape, ())
6868
self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS))
6969

7070
with self.test_session() as sess:
7171
image, label = sess.run([image, label])
7272

73-
self.assertAllEqual(label, np.array([int(i == 7) for i in range(10)]))
73+
self.assertEqual(label, 7)
7474

7575
for row in image:
7676
for pixel in row:
7777
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
7878

7979
def cifar10_model_fn_helper(self, mode, resnet_version, dtype):
8080
input_fn = cifar10_main.get_synth_input_fn()
81-
dataset = input_fn(True, '', _BATCH_SIZE)
81+
dataset = input_fn(True, '', _BATCH_SIZE, 1)
8282
iterator = dataset.make_one_shot_iterator()
8383
features, labels = iterator.get_next()
8484
spec = cifar10_main.cifar10_model_fn(

official/resnet/imagenet_main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,16 +155,16 @@ def parse_record(raw_record, is_training):
155155
return image, label
156156

157157

158-
def input_fn(is_training, data_dir, global_batch_size, num_epochs=1,
159-
num_gpus=1, datasets_num_private_threads=None):
158+
def input_fn(is_training, data_dir, global_batch_size,
159+
num_gpus, num_epochs=1, datasets_num_private_threads=None):
160160
"""Input function which provides batches for train or eval.
161161
162162
Args:
163163
is_training: A boolean denoting whether the input is for training.
164164
data_dir: The directory containing the input data.
165165
global_batch_size: The number of samples per batch.
166-
num_epochs: The number of epochs to repeat the dataset.
167166
num_gpus: The number of GPUs.
167+
num_epochs: The number of epochs to repeat the dataset.
168168
datasets_num_private_threads: Number of threads for a private
169169
threadpool created for all datasets computation.
170170
@@ -184,7 +184,7 @@ def input_fn(is_training, data_dir, global_batch_size, num_epochs=1,
184184

185185
return resnet_run_loop.process_record_dataset(
186186
dataset, is_training, global_batch_size, _SHUFFLE_BUFFER, parse_record,
187-
num_epochs, num_gpus, datasets_num_private_threads
187+
num_gpus, num_epochs, datasets_num_private_threads
188188
)
189189

190190

official/resnet/imagenet_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def resnet_model_fn_helper(self, mode, resnet_version, dtype):
192192
tf.train.create_global_step()
193193

194194
input_fn = imagenet_main.get_synth_input_fn()
195-
dataset = input_fn(True, '', _BATCH_SIZE)
195+
dataset = input_fn(True, '', _BATCH_SIZE, 1)
196196
iterator = dataset.make_one_shot_iterator()
197197
features, labels = iterator.get_next()
198198
spec = imagenet_main.imagenet_model_fn(

official/resnet/resnet_run_loop.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
# Functions for input processing.
4141
################################################################################
4242
def process_record_dataset(dataset, is_training, global_batch_size,
43-
shuffle_buffer, parse_record_fn, num_epochs=1,
44-
num_gpus=1, datasets_num_private_threads=None):
43+
shuffle_buffer, parse_record_fn, num_gpus,
44+
num_epochs=1, datasets_num_private_threads=None):
4545
"""Given a Dataset with raw records, return an iterator over the records.
4646
4747
Args:
@@ -53,8 +53,8 @@ def process_record_dataset(dataset, is_training, global_batch_size,
5353
time and use less memory.
5454
parse_record_fn: A function that takes a raw record and returns the
5555
corresponding (image, label) pair.
56-
num_epochs: The number of epochs to repeat the dataset.
5756
num_gpus: The number of GPUs.
57+
num_epochs: The number of epochs to repeat the dataset.
5858
datasets_num_private_threads: Number of threads for a private
5959
threadpool created for all datasets computation.
6060
@@ -121,7 +121,9 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
121121
An input_fn that can be used in place of a real one to return a dataset
122122
that can be used for iteration.
123123
"""
124-
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument
124+
def input_fn(is_training, data_dir, global_batch_size, num_gpus,
125+
*args, **kwargs): # pylint: disable=unused-argument
126+
batch_size=per_device_batch_size(global_batch_size, num_gpus)
125127
images = tf.zeros((batch_size, height, width, num_channels), tf.float32)
126128
labels = tf.zeros((batch_size), tf.int32)
127129
return tf.data.Dataset.from_tensors((images, labels)).repeat()
@@ -366,7 +368,7 @@ def resnet_main(
366368
# Using the Winograd non-fused algorithms provides a small performance boost.
367369
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
368370
os.environ['TF_GPU_THREAD_MODE'] = flags_obj.tf_gpu_thread_mode
369-
os.environ['TF_GPU_THREAD_COUNT'] = flags_obj.tf_gpu_thread_count
371+
os.environ['TF_GPU_THREAD_COUNT'] = str(flags_obj.tf_gpu_thread_count)
370372

371373

372374
# Create session config based on values of inter_op_parallelism_threads and
@@ -378,13 +380,15 @@ def resnet_main(
378380
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
379381
allow_soft_placement=True)
380382

381-
if flags_core.get_num_gpus(flags_obj) == 0:
383+
num_gpus = flags_core.get_num_gpus(flags_obj)
384+
385+
if num_gpus == 0:
382386
distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
383-
elif flags_core.get_num_gpus(flags_obj) == 1:
387+
elif num_gpus == 1:
384388
distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
385389
else:
386390
distribution = tf.contrib.distribute.MirroredStrategy(
387-
num_gpus=flags_core.get_num_gpus(flags_obj)
391+
num_gpus=num_gpus
388392
)
389393

390394
run_config = tf.estimator.RunConfig(train_distribute=distribution,
@@ -419,17 +423,21 @@ def resnet_main(
419423

420424
def input_fn_train():
421425
return input_function(
422-
is_training=True, data_dir=flags_obj.data_dir,
426+
is_training=True,
427+
data_dir=flags_obj.data_dir,
423428
global_batch_size=flags_obj.batch_size,
429+
num_gpus=num_gpus,
424430
num_epochs=flags_obj.epochs_between_evals,
425-
num_gpus=flags_core.get_num_gpus(flags_obj))
431+
datasets_num_private_threads=flags_obj.datasets_num_private_threads)
426432

427433
def input_fn_eval():
428434
return input_function(
429-
is_training=False, data_dir=flags_obj.data_dir,
435+
is_training=False,
436+
data_dir=flags_obj.data_dir,
430437
global_batch_size=flags_obj.batch_size,
438+
num_gpus=num_gpus,
431439
num_epochs=1,
432-
num_gpus=flags_core.get_num_gpus(flags_obj))
440+
datasets_num_private_threads=flags_obj.datasets_num_private_threads)
433441

434442

435443
total_training_cycle = (flags_obj.train_epochs //

official/utils/flags/_performance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
135135
flags.DEFINE_string(
136136
name="tf_gpu_thread_mode", short_name="gt_mode", default="global",
137137
help=help_wrap(
138-
"Whether and how the GPU device uses its own threadpool.")
138+
"Whether and how the GPU device uses its own threadpool.")
139139
)
140140

141141
if tf_gpu_thread_count:
@@ -149,8 +149,8 @@ def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
149149
name="datasets_num_private_threads", short_name="dataset_thread_count",
150150
default=None,
151151
help=help_wrap(
152-
"Number of threads for a private threadpool created for all"
153-
"datasets computation..")
152+
"Number of threads for a private threadpool created for all"
153+
"datasets computation..")
154154
)
155155

156156
return key_flags

0 commit comments

Comments
 (0)