Skip to content

Commit 57b9931

Browse files
authored
Merge pull request tensorflow#3820 from joel-shor/master
Add cyclegan to open source tensorflow/models
2 parents ba2b8e0 + f3a542b commit 57b9931

File tree

10 files changed

+877
-3
lines changed

10 files changed

+877
-3
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Contains code for loading and preprocessing image data."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
22+
import numpy as np
23+
import tensorflow as tf
24+
25+
26+
def normalize_image(image):
27+
"""Rescale from range [0, 255] to [-1, 1]."""
28+
return (tf.to_float(image) - 127.5) / 127.5
29+
30+
31+
def undo_normalize_image(normalized_image):
32+
"""Convert to a numpy array that can be read by PIL."""
33+
# Convert from NHWC to HWC.
34+
normalized_image = np.squeeze(normalized_image, axis=0)
35+
return np.uint8(normalized_image * 127.5 + 127.5)
36+
37+
38+
def _sample_patch(image, patch_size):
39+
"""Crop image to square shape and resize it to `patch_size`.
40+
41+
Args:
42+
image: A 3D `Tensor` of HWC format.
43+
patch_size: A Python scalar. The output image size.
44+
45+
Returns:
46+
A 3D `Tensor` of HWC format which has the shape of
47+
[patch_size, patch_size, 3].
48+
"""
49+
image_shape = tf.shape(image)
50+
height, width = image_shape[0], image_shape[1]
51+
target_size = tf.minimum(height, width)
52+
image = tf.image.resize_image_with_crop_or_pad(image, target_size,
53+
target_size)
54+
# tf.image.resize_area only accepts 4D tensor, so expand dims first.
55+
image = tf.expand_dims(image, axis=0)
56+
image = tf.image.resize_images(image, [patch_size, patch_size])
57+
image = tf.squeeze(image, axis=0)
58+
# Force image num_channels = 3
59+
image = tf.tile(image, [1, 1, tf.maximum(1, 4 - tf.shape(image)[2])])
60+
image = tf.slice(image, [0, 0, 0], [patch_size, patch_size, 3])
61+
return image
62+
63+
64+
def full_image_to_patch(image, patch_size):
65+
image = normalize_image(image)
66+
# Sample a patch of fixed size.
67+
image_patch = _sample_patch(image, patch_size)
68+
image_patch.shape.assert_is_compatible_with([patch_size, patch_size, 3])
69+
return image_patch
70+
71+
72+
def _provide_custom_dataset(image_file_pattern,
73+
batch_size,
74+
shuffle=True,
75+
num_threads=1,
76+
patch_size=128):
77+
"""Provides batches of custom image data.
78+
79+
Args:
80+
image_file_pattern: A string of glob pattern of image files.
81+
batch_size: The number of images in each batch.
82+
shuffle: Whether to shuffle the read images. Defaults to True.
83+
num_threads: Number of prefetching threads. Defaults to 1.
84+
patch_size: Size of the path to extract from the image. Defaults to 128.
85+
86+
Returns:
87+
A float `Tensor` of shape [batch_size, patch_size, patch_size, 3]
88+
representing a batch of images.
89+
"""
90+
filename_queue = tf.train.string_input_producer(
91+
tf.train.match_filenames_once(image_file_pattern),
92+
shuffle=shuffle,
93+
capacity=5 * batch_size)
94+
image_reader = tf.WholeFileReader()
95+
96+
_, image_bytes = image_reader.read(filename_queue)
97+
image = tf.image.decode_image(image_bytes)
98+
image_patch = full_image_to_patch(image, patch_size)
99+
100+
if shuffle:
101+
return tf.train.shuffle_batch(
102+
[image_patch],
103+
batch_size=batch_size,
104+
num_threads=num_threads,
105+
capacity=5 * batch_size,
106+
min_after_dequeue=batch_size)
107+
else:
108+
return tf.train.batch(
109+
[image_patch],
110+
batch_size=batch_size,
111+
num_threads=1, # no threads so it's deterministic
112+
capacity=5 * batch_size)
113+
114+
115+
def provide_custom_datasets(image_file_patterns,
116+
batch_size,
117+
shuffle=True,
118+
num_threads=1,
119+
patch_size=128):
120+
"""Provides multiple batches of custom image data.
121+
122+
Args:
123+
image_file_patterns: A list of glob patterns of image files.
124+
batch_size: The number of images in each batch.
125+
shuffle: Whether to shuffle the read images. Defaults to True.
126+
num_threads: Number of prefetching threads. Defaults to 1.
127+
patch_size: Size of the patch to extract from the image. Defaults to 128.
128+
129+
Returns:
130+
A list of float `Tensor`s with the same size of `image_file_patterns`.
131+
Each of the `Tensor` in the list has a shape of
132+
[batch_size, patch_size, patch_size, 3] representing a batch of images.
133+
134+
Raises:
135+
ValueError: If image_file_patterns is not a list or tuple.
136+
"""
137+
if not isinstance(image_file_patterns, (list, tuple)):
138+
raise ValueError(
139+
'`image_file_patterns` should be either list or tuple, but was {}.'.
140+
format(type(image_file_patterns)))
141+
custom_datasets = []
142+
for pattern in image_file_patterns:
143+
custom_datasets.append(
144+
_provide_custom_dataset(
145+
pattern,
146+
batch_size=batch_size,
147+
shuffle=shuffle,
148+
num_threads=num_threads,
149+
patch_size=patch_size))
150+
return custom_datasets
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for data_provider."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import os
22+
23+
import numpy as np
24+
25+
import tensorflow as tf
26+
27+
import data_provider
28+
29+
mock = tf.test.mock
30+
31+
32+
class DataProviderTest(tf.test.TestCase):
33+
34+
def test_normalize_image(self):
35+
image = tf.random_uniform(shape=(8, 8, 3), maxval=256, dtype=tf.int32)
36+
rescaled_image = data_provider.normalize_image(image)
37+
self.assertEqual(tf.float32, rescaled_image.dtype)
38+
self.assertListEqual(image.shape.as_list(), rescaled_image.shape.as_list())
39+
with self.test_session(use_gpu=True) as sess:
40+
rescaled_image_out = sess.run(rescaled_image)
41+
self.assertTrue(np.all(np.abs(rescaled_image_out) <= 1.0))
42+
43+
def test_sample_patch(self):
44+
image = tf.zeros(shape=(8, 8, 3))
45+
patch1 = data_provider._sample_patch(image, 7)
46+
patch2 = data_provider._sample_patch(image, 10)
47+
image = tf.zeros(shape=(8, 8, 1))
48+
patch3 = data_provider._sample_patch(image, 10)
49+
with self.test_session(use_gpu=True) as sess:
50+
self.assertTupleEqual((7, 7, 3), sess.run(patch1).shape)
51+
self.assertTupleEqual((10, 10, 3), sess.run(patch2).shape)
52+
self.assertTupleEqual((10, 10, 3), sess.run(patch3).shape)
53+
54+
def _get_testdata_dir(self):
55+
return os.path.join(
56+
tf.flags.FLAGS.test_srcdir,
57+
'google3/third_party/tensorflow_models/gan/cyclegan/testdata')
58+
59+
def test_custom_dataset_provider(self):
60+
file_pattern = os.path.join(self._get_testdata_dir(), '*.jpg')
61+
batch_size = 3
62+
patch_size = 8
63+
images = data_provider._provide_custom_dataset(
64+
file_pattern, batch_size=batch_size, patch_size=patch_size)
65+
self.assertListEqual([batch_size, patch_size, patch_size, 3],
66+
images.shape.as_list())
67+
self.assertEqual(tf.float32, images.dtype)
68+
69+
with self.test_session(use_gpu=True) as sess:
70+
sess.run(tf.local_variables_initializer())
71+
with tf.contrib.slim.queues.QueueRunners(sess):
72+
images_out = sess.run(images)
73+
self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
74+
images_out.shape)
75+
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
76+
77+
def test_custom_datasets_provider(self):
78+
file_pattern = os.path.join(self._get_testdata_dir(), '*.jpg')
79+
batch_size = 3
80+
patch_size = 8
81+
images_list = data_provider.provide_custom_datasets(
82+
[file_pattern, file_pattern],
83+
batch_size=batch_size,
84+
patch_size=patch_size)
85+
for images in images_list:
86+
self.assertListEqual([batch_size, patch_size, patch_size, 3],
87+
images.shape.as_list())
88+
self.assertEqual(tf.float32, images.dtype)
89+
90+
with self.test_session(use_gpu=True) as sess:
91+
sess.run(tf.local_variables_initializer())
92+
with tf.contrib.slim.queues.QueueRunners(sess):
93+
images_out_list = sess.run(images_list)
94+
for images_out in images_out_list:
95+
self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
96+
images_out.shape)
97+
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
98+
99+
100+
if __name__ == '__main__':
101+
tf.test.main()

0 commit comments

Comments
 (0)