超越边界:ControlNets改变医学图像生成的游戏规则

原创 精选
人工智能 机器视觉
文章介绍了如何使用ControlNets来控制Latent Diffusion Models生成医学图像的过程。首先,讨论了扩散模型的发展和其在生成过程中的控制挑战,然后介绍了ControlNets的概念和优势。

作者 | 崔皓

审校 | 重楼

摘要

文章介绍了如何使用ControlNets来控制Latent Diffusion Models生成医学图像的过程。首先讨论了扩散模型的发展和其在生成过程中的控制挑战,然后介绍了ControlNets的概念和优势。接着,文章详细解释了如何训练Latent Diffusion Model和ControlNet,以及如何使用ControlNet进行采样和评估。最后,文章展示了使用ControlNets生成的医学图像,并提供了性能评估结果。文章的目标是展示这些模型在将脑图像转换为各种对比度的能力,并鼓励读者在自己的项目中使用这些工具。

开篇

本文将介绍如何训练一个ControlNet,使用户能够精确地控制Latent Diffusion Model(如Stable Diffusion!)的生成过程。我们的目标是展示模型将脑图像转换为各种对比度的能力。为了实现这一目标,我们将利用最近推出的MONAI开源扩展,即MONAI Generative Models!

项目代码可以在这个公共仓库中找到:https://github.com/Warvito/generative_brain_controlnet

引言

近年来,文本到图像的扩散模型取得了显著的进步,人们可以根据开放领域的文本描述生成高逼真的图像。这些生成的图像具有丰富的细节,清晰的轮廓,连贯的结构和有意义的场景。然而,尽管扩散模型取得了重大的成就,但在生成过程的精确控制方面仍存在挑战。即使是内容丰富的文本描述,也可能无法准确地捕捉到用户的想法。

正如Lvmin Zhang和Maneesh Agrawala在他们的开创性论文“Adding Conditional Control to Text-to-Image Diffusion Models”(2023)中所提出的那样,ControlNets的引入显著提高了扩散模型的可控性和定制性。这些神经网络充当轻量级的适配器,可以精确地控制和定制,同时保留扩散模型的原始生成能力。通过对这些适配器进行微调,同时保持原始扩散模型不被修改,可以有效地增强文本到图像模型的多样性。

【编者:扩展模型允许我们在保持原始模型不变的同时,对模型的行为进行微调。这意味着我们可以利用原始模型的生成能力,同时通过ControlNets来调整生成过程,以满足特定的需求或改进性能。】

ControlNet的独特之处在于它解决了空间一致性的问题。与以往的方法不同,ControlNet允许对结构空间、结构和几何方面进行明确的控制,同时保留从文本标题中获得的语义控制。原始研究引入了各种模型,使得可以基于边缘、姿态、语义掩码和深度图进行条件生成,为计算机视觉领域激动人心的进步铺平了道路。

【编者:空间一致性是指在图像生成或处理过程中,生成的图像应该在空间上保持一致性,即相邻的像素或区域之间应该有合理的关系和连续性。例如,在生成一个人脸的图像时,眼睛、鼻子和嘴巴的相对位置应该是一致的,不能随机分布。在传统的生成模型中,保持空间一致性可能是一个挑战,因为模型可能会在尝试生成复杂的图像时产生不一致的结果。

例如,你正在使用一个文本到图像的模型来生成一张猫的图像,来看看如何使用边缘、姿态、语义掩码和深度图这四个条件。

1. 边缘:创建一个边缘图来表示猫的轮廓。包括猫的头部、身体、尾巴等主要部分的边缘信息。

2.姿态:创建姿态图来表示猫的姿势。例如,一只正在跳跃的猫,可以在姿态图中表示出这个跳跃的动作。

3.语义掩码:创建一个语义掩码来表示猫的各个部分。例如,在语义掩码中标出猫的眼睛、耳朵、鼻子等部分。

4.深度图:创建深度图来表示猫的三维形状。例如,表示出猫的头部比尾巴更接近观察者。

通过这四个步骤,就可以指导模型生成一只符合我们需求的猫的图像。】

在医学成像领域,经常会遇到图像转换的应用场景,因此ControlNet的使用就非常有价值。在这些应用场景中,有一个场景需要将图像在不同的领域之间进行转换,例如将计算机断层扫描(CT)转换为磁共振成像(MRI),或者将图像在不同的对比度之间进行转换,例如从T1加权到T2加权的MRI图像。在这篇文章中,我们将关注一个特定的案例:使用从FLAIR图像获取的脑图像的2D切片来生成相应的T1加权图像。我们的目标是展示MONAI扩展(MONAI Generative Models)以及ControlNets如何进行医学数据训练,并生成评估模型。通过深入研究这个例子,我们能提供关于这些技术在医学成像领域的最佳实践。

FLAIR到T1w转换FLAIR到T1w转换

Latent Diffusion Model训练

Latent Diffusion Model架构Latent Diffusion Model架构

为了从FLAIR图像生成T1加权(T1w)图像,首先需要训练一个能够生成T1w图像的扩散模型。在我们的例子中,我们使用从英国生物银行数据集(根据这个数据协议可用)中提取的脑MRI图像的2D切片。然后使用你最喜欢的方法(例如,ANTs或UniRes)将原始的3D脑部图像注册到MNI空间后,我们从脑部的中心部分提取五个2D切片。之所以选择这个区域,因为它包含各种组织,使得我们更容易评估图像转换。使用这个脚本,我们最终得到了大约190,000个切片,空间尺寸为224 × 160像素。接下来,使用这个脚本将我们的图像划分为训练集(约180,000个切片)、验证集(约5,000个切片)和测试集(约5,000个切片)。准备好数据集后,我们可以开始训练我们的Latent Diffusion Model了!

为了优化计算资源,潜在扩散模型使用一个编码器将输入图像x转换为一个低维的潜在空间z,然后可以通过一个解码器进行重构。这种方法使得即使在计算能力有限的情况下也能训练扩散模型,同时保持了它们的原始质量和灵活性。与我们在之前的文章中所做的类似(使用MONAI生成医学图像),我们使用MONAI Generative models中的KL-regularization模型来创建压缩模型。通过使用这个配置加上L1损失、KL-regularisation,感知损失以及对抗性损失,我们创建了一个能够以高保真度编码和解码脑图像的自编码器(使用这个脚本)。自编码器的重构质量对于Latent Diffusion Model的性能至关重要,因为它定义了生成图像的质量上限。如果自编码器的解码器产生模糊或低质量的图像,我们的生成模型将无法生成更高质量的图像。

【编者:KL-regularization,或称为Kullback-Leibler正则化,是一种在机器学习和统计中常用的技术,用于在模型复杂性和模型拟合数据的好坏之间找到一个平衡。这种正则化方法的名字来源于它使用的Kullback-Leibler散度,这是一种衡量两个概率分布之间差异的度量

上面这段话用一个例子来解释一下,或许会更加清楚一些。

例如,假设你是一个艺术家,你的任务是画出一系列的猫的画像。你可以自由地画任何猫,但是你的老板希望你画的猫看起来都是"普通的"猫,而不是太奇特的猫。

这就是你的任务:你需要创造新的猫的画像,同时还要确保这些画像都符合"普通的"猫的特征。这就像是变分自编码器(VAE)的任务:它需要生成新的数据,同时还要确保这些数据符合某种预设的分布(也就是"普通的"猫的分布)。

现在,假设你开始画猫。你可能会发现,有些时候你画的猫看起来太奇特了,比如它可能有六只眼睛,或者它的尾巴比普通的猫长很多。这时,你的老板可能会提醒你,让你画的猫更接近"普通的"猫。

这就像是KL-regularization的作用:它是一种"惩罚",当你生成的数据偏离预设的分布时,它就会提醒你。如果你画的猫太奇特,你的老板就会提醒你。在VAE中,如果生成的数据偏离预设的分布,KL-regularization就会通过增加损失函数的值来提醒模型。

通过这种方式,你可以在创造新的猫的画像的同时,还能确保这些画像都符合"普通的"猫的特征。同样,VAE也可以在生成新的数据的同时,确保这些数据符合预设的分布。这就是KL-regularization的主要作用。】

【编者:对三种损失函数进行简要说明:

1. L1损失:L1损失,也称为绝对值损失,是预测值和真实值之间差异的绝对值的平均。它的公式为 L1 = 1/n Σ|yi - xi|,其中yi是真实值,xi是预测值,n是样本数量。L1损失对异常值不敏感,因为它不会过度惩罚预测错误的样本。例如,如果我们预测一个房价为100万,但实际价格为110万,L1损失就是10万。

2. 感知损失:感知损失是一种在图像生成任务中常用的损失函数,它衡量的是生成图像和真实图像在感知层面的差异。感知损失通常通过比较图像在某个预训练模型(如VGG网络)的某一层的特征表示来计算。这种方法可以捕捉到图像的高级特性,如纹理和形状,而不仅仅是像素级的差异。例如,如果我们生成的猫的图像和真实的猫的图像在颜色上有细微的差异,但在形状和纹理上是一致的,那么感知损失可能就会很小。

3. 对抗性损失:对抗性损失是在生成对抗网络(GAN)中使用的一种损失函数。在GAN中,生成器的任务是生成看起来像真实数据的假数据,而判别器的任务是区分真实数据和假数据。对抗性损失就是用来衡量生成器生成的假数据能否欺骗判别器。例如,如果我们的生成器生成了一张猫的图像,而判别器几乎无法区分这张图像和真实的猫的图像,那么对抗性损失就会很小。】

使用这个脚本,我们可以通过使用原始图像和它们的重构之间的多尺度结构相似性指数测量(MS-SSIM)来量化自编码器的保真度。在这个例子中,我们得到了一个高性能的MS-SSIM指标,等于0.9876。

【编者:多尺度结构相似性指数测量(Multi-Scale Structural Similarity Index, MS-SSIM)是一种用于衡量两幅图像相似度的指标。它是结构相似性指数(Structural Similarity Index, SSIM)的扩展,考虑了图像的多尺度信息。

SSIM是一种比传统的均方误差(Mean Squared Error, MSE)或峰值信噪比(Peak Signal-to-Noise Ratio, PSNR)更符合人眼视觉感知的图像质量评价指标。它考虑了图像的亮度、对比度和结构三个方面的信息,而不仅仅是像素级的差异。

MS-SSIM则进一步考虑了图像的多尺度信息。它通过在不同的尺度(例如,不同的分辨率)上计算SSIM,然后将这些SSIM值进行加权平均,得到最终的MS-SSIM值。这样可以更好地捕捉到图像的细节和结构信息。

例如,假设我们有两幅猫的图像,一幅是原始图像,另一幅是我们的模型生成的图像。我们可以在不同的尺度(例如,原始尺度、一半尺度、四分之一尺度等)上计算这两幅图像的SSIM值,然后将这些SSIM值进行加权平均,得到MS-SSIM值。如果MS-SSIM值接近1,那么说明生成的图像与原始图像非常相似;如果MS-SSIM值远离1,那么说明生成的图像与原始图像有较大的差异。

MS-SSIM常用于图像处理和计算机视觉的任务中,例如图像压缩、图像增强、图像生成等,用于评价处理或生成的图像的质量。文中MS-SSIM被用来量化自编码器的保真度,即自编码器重构的图像与原始图像的相似度。得到的MS-SSIM指标为0.9876,接近1,说明自编码器的重构质量非常高,重构的图像与原始图像非常相似。】

在我们训练了自编码器之后,我们将在潜在空间z上训练diffusion model(扩散模型)。扩散模型是一个能够通过在一系列时间步上迭代地去噪来从纯噪声图像生成图像的模型。它通常使用一个U-Net架构(具有编码器-解码器格式),其中我们有编码器的层跳过连接到解码器部分的层(通过长跳跃连接),使得特征可重用并稳定训练和收敛。

【编者:这段话描述的是训练扩散模型的过程,以及扩散模型的基本工作原理。

首先,作者提到在训练了自编码器之后,他们将在潜在空间z上训练扩散模型。这意味着他们首先使用自编码器将输入图像编码为一个低维的潜在空间z,然后在这个潜在空间上训练扩散模型。

扩散模型是一种生成模型,它的工作原理是从纯噪声图像开始,然后通过在一系列时间步上迭代地去噪,最终生成目标图像。这个过程就像是将一个模糊的图像逐渐清晰起来,直到生成一个清晰的、与目标图像相似的图像。

扩散模型通常使用一个U-Net架构。U-Net是一种特殊的卷积神经网络,它有一个编码器部分和一个解码器部分,编码器部分用于将输入图像编码为一个潜在空间,解码器部分用于将潜在空间解码为一个输出图像。U-Net的特点是它有一些跳跃连接,这些连接将编码器部分的某些层直接连接到解码器部分的对应层。这些跳跃连接可以使得编码器部分的特征被重用在解码器部分,这有助于稳定训练过程并加速模型的收敛。】

在训练过程中,Latent Diffusion Model学习了给定这些提示的条件噪声预测。再次,我们使用MONAI来创建和训练这个网络。在这个脚本中,我们使用这个配置来实例化模型,其中训练和评估在代码的这个部分进行。由于我们在这个教程中对文本提示不太感兴趣,所以我们对所有的图像使用了相同的提示( “脑部的T1加权图像”)。

我们的Latent Diffusion Model生成的合成脑图像我们的Latent Diffusion Model生成的合成脑图像

再次,我们可以量化生成模型从而提升其性能,这次我们评估样本的质量(使用Fréchet inception distance (FID))和模型的多样性(计算一组1,000个样本的所有样本对之间的MS-SSIM)。使用这对脚本(12),我们得到了FID = 2.1986和MS-SSIM Diversity = 0.5368。

如你在前面的图像和结果中所看到的,我们现在有一个高分辨率图像的模型,质量非常好。然而,我们对图像的外观没有任何空间控制。为此,我们将使用一个ControlNet来引导我们的Latent Diffusion Model的生成。

ControlNet训练

ControlNet架构ControlNet架构

ControlNet架构包括两个主要组成部分:一个是U-Net模型的编码器(可训练版本),包括中间块,以及一个预训练的“锁定”版本的扩散模型。在这里,锁定的副本保留了生成能力,而可训练的副本在特定的图像到图像数据集上进行训练,以学习条件控制。这两个组件通过一个“零卷积”层相互连接——一个1×1的卷积层,其初始化的权重和偏置设为零。卷积权重逐渐从零过渡到优化的参数,确保在初始训练步骤中,可训练和锁定副本的输出与在没有ControlNet的情况下保持一致。换句话说,当ControlNet在任何优化之前应用到某些神经网络块时,它不会对深度神经特征引入任何额外的影响或噪声。

【编者:描述ControlNet架构的设计和工作原理。ControlNet架构包括两个主要部分:一个可训练的U-Net模型的编码器,以及一个预训练的、被"锁定"的扩散模型。

用一个例子来帮助理解。假设你正在学习画画,你有一个老师(预训练的"锁定"的扩散模型)和一个可以修改的画布(可训练的U-Net模型的编码器)。你的老师已经是一个经验丰富的艺术家,他的技能(生成能力)是固定的,不会改变。而你的画布是可以修改的,你可以在上面尝试不同的画法,学习如何画画(学习条件控制)。

这两个部分通过一个"零卷积"层相互连接。这个"零卷积"层就像是一个透明的过滤器,它最初不会改变任何东西(因为它的权重和偏置都初始化为零),所以你可以看到你的老师的原始画作。然后,随着你的学习进步,这个过滤器会逐渐改变(卷积权重从零过渡到优化的参数),开始对你的老师的画作进行修改,使其更符合你的风格。

这个设计确保了在初始训练步骤中,可训练部分和"锁定"部分的输出是一致的,即你的画作最初会和你的老师的画作一样。然后,随着训练的进行,你的画作会逐渐展现出你自己的风格,但仍然保持着你老师的基本技巧。】

通过整合这两个组件,ControlNet使我们能够控制Diffusion Model(扩散模型)的U-Net中每个级别的行为。

在我们的例子中,我们在这个脚本中实例化了ControlNet,使用了以下等效的代码片段。

import torch
from generative.networks.nets import ControlNet, DiffusionModelUNet

# Load pre-trained diffusion model
diffusion_model = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    num_res_blocks=2,
    num_channels=[256, 512, 768],
    attention_levels=[False, True, True],
    with_conditioning=True,
    cross_attention_dim=1024,
    num_head_channels=[0, 512, 768],
)
diffusion_model.load_state_dict(torch.load("diffusion_model.pt"))

# Create ControlNet
controlnet = ControlNet(
    spatial_dims=2,
    in_channels=3,
    num_res_blocks=2,
    num_channels=[256, 512, 768],
    attention_levels=[False, True, True],
    with_conditioning=True,
    cross_attention_dim=1024,
    num_head_channels=[0, 512, 768],
    conditioning_embedding_in_channels=1,
    conditioning_embedding_num_channels=[64, 128, 128, 256],
)

# Create trainable copy of the diffusion model
controlnet.load_state_dict(diffusion_model.state_dict(), strict=False)

# Lock the weighht of the diffusion model
for p in diffusion_model.parameters():
    p.requires_grad = False

由于我们使用的是Latent Diffusion Model,这需要ControlNets将基于图像的条件转换为相同的潜在空间,以匹配卷积的大小。为此,我们使用一个与完整模型一起训练的卷积网络。在我们的案例中,我们有三个下采样级别(类似于自动编码器KL),定义在“conditioning_embedding_num_channels=[64, 128, 128, 256]”。由于我们的条件图像是一个FLAIR图像,只有一个通道,我们也需要在“conditioning_embedding_in_channels=1”中指定其输入通道的数量。

初始化我们的网络后,我们像训练扩散模型一样训练它。在以下的代码片段(以及代码的这部分)中,可以看到首先将条件FLAIR图像传递给可训练的网络,并从其跳过连接中获取输出。然后,当计算预测的噪声时,这些值被输入到扩散模型中。在内部,扩散模型将ControlNets的跳过连接与自己的连接相加,然后在馈送解码器部分之前(代码)。以下是训练循环的一部分:

# Training Loop...
images = batch["t1w"].to(device)
cond = batch["flair"].to(device)
...
noise = torch.randn_like(latent_representation).to(device)
noisy_z = scheduler.add_noise(
    original_samples=latent_representation, noise=noise, timesteps=timesteps)

# Compute trainable part
down_block_res_samples, mid_block_res_sample = controlnet(
    x=noisy_z, timesteps=timesteps, context=prompt_embeds, controlnet_cond=cond)

# Using controlnet outputs to control diffusion model behaviour
noise_pred = diffusion_model(
    x=noisy_z,
    timesteps=timesteps,
    context=prompt_embeds,
    down_block_additional_residuals=down_block_res_samples,
    mid_block_additional_residual=mid_block_res_sample,)

# Then compute diffusion model loss as usual...

ControlNet采样和评估

在训练模型之后,我们可以对它们进行采样和评估。在这里,使用测试集中的FLAIR图像来生成条件化的T1w图像。与我们的训练类似,采样过程非常接近于扩散模型使用的过程,唯一的区别是将条件图像传递给训练过的ControlNet,并在每个采样时间步中使用其输出来馈送扩散模型。如下图所示,我们生成的图像具有高空间保真度的原始条件,皮层回旋遵循类似的形状,图像保留了不同组织之间的边界。

测试集中的原始FLAIR图像作为输入到ControlNet(左),生成的T1加权图像(中),和原始的T1加权图像,也就是预期的输出(右)

在我们对模型的图像进行采样之后,可以量化ControlNet在将图像在不同对比度之间转换时的性能。由于我们从测试集中得到了预期的T1w图像,我们也可以检查它们的差异,并使用平均绝对误差(MAE)、峰值信噪比(PSNR)和MS-SSIM计算真实和合成图像之间的距离。在我们的测试集中,当执行这个脚本时,我们得到了PSNR= 26.2458+-1.0092,MAE=0.02632+-0.0036和MSSIM=0.9526+-0.0111。

ControlNet为我们的扩散模型提供了令人难以置信的控制,最近的方法已经扩展了其方法,结合了不同的训练ControlNets(Multi-ControlNet),在同一模型中处理不同类型的条件(T2I适配器),甚至在模型上设置条件(使用像ControlNet 1.1这样的方法)。如果这些方法听起来很有趣,不要犹豫,尝试一下!

总结

在这篇文章中,我们展示了如何使用ControlNets来控制Latent Diffusion Models生成医学图像的过程。我们的目标是展示这些模型在将脑图像转换为各种对比度的能力。为了实现这一目标,我们利用了最近推出的MONAI的开源扩展,即MONAI Generative Models!我们希望这篇文章能帮助你理解如何使用这些工具,并鼓励你在你自己的项目中使用它们。

作者介绍

崔皓,51CTO社区编辑,资深架构师,拥有18年的软件开发和架构经验,10年分布式架构经验。

原文标题:Controllable Medical Image Generation with ControlNets,作者:Walter Hugo Lopez Pinaya

责任编辑:华轩 来源: 51CTO
相关推荐

2020-12-04 05:13:56

智能织物时尚智能

2023-07-07 11:24:04

2024-03-18 00:09:19

人工智能生成式人工智能安全

2021-10-15 11:28:06

物联网边缘计算IoT

2023-05-11 14:07:29

2012-10-25 13:46:42

2024-07-17 08:27:29

2019-07-25 06:49:26

2013-08-14 10:43:37

2020-08-19 09:45:10

IBMAIOps混合多云管理

2024-02-23 16:12:47

2011-12-28 21:12:10

移动支付

2023-06-02 10:36:59

2020-11-19 17:36:10

IT 运营

2022-09-30 14:32:23

人工智能数据隐私游戏规则

2021-01-28 12:37:40

物联网体育行业IOT

2016-09-10 08:20:09

IBM

2023-03-21 08:02:34

架构React服务器

2018-12-07 16:08:28

Aruba网络管理

2018-01-14 16:01:33

点赞
收藏

51CTO技术栈公众号