快速学习一个算法UNet,你学会了吗?

人工智能
编码器提取的特征包含了输入图像的丰富细节信息,但由于下采样过程,这些细节在深层次可能会被丢失。跳跃链接通过将这些细节信息直接传递给解码器,使得解码器在进行上采样时能够更好地融合这些细节,提高了分割结果的精度。

大家好,我是小寒。

今天给大家分享一个超强的算法模型,UNet

UNet 是一种卷积神经网络架构,最初由 Olaf Ronneberger 等人在 2015 年提出,主要用于生物医学图像分割任务。

https://arxiv.org/abs/1505.04597

其设计思想在于通过编码器-解码器结构,逐步提取图像特征并进行多尺度融合,从而实现高精度的像素级别分割。

由于其有效的结构以实现精确的图像分割,UNet 已在生物医学成像以外的各种图像分割任务中广受欢迎。

图片图片

Unet 架构

UNet 架构以其 U 形结构而著称,它由两个主要部分组成:收缩(下采样)路径和扩展(上采样)路径。

图片图片

收缩路径

从架构中我们可以看到,收缩路径采用传统的卷积神经网络结构,通过一系列的卷积层和池化层逐步减少空间维度(即图像尺寸)并增加特征通道的数量。

每个卷积操作通常包括两次 3x3 的卷积,紧接着是一个ReLU 激活函数和一个 2x2 的最大池化操作。

池化操作减少了图像的尺寸,使得网络能够更广泛地获取上下文信息。

扩展路径

UNET 架构的一个关键组件是扩展路径。它负责对扩展路径的特征图进行上采样并构建最终的分割掩码。

上采样层(反卷积)

每个上采样操作通过 2x2 的反卷积将图像尺寸扩大一倍,然后通过与收缩路径对应层的特征进行拼接。

反卷积本质上与常规卷积相反。它们增强空间维度而不是减少空间维度,从而允许上采样。

反卷积操作的基本原理

反卷积操作的目标是将输入特征图的空间尺寸放大,而不像普通卷积那样缩小空间尺寸。它通过在输入特征图之间插入零值,然后对插值后的特征图进行标准卷积运算,从而实现上采样。

具体步骤如下:

  • 插值

在输入特征图的每个元素之间插入零值,以增加特征图的尺寸。例如,如果插值因子为2,则每个元素之间插入一个零值,使得特征图的尺寸扩大一倍。

图片图片

  • 卷积运算
    对插值后的特征图应用标准卷积操作。这一步通过卷积核在插值后的特征图上滑动并进行卷积计算,从而生成输出特征图。

跳跃连接

图片图片

跳跃连接将编码器(收缩路径)中某一层的特征图直接连接到解码器(扩展路径)中相应层的特征图。

通过这些连接,编码器提取的低层特征(高分辨率)能够直接传递到解码器的相应层,使网络能够聚合多尺度信息以实现正确的分割。

跳跃链接的作用

  • 特征融合
    编码器提取的特征包含了输入图像的丰富细节信息,但由于下采样过程,这些细节在深层次可能会被丢失。跳跃链接通过将这些细节信息直接传递给解码器,使得解码器在进行上采样时能够更好地融合这些细节,提高了分割结果的精度。
  • 梯度流动
    在深度神经网络中,梯度在反向传播过程中可能会逐渐减小(梯度消失问题),导致训练困难。
    跳跃链接提供了额外的梯度路径,帮助梯度更有效地流动,从而缓解梯度消失问题,加速模型的训练过程。

从头开始构建 UNet 模型

让我们从头来实现一个 UNet 模型。

1.加载必要的库

from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import BatchNormalization, Activation, MaxPool2D
from tensorflow.keras.layers import  Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model
  • 1.
  • 2.
  • 3.
  • 4.

2.构建卷积块

卷积块将 2D 卷积层应用于输入张量,然后进行批量归一化和 ReLU 激活。然后再应用另一个卷积层、批量归一化和 ReLU 激活,然后返回输出张量。

def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    return x
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.

3.构建编码器模块

UNet 架构中的编码器块执行下采样和特征提取。

它对输入张量应用卷积运算,然后进行最大池化以减少空间维度。

该块生成处理后的张量和下采样后的张量,后续层利用这些张量进行进一步处理和特征提取。

def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p
  • 1.
  • 2.
  • 3.
  • 4.

4.构建解码器块

UNet 架构中的解码器块执行上采样并合并跳跃连接。

它应用反卷积将输入张量上采样 2 倍。然后将上采样的张量与来自相应编码器块的张量通过跳跃连接连接起来。

然后进一步应用卷积块来细化合并的特征。解码器块的输出张量用于后续层以进行进一步处理。

def decoder_block(input, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.

5.构建 UNET 模型

build_unet 函数构建一个用于图像分割的 UNet 模型。

它应用编码器块来下采样和捕获特征,然后使用卷积块进行高级表示,之后应用解码器块用于上采样和合并跳过连接。

该模型生成一个具有 S 形激活的输出张量,表示前景类的像素概率。

def build_unet(input_shape):
    inputs = Input(input_shape)

    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p4, 1024)

    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    outputs = Conv2D(1, 1, padding="same", activatinotallow="sigmoid")(d4)

    model = Model(inputs, outputs, name="U-Net")
    return model
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.

现在我们已经从头开始构建了 UNet 模型。

责任编辑:武晓燕 来源: 程序员学长
相关推荐

2024-12-19 00:16:43

2024-08-29 09:18:55

2023-07-30 22:29:51

BDDMockitoAssert测试

2021-10-04 09:29:41

对象池线程池

2023-03-26 22:02:53

APMPR监控

2024-06-21 08:15:25

2024-09-09 23:04:04

2023-09-19 08:03:50

rebase​merge

2023-04-27 08:42:50

效果

2022-02-08 09:09:45

智能指针C++

2024-04-01 08:13:59

排行榜MySQL持久化

2024-03-28 12:20:17

2022-12-09 09:21:10

分库分表算法

2024-01-19 08:25:38

死锁Java通信

2023-01-10 08:43:15

定义DDD架构

2023-07-26 13:11:21

ChatGPT平台工具

2024-02-04 00:00:00

Effect数据组件

2024-09-26 09:10:08

2024-01-02 12:05:26

Java并发编程

2023-08-01 12:51:18

WebGPT机器学习模型
点赞
收藏

51CTO技术栈公众号