Skip to content

Commit fe3746e

Browse files
committed
Use AUTOTUNE, remove noop take, and comment fixes
1 parent eb37057 commit fe3746e

File tree

1 file changed

+10
-20
lines changed

1 file changed

+10
-20
lines changed

official/resnet/resnet_run_loop.py

+10-20
Original file line numberDiff line numberDiff line change
@@ -68,36 +68,26 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
6868
Dataset of (image, label) pairs ready for iteration.
6969
"""
7070

71-
# We prefetch a batch at a time, This can help smooth out the time taken to
72-
# load input files as we go through shuffling and processing.
71+
# Sets tf.data to AUTOTUNE, e.g. num_parallel_batches in map_and_batch.
72+
options = tf.data.Options()
73+
options.experimental_autotune = True
74+
dataset = dataset.with_options(options)
75+
76+
# Prefetches a batch at a time to smooth out the time taken to load input
77+
# files for shuffling and processing.
7378
dataset = dataset.prefetch(buffer_size=batch_size)
7479
if is_training:
75-
# Shuffle the records. Note that we shuffle before repeating to ensure
76-
# that the shuffling respects epoch boundaries.
80+
# Shuffles records before repeating to respect epoch boundaries.
7781
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
7882

79-
# If we are training over multiple epochs before evaluating, repeat the
80-
# dataset for the appropriate number of epochs.
83+
# Repeats the dataset for the number of epochs to train.
8184
dataset = dataset.repeat(num_epochs)
8285

83-
if is_training and num_gpus and examples_per_epoch:
84-
total_examples = num_epochs * examples_per_epoch
85-
# Force the number of batches to be divisible by the number of devices.
86-
# This prevents some devices from receiving batches while others do not,
87-
# which can lead to a lockup. This case will soon be handled directly by
88-
# distribution strategies, at which point this .take() operation will no
89-
# longer be needed.
90-
total_batches = total_examples // batch_size // num_gpus * num_gpus
91-
dataset.take(total_batches * batch_size)
92-
93-
# Parse the raw records into images and labels. Testing has shown that setting
94-
# num_parallel_batches > 1 produces no improvement in throughput, since
95-
# batch_size is almost always much greater than the number of CPU cores.
86+
# Parses the raw records into images and labels.
9687
dataset = dataset.apply(
9788
tf.contrib.data.map_and_batch(
9889
lambda value: parse_record_fn(value, is_training, dtype),
9990
batch_size=batch_size,
100-
num_parallel_batches=1,
10191
drop_remainder=False))
10292

10393
# Operations between the final prefetch and the get_next call to the iterator

0 commit comments

Comments
 (0)