快速学会一个算法,UNet

开发 架构
U-Net 的核心思想是通过对称的编码器-解码器结构,实现对输入图像的高效特征提取和精确的像素级分割。

今天给大家分享一个强大的算法模型,U-Net

U-Net 是一种广泛应用于图像分割任务的卷积神经网络(CNN)架构,最初由 Olaf Ronneberger 等人在 2015 年为生物医学图像分割而提出。

由于其出色的性能和灵活性,U-Net 现已广泛应用于各种图像分割领域,如医学影像分析、遥感图像处理等。

图片图片

U-Net 的架构

U-Net 的核心思想是通过对称的编码器-解码器结构,实现对输入图像的高效特征提取和精确的像素级分割。

UNet的架构由两部分组成:

  1. 收缩路径(编码器)类似于传统的卷积神经网络,用于捕捉上下文信息。通过一系列卷积层和池化层逐步降低空间分辨率,同时增加特征通道数。
  2. 扩展路径(解码器)逐步恢复空间分辨率,结合编码器中的高分辨率特征,通过上采样和卷积操作生成精细的分割图。

两部分通过跳跃连接(Skip Connections)连接,允许高分辨率特征与解码器中的对应层进行融合,从而提高分割的精确度。

编码器

编码器的主要作用是通过一系列的卷积层和池化层逐步提取图像的特征,同时减少空间维度(下采样)。

图片图片

具体结构如下:

  • 卷积层:通常采用多个 3×3 的卷积核,步幅为 1,无填充(padding)。
  • 激活函数:通常使用 ReLU 激活函数。
  • 池化层:采用 2×2 的最大池化(Max Pooling),步幅为 2,用于下采样。

每经过一个下采样步骤,特征图的空间尺寸减半,通道数增加一倍,以捕捉更高级别的特征。

解码器

解码器的主要作用是通过上采样逐步恢复图像的空间分辨率,同时结合编码器的特征进行精细的分割。

图片图片

具体结构如下:

  • 上采样:通常采用转置卷积将特征图的空间尺寸放大一倍。
  • 卷积层:与编码器类似,使用 3×3 的卷积核。
  • 激活函数:使用 ReLU 激活函数。

跳跃连接

跳跃连接是 U-Net 的关键设计,通过将编码器中每个下采样步骤的特征图与解码器中相应上采样步骤的特征图进行拼接,保留了高分辨率的信息,帮助解码器更准确地定位和分割目标区域。

图片图片

U-Net 的优势

  1. 高效的特征利用通过跳跃连接,U-Net 能够充分利用不同层次的特征信息,既包含了高层的语义信息,又保留了低层的空间信息,提高了分割的准确性。
  2. 对少量数据的有效利用U-Net 在设计上适合处理数据量较少的任务,尤其在生物医学图像处理中表现出色。
  3. 端到端训练整个 U-Net 可以通过端到端的方式进行训练,简化了模型设计和优化过程。

案例分享

下面是一个使用 UNet 架构进行图像分割的示例代码。

该示例使用 TensorFlow 和 Keras 构建 UNet 模型,并在合成数据上进行训练。

首先,我们来构建一个 Unet 模型。

import tensorflow as tf
from tensorflow.keras import layers, models

def unet_model(input_size=(128, 128, 3)):
    inputs = layers.Input(input_size)
    
    # 编码器
    c1 = layers.Conv2D(64, (3, 3), activatinotallow='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3, 3), activatinotallow='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2, 2))(c1)
    
    c2 = layers.Conv2D(128, (3, 3), activatinotallow='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3, 3), activatinotallow='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2, 2))(c2)
    
    c3 = layers.Conv2D(256, (3, 3), activatinotallow='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3, 3), activatinotallow='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2, 2))(c3)
    
    c4 = layers.Conv2D(512, (3, 3), activatinotallow='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3, 3), activatinotallow='relu', padding='same')(c4)
    p4 = layers.MaxPooling2D(pool_size=(2, 2))(c4)
    
    # 底部
    c5 = layers.Conv2D(1024, (3, 3), activatinotallow='relu', padding='same')(p4)
    c5 = layers.Conv2D(1024, (3, 3), activatinotallow='relu', padding='same')(c5)
    
    # 解码器
    u6 = layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3, 3), activatinotallow='relu', padding='same')(u6)
    c6 = layers.Conv2D(512, (3, 3), activatinotallow='relu', padding='same')(c6)
    
    u7 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3, 3), activatinotallow='relu', padding='same')(u7)
    c7 = layers.Conv2D(256, (3, 3), activatinotallow='relu', padding='same')(c7)
    
    u8 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3, 3), activatinotallow='relu', padding='same')(u8)
    c8 = layers.Conv2D(128, (3, 3), activatinotallow='relu', padding='same')(c8)
    
    u9 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = layers.concatenate([u9, c1], axis=3)
    c9 = layers.Conv2D(64, (3, 3), activatinotallow='relu', padding='same')(u9)
    c9 = layers.Conv2D(64, (3, 3), activatinotallow='relu', padding='same')(c9)
    
    outputs = layers.Conv2D(1, (1, 1), activatinotallow='sigmoid')(c9)
    
    model = models.Model(inputs=[inputs], outputs=[outputs])
    return model

# 模型实例化
model = unet_model()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

为了演示,我们将生成一些合成图像和对应的掩码。实际应用中,你应该使用真实的图像和标注数据。

import numpy as np
import matplotlib.pyplot as plt

def generate_synthetic_data(num_samples, img_size):
    X = np.zeros((num_samples, img_size, img_size, 3), dtype=np.float32)
    y = np.zeros((num_samples, img_size, img_size, 1), dtype=np.float32)
    
    for i in range(num_samples):
        # 随机生成圆形
        radius = np.random.randint(10, img_size//4)
        center_x = np.random.randint(radius, img_size - radius)
        center_y = np.random.randint(radius, img_size - radius)
        
        Y, X_grid = np.ogrid[:img_size, :img_size]
        dist_from_center = np.sqrt((X_grid - center_x)**2 + (Y - center_y)**2)
        mask = dist_from_center <= radius
        
        y[i, mask, 0] = 1.0
        # 图像为掩码的随机颜色
        X[i] = np.random.rand(img_size, img_size, 3) * mask[..., np.newaxis]
    
    return X, y

# 生成训练和验证数据
train_X, train_y = generate_synthetic_data(1000, 128)
val_X, val_y = generate_synthetic_data(200, 128)

# 可视化样本
def display_sample(X, y, index):
    plt.figure(figsize=(6,3))
    plt.subplot(1,2,1)
    plt.imshow(X[index])
    plt.title('Image')
    plt.subplot(1,2,2)
    plt.imshow(y[index].squeeze(), cmap='gray')
    plt.title('Mask')
    plt.show()

display_sample(train_X, train_y, 0)

图片图片

接下来,我们来训练模型。

# 使用回调函数保存最佳模型
checkpoint = tf.keras.callbacks.ModelCheckpoint('unet_best_model.h5', 
                                                mnotallow='val_loss', 
                                                verbose=1, 
                                                save_best_notallow=True, 
                                                mode='min')

# 训练模型
history = model.fit(train_X, train_y, 
                    validation_data=(val_X, val_y),
                    epochs=20, 
                    batch_size=16,
                    callbacks=[checkpoint])

评估和预测

# 加载最佳模型
model.load_weights('unet_best_model.h5')

# 在验证集上评估
loss, accuracy = model.evaluate(val_X, val_y)
print(f'Validation Loss: {loss}')
print(f'Validation Accuracy: {accuracy}')

# 进行预测并可视化
def predict_and_display(model, X, y, index):
    pred = model.predict(X[index:index+1])[0]
    pred_mask = pred > 0.5
    
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.imshow(X[index])
    plt.title('Image')
    plt.subplot(1,3,2)
    plt.imshow(y[index].squeeze(), cmap='gray')
    plt.title('True Mask')
    plt.subplot(1,3,3)
    plt.imshow(pred_mask.squeeze(), cmap='gray')
    plt.title('Predicted Mask')
    plt.show()
predict_and_display(model, val_X, val_y, 0)

图片 图片

责任编辑:武晓燕 来源: 程序员学长
相关推荐

2024-08-29 09:18:55

2024-06-19 09:47:21

2024-06-06 09:44:33

2024-06-03 08:09:39

2024-07-19 08:21:24

2024-08-21 08:21:45

CNN算法神经网络

2024-08-02 10:28:13

算法NLP模型

2024-09-09 23:04:04

2024-12-04 10:33:17

2024-11-11 00:00:02

卷积神经网络算法

2024-08-22 08:24:51

算法CNN深度学习

2024-07-30 08:08:49

2024-08-08 12:33:55

算法

2024-07-12 08:38:05

2024-08-12 00:00:05

集成学习典型算法代码

2024-06-20 08:52:10

2024-08-22 08:21:10

算法神经网络参数

2020-04-10 10:15:29

算法开源Github

2021-07-29 07:55:19

Demo 工作池

2024-09-24 07:28:10

点赞
收藏

51CTO技术栈公众号