Skip to content

Commit d0afef7

Browse files
committed
revert changes that do not seem to give a performance boost
1 parent 3164944 commit d0afef7

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

official/resnet/resnet_run_loop.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,21 @@ def process_record_dataset(dataset, is_training, global_batch_size,
6161
Returns:
6262
Dataset of (image, label) pairs ready for iteration.
6363
"""
64-
65-
# We prefetch a batch at a time, This can help smooth out the time taken to
66-
# load input files as we go through shuffling and processing.
67-
dataset = dataset.prefetch(buffer_size=global_batch_size)
6864
if is_training:
6965
# Shuffle the records. Note that we shuffle before repeating to ensure
7066
# that the shuffling respects epoch boundaries.
71-
# If we are training over multiple epochs before evaluating, repeat the
72-
# dataset for the appropriate number of epochs.
73-
# Using the fused shuffle_and_repeat method gives better performance.
74-
dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(
75-
buffer_size=shuffle_buffer, count=num_epochs))
76-
else:
77-
dataset = dataset.repeat(num_epochs)
67+
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
68+
69+
# If we are training over multiple epochs before evaluating, repeat the
70+
# dataset for the appropriate number of epochs.
71+
dataset = dataset.repeat(num_epochs)
7872

7973
# Parse the raw records into images and labels. Testing has shown that setting
8074
# num_parallel_batches > 1 produces no improvement in throughput, since
8175
# batch_size is almost always much greater than the number of CPU cores.
76+
# num_parallel_batches=num_gpus is better in presence of dedicated threads.
77+
# Otherwise, not so great.
78+
# TODO(priya): Perhaps make this conditional on the flag?
8279
dataset = dataset.apply(
8380
tf.contrib.data.map_and_batch(
8481
lambda value: parse_record_fn(value, is_training),
@@ -90,9 +87,8 @@ def process_record_dataset(dataset, is_training, global_batch_size,
9087
# will happen synchronously during run time. We prefetch here again to
9188
# background all of the above processing work and keep it out of the
9289
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
93-
# allows TensorFlow to adjust how many batches to fetch based
94-
# on how many devices are present.
95-
dataset.prefetch(buffer_size=num_gpus)
90+
# allows TensorFlow to determine the best setting.
91+
dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
9692

9793
if datasets_num_private_threads:
9894
dataset = threadpool.override_threadpool(

0 commit comments

Comments
 (0)