我们一起快速学会一个算法-UNet

人工智能
UNet 模型由两部分组成:编码器(Contracting Path)和解码器(Expanding Path),中间通过跳跃连接(Skip Connections)相连。

大家好,我是小寒。

今天给大家分享一个超强的算法模型,UNet

UNet 是一种专门用于图像分割任务的卷积神经网络(CNN)架构,最早由 Olaf Ronneberger 等人在 2015 年提出。

UNet 的名字来源于其结构的对称性,类似于字母“U”。UNet 模型由于其优越的分割性能,被广泛应用于各种图像分割任务,如医学图像分割等。

图片图片

Unet 模型架构

UNet 模型由两部分组成:编码器(Contracting Path)和解码器(Expanding Path),中间通过跳跃连接(Skip Connections)相连。

编码器(收缩路径)

编码器部分主要用于提取输入图像的特征。

它由一系列的卷积层、ReLU激活函数、最大池化层(Max Pooling)组成。

  • 每个卷积层通常包含两次卷积操作(使用 3x3 卷积核),每次卷积操作后接一个 ReLU 激活函数。
  • 之后,采用一个 2x2 的最大池化层(Max Pooling)进行下采样,以减少特征图的空间维度。
  • 每次下采样后,特征图的空间尺寸减小,而通道数增加,以提取更高层次的特征。

解码器(扩展路径)

解码器部分用于恢复图像的空间信息,最终输出与输入图像相同大小的分割结果。

它由上采样(up-sampling)操作和卷积层组成。

  • 上采样(Upsampling),通常通过反卷积将特征图的空间分辨率逐步恢复。
  • 上采样后,通过跳跃连接(Skip Connection)将对应层的编码器特征与解码器特征拼接在一起,这样可以保留输入图像的细节。
  • 拼接后的特征图经过两次卷积操作(同样使用 3x3 卷积核)和 ReLU 激活函数进行处理。
  • 最终,经过逐步上采样和卷积,恢复到与输入图像相同的分辨率。

跳跃连接 (Skip Connections)

在UNet中,跳跃连接将编码器中每一层的输出与解码器中相应层的输入相连,确保模型在还原图像分辨率时保留更多的细节信息。

这种连接允许网络在进行上采样时参考编码器部分的特征,从而更好地复原高分辨率特征。

UNet模型的优点

  1. 高效处理小样本数据集
    UNet 最初设计用于生物医学图像分割,具有高效利用小样本数据集的能力。
  2. 精细的分割结果
    通过跳跃连接,UNet 能够很好地保留高分辨率的细节,使得分割结果更为精确。
  3. 灵活性强
    UNet 结构简单且有效,容易扩展和调整,适应不同类型的分割任务。

案例分享

下面是一个使用 PyTorch 实现 UNet 模型的代码示例。这个示例展示了一个简化版的UNet模型,并应用于图像分割任务。

import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        
        # 编码器部分
        self.encoder1 = self.double_conv(in_channels, 64)
        self.encoder2 = self.double_conv(64, 128)
        self.encoder3 = self.double_conv(128, 256)
        self.encoder4 = self.double_conv(256, 512)
        
        # 最底部的卷积
        self.bottleneck = self.double_conv(512, 1024)
        
        # 解码器部分
        self.upconv4 = self.upconv(1024, 512)
        self.decoder4 = self.double_conv(1024, 512)
        self.upconv3 = self.upconv(512, 256)
        self.decoder3 = self.double_conv(512, 256)
        self.upconv2 = self.upconv(256, 128)
        self.decoder2 = self.double_conv(256, 128)
        self.upconv1 = self.upconv(128, 64)
        self.decoder1 = self.double_conv(128, 64)
        
        # 最终的1x1卷积,用于生成分割图
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def double_conv(self, in_channels, out_channels):
        """两次卷积操作"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def upconv(self, in_channels, out_channels):
        """上采样操作"""
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    
    def forward(self, x):
        # 编码器部分
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2))
        enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2))
        enc4 = self.encoder4(F.max_pool2d(enc3, kernel_size=2))
        
        # Bottleneck
        bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2))
        
        # 解码器部分
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, self.crop_tensor(enc4, dec4)), dim=1)
        dec4 = self.decoder4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, self.crop_tensor(enc3, dec3)), dim=1)
        dec3 = self.decoder3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, self.crop_tensor(enc2, dec2)), dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, self.crop_tensor(enc1, dec1)), dim=1)
        dec1 = self.decoder1(dec1)
        
        # 最后的1x1卷积生成输出
        return self.final_conv(dec1)

    def crop_tensor(self, encoder_tensor, decoder_tensor):
        """裁剪编码器张量,使其与解码器张量大小匹配"""
        _, _, H, W = decoder_tensor.size()
        encoder_tensor = self.center_crop(encoder_tensor, H, W)
        return encoder_tensor

    def center_crop(self, tensor, target_height, target_width):
        """中心裁剪函数"""
        _, _, h, w = tensor.size()
        crop_y = (h - target_height) // 2
        crop_x = (w - target_width) // 2
        return tensor[:, :, crop_y:crop_y + target_height, crop_x:crop_x + target_width]

# 使用示例
model = UNet(in_channels=1, out_channels=1)  # 输入和输出均为1通道(例如用于灰度图像)
input_image = torch.randn(1, 1, 572, 572)    # 随机生成一个输入图像
output = model(input_image)
print(output.shape)
责任编辑:武晓燕 来源: 程序员学长
相关推荐

2024-12-19 00:16:43

2024-06-19 09:47:21

2021-11-26 07:00:05

反转整数数字

2024-07-19 08:21:24

2024-06-06 09:44:33

2024-06-03 08:09:39

2024-08-21 08:21:45

CNN算法神经网络

2024-08-02 10:28:13

算法NLP模型

2024-09-09 23:04:04

2024-08-02 09:49:35

Spring流程Tomcat

2021-10-27 06:49:34

线程池Core函数

2024-06-17 11:59:39

2022-08-29 07:48:27

文件数据参数类型

2024-08-12 15:55:51

2021-11-15 11:03:09

接口压测工具

2021-05-20 07:15:34

RSA-PSS算法签名

2023-11-13 18:36:04

知识抽取NER

2023-05-08 07:32:03

BFSDFS路径

2023-10-31 14:04:17

Rust类型编译器

2024-12-04 10:33:17

点赞
收藏

51CTO技术栈公众号