@@ -61,24 +61,21 @@ def process_record_dataset(dataset, is_training, global_batch_size,
61
61
Returns:
62
62
Dataset of (image, label) pairs ready for iteration.
63
63
"""
64
-
65
- # We prefetch a batch at a time, This can help smooth out the time taken to
66
- # load input files as we go through shuffling and processing.
67
- dataset = dataset .prefetch (buffer_size = global_batch_size )
68
64
if is_training :
69
65
# Shuffle the records. Note that we shuffle before repeating to ensure
70
66
# that the shuffling respects epoch boundaries.
71
- # If we are training over multiple epochs before evaluating, repeat the
72
- # dataset for the appropriate number of epochs.
73
- # Using the fused shuffle_and_repeat method gives better performance.
74
- dataset = dataset .apply (tf .contrib .data .shuffle_and_repeat (
75
- buffer_size = shuffle_buffer , count = num_epochs ))
76
- else :
77
- dataset = dataset .repeat (num_epochs )
67
+ dataset = dataset .shuffle (buffer_size = shuffle_buffer )
68
+
69
+ # If we are training over multiple epochs before evaluating, repeat the
70
+ # dataset for the appropriate number of epochs.
71
+ dataset = dataset .repeat (num_epochs )
78
72
79
73
# Parse the raw records into images and labels. Testing has shown that setting
80
74
# num_parallel_batches > 1 produces no improvement in throughput, since
81
75
# batch_size is almost always much greater than the number of CPU cores.
76
+ # num_parallel_batches=num_gpus is better in presence of dedicated threads.
77
+ # Otherwise, not so great.
78
+ # TODO(priya): Perhaps make this conditional on the flag?
82
79
dataset = dataset .apply (
83
80
tf .contrib .data .map_and_batch (
84
81
lambda value : parse_record_fn (value , is_training ),
@@ -90,9 +87,8 @@ def process_record_dataset(dataset, is_training, global_batch_size,
90
87
# will happen synchronously during run time. We prefetch here again to
91
88
# background all of the above processing work and keep it out of the
92
89
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
93
- # allows TensorFlow to adjust how many batches to fetch based
94
- # on how many devices are present.
95
- dataset .prefetch (buffer_size = num_gpus )
90
+ # allows TensorFlow to determine the best setting.
91
+ dataset .prefetch (buffer_size = tf .contrib .data .AUTOTUNE )
96
92
97
93
if datasets_num_private_threads :
98
94
dataset = threadpool .override_threadpool (
0 commit comments