@@ -68,36 +68,26 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
68
68
Dataset of (image, label) pairs ready for iteration.
69
69
"""
70
70
71
- # We prefetch a batch at a time, This can help smooth out the time taken to
72
- # load input files as we go through shuffling and processing.
71
+ # Sets tf.data to AUTOTUNE, e.g. num_parallel_batches in map_and_batch.
72
+ options = tf .data .Options ()
73
+ options .experimental_autotune = True
74
+ dataset = dataset .with_options (options )
75
+
76
+ # Prefetches a batch at a time to smooth out the time taken to load input
77
+ # files for shuffling and processing.
73
78
dataset = dataset .prefetch (buffer_size = batch_size )
74
79
if is_training :
75
- # Shuffle the records. Note that we shuffle before repeating to ensure
76
- # that the shuffling respects epoch boundaries.
80
+ # Shuffles records before repeating to respect epoch boundaries.
77
81
dataset = dataset .shuffle (buffer_size = shuffle_buffer )
78
82
79
- # If we are training over multiple epochs before evaluating, repeat the
80
- # dataset for the appropriate number of epochs.
83
+ # Repeats the dataset for the number of epochs to train.
81
84
dataset = dataset .repeat (num_epochs )
82
85
83
- if is_training and num_gpus and examples_per_epoch :
84
- total_examples = num_epochs * examples_per_epoch
85
- # Force the number of batches to be divisible by the number of devices.
86
- # This prevents some devices from receiving batches while others do not,
87
- # which can lead to a lockup. This case will soon be handled directly by
88
- # distribution strategies, at which point this .take() operation will no
89
- # longer be needed.
90
- total_batches = total_examples // batch_size // num_gpus * num_gpus
91
- dataset .take (total_batches * batch_size )
92
-
93
- # Parse the raw records into images and labels. Testing has shown that setting
94
- # num_parallel_batches > 1 produces no improvement in throughput, since
95
- # batch_size is almost always much greater than the number of CPU cores.
86
+ # Parses the raw records into images and labels.
96
87
dataset = dataset .apply (
97
88
tf .contrib .data .map_and_batch (
98
89
lambda value : parse_record_fn (value , is_training , dtype ),
99
90
batch_size = batch_size ,
100
- num_parallel_batches = 1 ,
101
91
drop_remainder = False ))
102
92
103
93
# Operations between the final prefetch and the get_next call to the iterator
0 commit comments