|
| 1 | +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Contains code for loading and preprocessing image data.""" |
| 16 | + |
| 17 | +from __future__ import absolute_import |
| 18 | +from __future__ import division |
| 19 | +from __future__ import print_function |
| 20 | + |
| 21 | + |
| 22 | +import numpy as np |
| 23 | +import tensorflow as tf |
| 24 | + |
| 25 | + |
| 26 | +def normalize_image(image): |
| 27 | + """Rescale from range [0, 255] to [-1, 1].""" |
| 28 | + return (tf.to_float(image) - 127.5) / 127.5 |
| 29 | + |
| 30 | + |
| 31 | +def undo_normalize_image(normalized_image): |
| 32 | + """Convert to a numpy array that can be read by PIL.""" |
| 33 | + # Convert from NHWC to HWC. |
| 34 | + normalized_image = np.squeeze(normalized_image, axis=0) |
| 35 | + return np.uint8(normalized_image * 127.5 + 127.5) |
| 36 | + |
| 37 | + |
| 38 | +def _sample_patch(image, patch_size): |
| 39 | + """Crop image to square shape and resize it to `patch_size`. |
| 40 | +
|
| 41 | + Args: |
| 42 | + image: A 3D `Tensor` of HWC format. |
| 43 | + patch_size: A Python scalar. The output image size. |
| 44 | +
|
| 45 | + Returns: |
| 46 | + A 3D `Tensor` of HWC format which has the shape of |
| 47 | + [patch_size, patch_size, 3]. |
| 48 | + """ |
| 49 | + image_shape = tf.shape(image) |
| 50 | + height, width = image_shape[0], image_shape[1] |
| 51 | + target_size = tf.minimum(height, width) |
| 52 | + image = tf.image.resize_image_with_crop_or_pad(image, target_size, |
| 53 | + target_size) |
| 54 | + # tf.image.resize_area only accepts 4D tensor, so expand dims first. |
| 55 | + image = tf.expand_dims(image, axis=0) |
| 56 | + image = tf.image.resize_images(image, [patch_size, patch_size]) |
| 57 | + image = tf.squeeze(image, axis=0) |
| 58 | + # Force image num_channels = 3 |
| 59 | + image = tf.tile(image, [1, 1, tf.maximum(1, 4 - tf.shape(image)[2])]) |
| 60 | + image = tf.slice(image, [0, 0, 0], [patch_size, patch_size, 3]) |
| 61 | + return image |
| 62 | + |
| 63 | + |
| 64 | +def full_image_to_patch(image, patch_size): |
| 65 | + image = normalize_image(image) |
| 66 | + # Sample a patch of fixed size. |
| 67 | + image_patch = _sample_patch(image, patch_size) |
| 68 | + image_patch.shape.assert_is_compatible_with([patch_size, patch_size, 3]) |
| 69 | + return image_patch |
| 70 | + |
| 71 | + |
| 72 | +def _provide_custom_dataset(image_file_pattern, |
| 73 | + batch_size, |
| 74 | + shuffle=True, |
| 75 | + num_threads=1, |
| 76 | + patch_size=128): |
| 77 | + """Provides batches of custom image data. |
| 78 | +
|
| 79 | + Args: |
| 80 | + image_file_pattern: A string of glob pattern of image files. |
| 81 | + batch_size: The number of images in each batch. |
| 82 | + shuffle: Whether to shuffle the read images. Defaults to True. |
| 83 | + num_threads: Number of prefetching threads. Defaults to 1. |
| 84 | + patch_size: Size of the path to extract from the image. Defaults to 128. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + A float `Tensor` of shape [batch_size, patch_size, patch_size, 3] |
| 88 | + representing a batch of images. |
| 89 | + """ |
| 90 | + filename_queue = tf.train.string_input_producer( |
| 91 | + tf.train.match_filenames_once(image_file_pattern), |
| 92 | + shuffle=shuffle, |
| 93 | + capacity=5 * batch_size) |
| 94 | + image_reader = tf.WholeFileReader() |
| 95 | + |
| 96 | + _, image_bytes = image_reader.read(filename_queue) |
| 97 | + image = tf.image.decode_image(image_bytes) |
| 98 | + image_patch = full_image_to_patch(image, patch_size) |
| 99 | + |
| 100 | + if shuffle: |
| 101 | + return tf.train.shuffle_batch( |
| 102 | + [image_patch], |
| 103 | + batch_size=batch_size, |
| 104 | + num_threads=num_threads, |
| 105 | + capacity=5 * batch_size, |
| 106 | + min_after_dequeue=batch_size) |
| 107 | + else: |
| 108 | + return tf.train.batch( |
| 109 | + [image_patch], |
| 110 | + batch_size=batch_size, |
| 111 | + num_threads=1, # no threads so it's deterministic |
| 112 | + capacity=5 * batch_size) |
| 113 | + |
| 114 | + |
| 115 | +def provide_custom_datasets(image_file_patterns, |
| 116 | + batch_size, |
| 117 | + shuffle=True, |
| 118 | + num_threads=1, |
| 119 | + patch_size=128): |
| 120 | + """Provides multiple batches of custom image data. |
| 121 | +
|
| 122 | + Args: |
| 123 | + image_file_patterns: A list of glob patterns of image files. |
| 124 | + batch_size: The number of images in each batch. |
| 125 | + shuffle: Whether to shuffle the read images. Defaults to True. |
| 126 | + num_threads: Number of prefetching threads. Defaults to 1. |
| 127 | + patch_size: Size of the patch to extract from the image. Defaults to 128. |
| 128 | +
|
| 129 | + Returns: |
| 130 | + A list of float `Tensor`s with the same size of `image_file_patterns`. |
| 131 | + Each of the `Tensor` in the list has a shape of |
| 132 | + [batch_size, patch_size, patch_size, 3] representing a batch of images. |
| 133 | +
|
| 134 | + Raises: |
| 135 | + ValueError: If image_file_patterns is not a list or tuple. |
| 136 | + """ |
| 137 | + if not isinstance(image_file_patterns, (list, tuple)): |
| 138 | + raise ValueError( |
| 139 | + '`image_file_patterns` should be either list or tuple, but was {}.'. |
| 140 | + format(type(image_file_patterns))) |
| 141 | + custom_datasets = [] |
| 142 | + for pattern in image_file_patterns: |
| 143 | + custom_datasets.append( |
| 144 | + _provide_custom_dataset( |
| 145 | + pattern, |
| 146 | + batch_size=batch_size, |
| 147 | + shuffle=shuffle, |
| 148 | + num_threads=num_threads, |
| 149 | + patch_size=patch_size)) |
| 150 | + return custom_datasets |
0 commit comments