终于把 Unet 算法搞懂了!!

人工智能
UNet 的成功源于其有效的特征提取与恢复机制,特别是跳跃连接的设计,使得编码过程中丢失的细节能够通过解码阶段恢复。UNet 在医学图像分割等任务上有着广泛的应用,能够生成高精度的像素级分割结果。

今天给大家分享一个超强的算法模型,Unet

UNet 是一种经典的卷积神经网络(CNN)架构,最初由 Olaf Ronneberger 等人在 2015 年提出,专为生物医学图像分割设计。

它的独特之处在于其编码器-解码器对称结构,能够有效地在多尺度上提取特征并生成精确的像素级分割结果。

UNet 算法在图像分割任务中表现优异,尤其是在需要精细边界的场景中广泛应用,如医学影像分割、卫星图像分割等。

图片图片

UNet 架构

UNet 模型由两部分组成:编码器和解码器,中间通过跳跃连接(Skip Connections)相连。

UNet 的设计理念是将输入图像经过一系列卷积和下采样操作逐渐提取高层次特征(编码路径),然后通过上采样逐步恢复原始的分辨率(解码路径),并将编码路径中对应的特征与解码路径进行跳跃连接(skip connection)。这种跳跃连接能够帮助网络结合低层次细节信息和高层次语义信息,实现精确的像素级分割。

编码器

类似传统的卷积神经网络,编码器的主要任务是逐渐压缩输入图像的空间分辨率,提取更高层次的特征。

这个部分包含一系列卷积层和最大池化层(max pooling),每次池化操作都会将图像的空间维度减少一半。

图片图片

解码器

解码器的任务是通过逐渐恢复图像的空间分辨率,将编码器部分提取到的高层次特征映射回原始的图像分辨率。

解码器包含反卷积(上采样)操作,并结合来自编码器的相应特征层,以实现精细的边界恢复。

图片图片

跳跃连接

跳跃连接是 UNet 的一个关键创新点。

每个编码器层的输出特征图与解码器中对应层的特征图进行拼接,形成跳跃连接。

这样可以将编码器中的局部信息和解码器中的全局信息进行融合,从而提高分割结果的精度。

图片图片

UNet 算法工作流程

  • 输入图像

  • 编码阶段

每个编码块包含两个 3x3 卷积层(带有 ReLU 激活函数)和一个 2x2 最大池化层,池化层用于下采样。

经过每个编码块后,特征图的空间尺寸减少一半,但通道数量翻倍。

  • 瓶颈层

在网络的最底部,这部分用来提取最深层次的特征。

  • 解码阶段

每个解码块包含一个 2x2 转置卷积(或上采样操作)和两个 3x3 卷积层(带有 ReLU 激活函数)。

  • 与编码路径不同的是,解码过程中每次上采样时,还将相应的编码层的特征拼接(跳跃连接)到解码层。
  1. 输出层

最后一层通过 1x1 卷积将输出通道数映射为类别数,用于生成分割掩码。

最终输出的是一个大小与输入图像相同的分割图。

代码示例

下面是一个使用 UNet 进行图像分割的简单示例代码。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision import datasets
import matplotlib.pyplot as plt

# UNet 模型定义
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )

        self.encoder1 = conv_block(1, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.encoder4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = conv_block(512, 1024)

        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder4 = conv_block(1024, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = conv_block(128, 64)

        self.conv_last = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(self.pool(enc2))
        enc4 = self.encoder4(self.pool(enc3))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool(enc4))

        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return torch.sigmoid(self.conv_last(dec1))

# 创建数据集
class RandomDataset(Dataset):
    def __init__(self, num_samples, image_size):
        self.num_samples = num_samples
        self.image_size = image_size
        self.transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        image = torch.randn(1, self.image_size, self.image_size)  # 随机生成图像
        mask = (image > 0).float()  # 随机生成掩码
        return image, mask

# 训练模型
def train_model():
    image_size = 128
    batch_size = 8
    num_epochs = 10
    learning_rate = 1e-3

    # 实例化模型、损失函数和优化器
    model = UNet()
    criterion = nn.BCELoss()  # 使用二元交叉熵损失
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
  
    dataset = RandomDataset(num_samples=100, image_size=image_size)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
  
    for epoch in range(num_epochs):
        for images, masks in dataloader:          
            outputs = model(images)
            loss = criterion(outputs, masks)
         
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

    # 测试一个随机样本
    test_image, test_mask = dataset[0]
    model.eval()
    with torch.no_grad():
        prediction = model(test_image.unsqueeze(0))

    
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.title('Input Image')
    plt.imshow(test_image.squeeze().numpy(), cmap='gray')

    plt.subplot(1, 3, 2)
    plt.title('Ground Truth Mask')
    plt.imshow(test_mask.squeeze().numpy(), cmap='gray')

    plt.subplot(1, 3, 3)
    plt.title('Predicted Mask')
    plt.imshow(prediction.squeeze().numpy(), cmap='gray')

    plt.show()


train_model()

UNet 的成功源于其有效的特征提取与恢复机制,特别是跳跃连接的设计,使得编码过程中丢失的细节能够通过解码阶段恢复。

UNet 在医学图像分割等任务上有着广泛的应用,能够生成高精度的像素级分割结果。

责任编辑:武晓燕 来源: 程序员学长
相关推荐

2024-10-16 07:58:48

2024-09-12 08:28:32

2024-10-17 13:05:35

神经网络算法机器学习深度学习

2024-10-05 23:00:35

2024-11-14 00:16:46

Seq2Seq算法RNN

2024-09-20 07:36:12

2024-10-28 00:38:10

2024-11-15 13:20:02

2024-07-17 09:32:19

2024-08-01 08:41:08

2024-08-23 09:06:35

机器学习混淆矩阵预测

2024-11-05 12:56:06

机器学习函数MSE

2024-10-14 14:02:17

机器学习评估指标人工智能

2024-09-18 16:42:58

机器学习评估指标模型

2024-10-28 00:00:10

机器学习模型程度

2024-10-28 15:52:38

机器学习特征工程数据集

2024-10-08 15:09:17

2024-10-08 10:16:22

2024-10-30 08:23:07

2024-10-31 10:00:39

注意力机制核心组件
点赞
收藏

51CTO技术栈公众号