Skip to content

Commit 84db377

Browse files
committed
Remove one hot labels, Add drop_remainder to batch, Use parallel interleve in imagenet dataset.
1 parent e79232f commit 84db377

File tree

4 files changed

+13
-12
lines changed

4 files changed

+13
-12
lines changed

official/resnet/cifar10_main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def parse_record(raw_record, is_training):
7373
# The first byte represents the label, which we convert from uint8 to int32
7474
# and then to one-hot.
7575
label = tf.cast(record_vector[0], tf.int32)
76-
label = tf.one_hot(label, _NUM_CLASSES)
7776

7877
# The remaining bytes after the label represent the image, which we reshape
7978
# from [depth * height * width] to [depth, height, width].

official/resnet/cifar10_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ def test_dataset_input_fn(self):
6464
lambda val: cifar10_main.parse_record(val, False))
6565
image, label = fake_dataset.make_one_shot_iterator().get_next()
6666

67-
self.assertAllEqual(label.shape, (10,))
67+
self.assertAllEqual(label.shape, ())
6868
self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS))
6969

7070
with self.test_session() as sess:
7171
image, label = sess.run([image, label])
7272

73-
self.assertAllEqual(label, np.array([int(i == 7) for i in range(10)]))
73+
self.assertEqual(label, 7)
7474

7575
for row in image:
7676
for pixel in row:

official/resnet/imagenet_main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
}
4040

4141
_NUM_TRAIN_FILES = 1024
42-
_SHUFFLE_BUFFER = 1500
42+
_SHUFFLE_BUFFER = 10000
4343

4444
DATASET_NAME = 'ImageNet'
4545

@@ -152,8 +152,6 @@ def parse_record(raw_record, is_training):
152152
num_channels=_NUM_CHANNELS,
153153
is_training=is_training)
154154

155-
label = tf.one_hot(tf.reshape(label, shape=[]), _NUM_CLASSES)
156-
157155
return image, label
158156

159157

@@ -177,6 +175,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
177175
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
178176

179177
# Convert to individual records
178+
# TODO(guptapriya): Should we make this cycle_length a flag similar to
179+
# num_parallel_calls?
180+
dataset = dataset.apply(tf.contrib.data.parallel_interleave(
181+
tf.data.TFRecordDataset, cycle_length=10))
180182
dataset = dataset.flat_map(tf.data.TFRecordDataset)
181183

182184
return resnet_run_loop.process_record_dataset(

official/resnet/resnet_run_loop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
7979
tf.contrib.data.map_and_batch(
8080
lambda value: parse_record_fn(value, is_training),
8181
batch_size=batch_size,
82-
num_parallel_batches=1))
82+
num_parallel_batches=1,
83+
drop_remainder=True))
8384

8485
# Operations between the final prefetch and the get_next call to the iterator
8586
# will happen synchronously during run time. We prefetch here again to
@@ -111,7 +112,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
111112
"""
112113
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument
113114
images = tf.zeros((batch_size, height, width, num_channels), tf.float32)
114-
labels = tf.zeros((batch_size, num_classes), tf.int32)
115+
labels = tf.zeros((batch_size), tf.int32)
115116
return tf.data.Dataset.from_tensors((images, labels)).repeat()
116117

117118
return input_fn
@@ -227,8 +228,8 @@ def resnet_model_fn(features, labels, mode, model_class,
227228
})
228229

229230
# Calculate loss, which includes softmax cross entropy and L2 regularization.
230-
cross_entropy = tf.losses.softmax_cross_entropy(
231-
logits=logits, onehot_labels=labels)
231+
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
232+
logits=logits, labels=labels)
232233

233234
# Create a tensor named cross_entropy for logging purposes.
234235
tf.identity(cross_entropy, name='cross_entropy')
@@ -282,8 +283,7 @@ def exclude_batch_norm(name):
282283
train_op = None
283284

284285
if not tf.contrib.distribute.has_distribution_strategy():
285-
accuracy = tf.metrics.accuracy(
286-
tf.argmax(labels, axis=1), predictions['classes'])
286+
accuracy = tf.metrics.accuracy(labels, predictions['classes'])
287287
else:
288288
# Metrics are currently not compatible with distribution strategies during
289289
# training. This does not affect the overall performance of the model.

0 commit comments

Comments
 (0)