fbpx
image segmentation
Piotr Rudzki Updated: 10 Jan 2023 7 min to read

U-Net for Image Segmentation – Architecure Implementation & Code Example

It is a process of partitioning images into sets of pixels (segments) that correspond to objects on the image. This means that we’re essentially classifying each pixel in the image as either belonging to a specific object or not (of course, this object could be just background).

At this point, image segmentation may feel like object detection with extra steps where instead of just finding the objects we want, we have to find their every pixel.

But it’s not the case. Using image segmentation allows us to simplify the image to a representation that is “digestible” by other algorithms (binary mask instead of RGB image) while retaining information about objects’ sizes, shapes, and spatial position, which is beneficial in cases such as medical imaging or self-driving cars (but more on that later).

Applications of Image Segmentation

  • medical imaging – i.e., segmenting healthy and tumor tissue on histopathological images allows for quick assessment of patient’s cancer stage or automatic finding and marking bone fractures on RTG images
  • self-driving cars – autonomous vehicles can only be as good as they are perceptive, and that requires a system that can find exact locations of pedestrians, obstacles, road signs, and other vehicles on images
  • aerial images for civil engineering and/or agriculture – image segmentation can help with estimating building progress on construction sites or monitoring the health of crops by analyzing aerial images

Image Segmentation in Python is extensively used across various industries to enhance business efficiency. In retail, it enables automated inventory management by distinguishing product types and tracking stock levels. By implementing this technology, businesses can automate complex processes, reduce error rates, and improve operational efficiency, ultimately leading to better customer outcomes and cost savings. Mike Jackowski COO, ASPER BROTHERS Contact Me

Neural Networks for image segmentation

Even though some interesting non-learning algorithms were used for image segmentation, such as the Otsu method or Watershed algorithm, most of today’s real live segmentation problems are solved by training neural networks [NNs]. More precisely, convolutional neural networks [CNNs]. And even more specifically, some kind of encoder-decoder CNNs. One of the most notable architectures of such networks is [U-Net]() (the name comes from its shape) (img 1). U-net consists of 2 parts:

  • encoder (left part of a “U”) – encodes image into an abstract representation of image features by applying a sequence of convolutional blocks that gradually decrease representation’s height and width but an increasing number of channels that correspond to image features.
  • decoder (right part of a “U”) – decodes image representation into a binary mask by applying a sequence of up-convolutions (NOT the same as deconvolution) that gradually increase representation’s height and width to the size of the original image and decreases the number of channels to the number of classes that we are segmenting

Additionally, U-Net implements skip connections that connect corresponding levels of encoder and decoder. They allow the model not to “lose” features extracted by earlier blocks of an encoder, which increases segmentation performance.

 

U-net

 

Implementation

As for the demonstration part of the article, we’ll implement and train U-Net architecture using PyTorch and segmentation_models_pytorch to segment building on aerial images.

Setup

For our project, we’ll be needing the following packages installed:

  • PyTorch
  • PyTorch-segmentation-models
  • PIL
  • opencv
  • albumentations
  • pandas
  • matplotlib

As for data, we need to download and unpack this dataset.

When we have everything in place, we can start by importing everything that we will need during an implementation

import os
from typing import Tuple, List

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import albumentations as album
import segmentation_models_pytorch as smp

Additionally, we have to specify what device we will be training a U-Net on:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

I strongly recommend using `cuda`. Otherwise, training will be very time-consuming. If you don’t have access to a GPU machine, consider using a free one on collab or Kaggle

Finally, we will define a path to our dataset and load CSV with a map for our classes:

DATA_DIR = "/kaggle/input/massachusetts-buildings-dataset/tiff"
class_dict = pd.read_csv("/kaggle/input/massachusetts-buildings-dataset/label_class_dict.csv", index_col=0)

Preparing data

One of the fundamental processes that we can easily apply to avoid model overfitting is an augmentation of training data. Additionally, since CNNs are input shape agnostic, meaning that CNN trained with one shape can be used for different shapes, we can train the model on a set of smaller parts of images to better utilize computing resources. Of course, we want to avoid these augmentations for validation to get a clear view of our model’s performance.

def get_training_augmentation():
train_transform = [
album.RandomCrop(height=256, width=256, always_apply=True),
album.OneOf(
[
album.HorizontalFlip(p=1),
album.VerticalFlip(p=1),
album.RandomRotate90(p=1),
],
p=0.75,
),
]
return album.Compose(train_transform)

def get_validation_augmentation():
test_transform = [
album.PadIfNeeded(min_height=1536, min_width=1536, always_apply=True, border_mode=0)
]
return album.Compose(test_transform)

Now we can implement a Dataset class that will load images and masks from disk, convert masks from RGB to binary masks, and applies specified augmentations.

 

def encode_mask(mask, df_labels):
channels = []
for c in df_labels.index:
rgb = torch.tensor(df_labels.loc[c].to_list()).view(-1, 1, 1)
_mask = torch.all(mask == rgb, dim=0).float()
channels.append(_mask)
return torch.stack(channels, dim=0)

class BuildingDataset(Dataset):

def __init__(self, split: str, data_dir: str, df_labels: pd.DataFrame, augmentation=None):
self.img_dir = os.path.join(data_dir, split)
self.mask_dir = os.path.join(data_dir, split + "_labels")

self.sample_names = os.listdir(self.img_dir)

self.df_labels = df_labels
self.augmentation = augmentation

def __len__(self):
return len(self.sample_names)

def __getitem__(self, idx):
sample_name = self.sample_names[idx]
img_path = os.path.join(self.img_dir, sample_name)
mask_path = os.path.join(self.mask_dir, sample_name[:-1])

img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) / 255
mask = cv2.cvtColor(cv2.imread(mask_path), cv2.COLOR_BGR2RGB)
if self.augmentation:
sample = self.augmentation(image=img, mask=mask)
img, mask = sample['image'], sample['mask']

img = torch.tensor(img.transpose(2, 0, 1).astype('float32'))
mask = torch.tensor(mask.transpose(2, 0, 1).astype('float32'))
mask = encode_mask(mask, self.df_labels)

return img, mask
train_dataset = BuildingDataset("train", DATA_DIR, class_dict, augmentation=get_training_augmentation())
val_dataset = BuildingDataset("val", DATA_DIR, class_dict, augmentation=get_validation_augmentation())
test_dataset = BuildingDataset("test", DATA_DIR, class_dict, augmentation=get_validation_augmentation())

Below you can see some example training images and their masks.

 

train images and masks

 

Model implementation

Now we can implement a model. We’ll start with implementing a single convolutional block that’ll handle the core of computations.

class ConvBlock(nn.Module):

def __init__(self, in_channels: int, out_channels: int, n_convs: int = 2,
kernel_size: int = 3, padding: int = 1) -> None:
super(ConvBlock, self).__init__()
_in_channels = [in_channels] + [out_channels] * (n_convs - 2)
self.model = nn.Sequential(*[self._get_sngle_block(ic, out_channels, kernel_size,
padding) for ic in _in_channels])

def _get_sngle_block(self, in_channels: int, out_channels: int, kernel_size: int = 3,
padding: int = 1) -> nn.Sequential:
return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size,
padding=padding),
nn.BatchNorm2d(out_channels),
nn.ReLU())

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)

Here we can implement a Down block that will apply convolution and return 2 tensors: basic convolution output after max pool for size reduction and skip connection tensor of size equal to the input size.

class DownBlock(nn.Module):

def __init__(self, in_channels: int, out_channels: int, n_convs: int = 2,
kernel_size: int = 3, padding: int = 1) -> None:
super(DownBlock, self).__init__()
self.conv = ConvBlock(in_channels, out_channels, n_convs, kernel_size, padding)
self.down_sample = nn.MaxPool2d(2)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
skipped_x = self.conv(x)
x = self.down_sample(skipped_x)
return x, skipped_x

Next in line is the Up block that will accept 2 tensors: output of the previous block that Up block will up sample and output from corresponding Down block. The up block will concatenate these inputs and apply convolution to them.

class UpBlock(nn.Module):

def __init__(self, in_channels: int, out_channels: int, n_convs: int = 2,
kernel_size: int = 3, padding: int = 1) -> None:
super(UpBlock, self).__init__()
self.up_sample = nn.ConvTranspose2d(in_channels-out_channels, in_channels-out_channels,
kernel_size=2, stride=2)
self.conv = ConvBlock(in_channels, out_channels, n_convs, kernel_size, padding)

def forward(self, x: torch.Tensor, skipped_x: torch.Tensor) -> torch.Tensor:
x = self.up_sample(x)
x = torch.cat([x, skipped_x], dim=1)
x = self.conv(x)
return x

Finally we can put it all together in a U-Net instance:

class UNet(nn.Module):

def __init__(self, in_channels: int = 3, out_classes: int = 2) -> None:
super(UNet, self).__init__()
self.down_0 = DownBlock(in_channels, 64)
self.down_1 = DownBlock(64, 128)
self.down_2 = DownBlock(128, 256)
self.down_3 = DownBlock(256, 512)
self.bottleneck = ConvBlock(512, 1024)
self.up_0 = UpBlock(1024 + 512, 512)
self.up_1 = UpBlock(512 + 256, 256)
self.up_2 = UpBlock(256 + 128, 128)
self.up_3 = UpBlock(128 + 64, 64)
self.final_conv = nn.Conv2d(64, out_classes, kernel_size=1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x, skipped_x_0 = self.down_0(x)
x, skipped_x_1 = self.down_1(x)
x, skipped_x_2 = self.down_2(x)
x, skipped_x_3 = self.down_3(x)
x = self.bottleneck(x)
x = self.up_0(x, skipped_x_3)
x = self.up_1(x, skipped_x_2)
x = self.up_2(x, skipped_x_1)
x = self.up_3(x, skipped_x_0)
return self.final_conv(x)

Training and Evaluation

At last, we can train our model, and for convenience sake, we will utilize segmentation_models_pytorch to do so. Note that it is completely viable to use raw PyTorchsmp just saves us some hassle with implementing metrics, loss functions, and training loops ourselves.

Firstly we will initialize our model, define training hyperparameters (feel free to tweak and play with them), and initialize loss function, metric, optimizer, and data loaders.

model = UNet(out_classes=2)
n_epochs = 15
batch_size = 32
lr = 5e-5

loss = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(threshold=0.5)]
optimizer = torch.optim.Adam([
dict(params=model.parameters(), lr=lr)
])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(val_dataset, batch_size=1)

Secondly, we will leverage smp‘s functionality to initialize instances of training and validation epochs.

train_epoch = smp.utils.train.TrainEpoch(
model,
loss=loss,
metrics=metrics,
optimizer=optimizer,
device=device,
verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=device,
verbose=True,
)

And finally, we can train our model. Below you can see output from my example run:

for i in range(0, n_epochs):
print(f'\nEpoch: {i}')
train_logs = train_epoch.run(train_loader)
valid_logs = valid_epoch.run(validation_loader)

image3

To test our model and ensure that we didn’t overfit, we can run a single validation epoch with a newly initialized test data loader

 

test_loader = DataLoader(test_dataset, batch_size=1)
test_logs = valid_epoch.run(test_loader)

 

image4

 

Below you can see how the model works on a test set:

 

segmentation

 

segmnentation image

 

Conclusion

Image segmentation continues to be one of the most critical research areas in computer vision. Today it’s dominated by Convolutional Neural Networks as they allow us to push further and further the scope of what is possible in computer vision. Additionally, the example above proves that training a well-performing model is relatively straightforward if you have a suitable dataset and doesn’t require that many resources.

 

Call to action
Need support from Python specialists to implement solutions in your application? Our skilled developers will be happy to share their knowledge and experience.

 

avatar

Piotr Rudzki

Machine Learning Engineer

Share

SUBSCRIBE our NEWSLETTER

Are you interested in news from the world of software development? Subscribe to our newsletter and receive a list of the most interesting information.

    ADD COMMENT

    Download our Free GPT Prompt Template

    Develop the ideal tech stack for your Image Segmentation solution.

      RELATED articles