Image segmentation with a U-Net-like architecture
Image segmentation with a U-Net-like architecture
https://keras.io/examples/vision/oxford_pets_image_segmentation/ 1/13
14/10/24, 17:36 Image segmentation with a U-Net-like architecture
Learning
Prepare paths of input images and target segmentation
Object detection with Vision Transformers masks
3D image classification from CT scans
import os
Monocular depth estimation
Metric learning for image similarity search # Display auto-contrast version of corresponding target (per-pixel categories)
using TensorFlow Similarity img = ImageOps.autocontrast(load_img(target_img_paths[9]))
display(img)
Self-supervised contrastive learning with
NNCLR
https://keras.io/examples/vision/oxford_pets_image_segmentation/ 2/13
14/10/24, 17:36 Image segmentation with a U-Net-like architecture
Knowledge Distillation
Learning to Resize
Structured Data
Timeseries
Audio Data
Reinforcement Learning
Graph Data
KerasTuner: Hyperparameter
Tuning
https://keras.io/examples/vision/oxford_pets_image_segmentation/ 3/13
14/10/24, 17:36 Image segmentation with a U-Net-like architecture
def get_dataset(
batch_size,
img_size,
input_img_paths,
target_img_paths,
max_dataset_len=None,
):
"""Returns a TF Dataset."""
target_img = tf_io.read_file(target_img_path)
target_img = tf_io.decode_png(target_img, channels=1)
target_img = tf_image.resize(target_img, img_size, method="nearest")
target_img = tf_image.convert_image_dtype(target_img, "uint8")
https://keras.io/examples/vision/oxford_pets_image_segmentation/ 4/13
14/10/24, 17:36 Image segmentation with a U-Net-like architecture
# Entry block
x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
# Project residual
residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
previous_block_activation
)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
x = layers.Activation("relu")(x)
x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling2D(2)(x)
# Project residual
residual = layers.UpSampling2D(2)(previous_block_activation)
residual = layers.Conv2D(filters, 1, padding="same")(residual)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
# Build model
model = get_model(img_size, num_classes)
model.summary()
Model: "functional_1"
https://keras.io/examples/vision/oxford_pets_image_segmentation/ 5/13
14/10/24, 17:36 Image segmentation with a U-Net-like architecture
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer │ (None, 160, 160, │ 0 │ - │
│ (InputLayer) │ 3) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d (Conv2D) │ (None, 80, 80, │ 896 │ input_layer[0][0] │
│ │ 32) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalization │ (None, 80, 80, │ 128 │ conv2d[0][0] │
│ (BatchNormalizatio… │ 32) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation │ (None, 80, 80, │ 0 │ batch_normalization… │
│ (Activation) │ 32) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_1 │ (None, 80, 80, │ 0 │ activation[0][0] │
│ (Activation) │ 32) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ separable_conv2d │ (None, 80, 80, │ 2,400 │ activation_1[0][0] │
│ (SeparableConv2D) │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 80, 80, │ 256 │ separable_conv2d[0]… │
│ (BatchNormalizatio… │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_2 │ (None, 80, 80, │ 0 │ batch_normalization… │
│ (Activation) │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ separable_conv2d_1 │ (None, 80, 80, │ 4,736 │ activation_2[0][0] │
│ (SeparableConv2D) │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 80, 80, │ 256 │ separable_conv2d_1[… │
│ (BatchNormalizatio… │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ max_pooling2d │ (None, 40, 40, │ 0 │ batch_normalization… │
│ (MaxPooling2D) │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_1 (Conv2D) │ (None, 40, 40, │ 2,112 │ activation[0][0] │
│ │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add (Add) │ (None, 40, 40, │ 0 │ max_pooling2d[0][0], │
│ │ 64) │ │ conv2d_1[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_3 │ (None, 40, 40, │ 0 │ add[0][0] │
│ (Activation) │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ separable_conv2d_2 │ (None, 40, 40, │ 8,896 │ activation_3[0][0] │
│ (SeparableConv2D) │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 40, 40, │ 512 │ separable_conv2d_2[… │
│ (BatchNormalizatio… │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_4 │ (None, 40, 40, │ 0 │ batch_normalization… │
│ (Activation) │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ separable_conv2d_3 │ (None, 40, 40, │ 17,664 │ activation_4[0][0] │
│ (SeparableConv2D) │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 40, 40, │ 512 │ separable_conv2d_3[… │
│ (BatchNormalizatio… │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ max_pooling2d_1 │ (None, 20, 20, │ 0 │ batch_normalization… │
│ (MaxPooling2D) │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_2 (Conv2D) │ (None, 20, 20, │ 8,320 │ add[0][0] │
│ │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add_1 (Add) │ (None, 20, 20, │ 0 │ max_pooling2d_1[0][… │
│ │ 128) │ │ conv2d_2[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_5 │ (None, 20, 20, │ 0 │ add_1[0][0] │
│ (Activation) │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ separable_conv2d_4 │ (None, 20, 20, │ 34,176 │ activation_5[0][0] │
│ (SeparableConv2D) │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 20, 20, │ 1,024 │ separable_conv2d_4[… │
│ (BatchNormalizatio… │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_6 │ (None, 20, 20, │ 0 │ batch_normalization… │
│ (Activation) │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ separable_conv2d_5 │ (None, 20, 20, │ 68,096 │ activation_6[0][0] │
│ (SeparableConv2D) │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 20, 20, │ 1,024 │ separable_conv2d_5[… │
│ (BatchNormalizatio… │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ max_pooling2d_2 │ (None, 10, 10, │ 0 │ batch_normalization… │
│ (MaxPooling2D) │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_3 (Conv2D) │ (None, 10, 10, │ 33,024 │ add_1[0][0] │
│ │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add_2 (Add) │ (None, 10, 10, │ 0 │ max_pooling2d_2[0][… │
│ │ 256) │ │ conv2d_3[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_7 │ (None, 10, 10, │ 0 │ add_2[0][0] │
│ (Activation) │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose │ (None, 10, 10, │ 590,080 │ activation_7[0][0] │
│ (Conv2DTranspose) │ 256) │ │ │
https://keras.io/examples/vision/oxford_pets_image_segmentation/ 6/13
14/10/24, 17:36 Image segmentation with a U-Net-like architecture
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 10, 10, │ 1,024 │ conv2d_transpose[0]… │
│ (BatchNormalizatio… │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_8 │ (None, 10, 10, │ 0 │ batch_normalization… │
│ (Activation) │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_1 │ (None, 10, 10, │ 590,080 │ activation_8[0][0] │
│ (Conv2DTranspose) │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 10, 10, │ 1,024 │ conv2d_transpose_1[… │
│ (BatchNormalizatio… │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_1 │ (None, 20, 20, │ 0 │ add_2[0][0] │
│ (UpSampling2D) │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d │ (None, 20, 20, │ 0 │ batch_normalization… │
│ (UpSampling2D) │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_4 (Conv2D) │ (None, 20, 20, │ 65,792 │ up_sampling2d_1[0][… │
│ │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add_3 (Add) │ (None, 20, 20, │ 0 │ up_sampling2d[0][0], │
│ │ 256) │ │ conv2d_4[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_9 │ (None, 20, 20, │ 0 │ add_3[0][0] │
│ (Activation) │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_2 │ (None, 20, 20, │ 295,040 │ activation_9[0][0] │
│ (Conv2DTranspose) │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 20, 20, │ 512 │ conv2d_transpose_2[… │
│ (BatchNormalizatio… │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_10 │ (None, 20, 20, │ 0 │ batch_normalization… │
│ (Activation) │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_3 │ (None, 20, 20, │ 147,584 │ activation_10[0][0] │
│ (Conv2DTranspose) │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 20, 20, │ 512 │ conv2d_transpose_3[… │
│ (BatchNormalizatio… │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_3 │ (None, 40, 40, │ 0 │ add_3[0][0] │
│ (UpSampling2D) │ 256) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_2 │ (None, 40, 40, │ 0 │ batch_normalization… │
│ (UpSampling2D) │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_5 (Conv2D) │ (None, 40, 40, │ 32,896 │ up_sampling2d_3[0][… │
│ │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add_4 (Add) │ (None, 40, 40, │ 0 │ up_sampling2d_2[0][… │
│ │ 128) │ │ conv2d_5[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_11 │ (None, 40, 40, │ 0 │ add_4[0][0] │
│ (Activation) │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_4 │ (None, 40, 40, │ 73,792 │ activation_11[0][0] │
│ (Conv2DTranspose) │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 40, 40, │ 256 │ conv2d_transpose_4[… │
│ (BatchNormalizatio… │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_12 │ (None, 40, 40, │ 0 │ batch_normalization… │
│ (Activation) │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_5 │ (None, 40, 40, │ 36,928 │ activation_12[0][0] │
│ (Conv2DTranspose) │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 40, 40, │ 256 │ conv2d_transpose_5[… │
│ (BatchNormalizatio… │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_5 │ (None, 80, 80, │ 0 │ add_4[0][0] │
│ (UpSampling2D) │ 128) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_4 │ (None, 80, 80, │ 0 │ batch_normalization… │
│ (UpSampling2D) │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_6 (Conv2D) │ (None, 80, 80, │ 8,256 │ up_sampling2d_5[0][… │
│ │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add_5 (Add) │ (None, 80, 80, │ 0 │ up_sampling2d_4[0][… │
│ │ 64) │ │ conv2d_6[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_13 │ (None, 80, 80, │ 0 │ add_5[0][0] │
│ (Activation) │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_6 │ (None, 80, 80, │ 18,464 │ activation_13[0][0] │
│ (Conv2DTranspose) │ 32) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 80, 80, │ 128 │ conv2d_transpose_6[… │
│ (BatchNormalizatio… │ 32) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_14 │ (None, 80, 80, │ 0 │ batch_normalization… │
│ (Activation) │ 32) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_7 │ (None, 80, 80, │ 9,248 │ activation_14[0][0] │
│ (Conv2DTranspose) │ 32) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 80, 80, │ 128 │ conv2d_transpose_7[… │
https://keras.io/examples/vision/oxford_pets_image_segmentation/ 7/13
14/10/24, 17:36 Image segmentation with a U-Net-like architecture
│ (BatchNormalizatio… │ 32) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_7 │ (None, 160, 160, │ 0 │ add_5[0][0] │
│ (UpSampling2D) │ 64) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_6 │ (None, 160, 160, │ 0 │ batch_normalization… │
│ (UpSampling2D) │ 32) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_7 (Conv2D) │ (None, 160, 160, │ 2,080 │ up_sampling2d_7[0][… │
│ │ 32) │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add_6 (Add) │ (None, 160, 160, │ 0 │ up_sampling2d_6[0][… │
│ │ 32) │ │ conv2d_7[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_8 (Conv2D) │ (None, 160, 160, │ 867 │ add_6[0][0] │
│ │ 3) │ │ │
└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
callbacks = [
keras.callbacks.ModelCheckpoint("oxford_segmentation.keras", save_best_only=True)
]
https://keras.io/examples/vision/oxford_pets_image_segmentation/ 8/13
14/10/24, 17:36 Image segmentation with a U-Net-like architecture
Epoch 1/50
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1700414690.172044 2226172 device_compiler.h:187] Compiled cluster using
XLA! This line is logged at most once for the lifetime of the process.
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Epoch 16/50
https://keras.io/examples/vision/oxford_pets_image_segmentation/ 10/13
14/10/24, 17:36 Image segmentation with a U-Net-like architecture
<keras.src.callbacks.history.History at 0x7fe01842dab0>
Visualize predictions
# Generate predictions for all images in the validation set
val_dataset = get_dataset(
batch_size, img_size, val_input_img_paths, val_target_img_paths
)
val_preds = model.predict(val_dataset)
def display_mask(i):
"""Quick utility to display a model's prediction."""
mask = np.argmax(val_preds[i], axis=-1)
mask = np.expand_dims(mask, axis=-1)
img = ImageOps.autocontrast(keras.utils.array_to_img(mask))
display(img)
https://keras.io/examples/vision/oxford_pets_image_segmentation/ 12/13
14/10/24, 17:36 Image segmentation with a U-Net-like architecture
Terms | Privacy
https://keras.io/examples/vision/oxford_pets_image_segmentation/ 13/13