对抗生成网络GAN 原创 精华

发布于 2024-11-26 09:06
浏览
0收藏

前言

在前一阶段课程中,我们学习了图像分割中的语义分割、实例分割、旋转目标检测等。这些图像分割算法都是有监督学习,而GAN(生成对抗网络)是一种特别的学习策略,其核心思想非常值得借鉴,所以本章将以GAN网络的代码为切入口,了解掌握其核心思想。

学习策略

人工智能方面的学习策略有两种:有监督学习和无监督学习。

有监督学习

定义:有监督学习是使用带标签的数据集进行训练。每个输入数据都有对应的输出标签,模型通过学习输入与输出之间的关系来进行预测。

举个例子:孩子的个人成长,有经验的家长为期规划了发展的路线,孩子在规划下有计划地学习成长,这属于有监督学习。

无监督学习

定义:无监督学习使用没有标签的数据集进行训练。模型试图发现数据中的模式或结构,而不依赖于任何预先定义的标签。

同样的例子:孩子在无监督学习下,是没有家长为期进行规划,而是经历社会"捶打"(做得好了有加分、做不好扣分),最终学习成长起来。

GAN的基础介绍

在上述的两种学习策略中,有一种特殊的、独立的学习策略:GAN(生成对抗网络)。

它由两个网络(生成器和判别器),通过对抗在竞争中共同发展。

  • G:生成器(造假)
  • D:鉴别器(打假)
  • 训练过程:

两个网络刚开始都没有任何能力

在竞争中共同发展

最后两个网络能力都得到提升

举个例子:GAN网络就像警察和小偷,警察和小偷之间互相对抗。

GAN示例

为了对GAN网络有个直观印象,我们可以参考Github上一个开源项目,对GAN有个初步认知。

页面地址:https://poloclub.github.io/ganlab/

示例目的

  • 在页面中添加一个手写数字图像
  • 通过训练模型来模拟手写数字图像
  • 从而达到新图像与原来的风格类似,分不出真假

对抗生成网络GAN-AI.x社区

核心思想

对抗生成网络GAN-AI.x社区

论文地址:https://arxiv.org/pdf/1406.2661

生成器(Generator):

  • 作用:负责凭空编造假的数据出来。
  • 目的:通过机器生成假的数据(大部分情况下是图像),最终目的是“骗过”判别器。
  • 过程:G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。

判别器(Discriminator):

  • 作用:负责判断传来的数据是真还是假。
  • 目的:判断这张图像是真实的还是机器生成的,目的是找出生成器做的“假数据”。
  • 过程:D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

接下来,我们通过详细了解GAN网络的代码,深入了解其运行过程。

引入依赖

import torch
from torch import nn
from torch.nn import functional as F

import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
from torch.utils.data importDataLoader

import os
import numpy as np
import matplotlib.pyplot as plt

# 判断当前设备是否GPU
device = torch.device('cuda'if torch.cuda.is_available()else'cpu')
device

读取数据集

# 加载并预处理图像
data = datasets.MNIST(root="data", 
                      train=True, 
                      transform = transforms.Compose(transforms=[transforms.ToTensor(),
                      transforms.Normalize(mean=[0.5], std=[0.5])]),
                      download=True)

# 封装成 DataLoader
data_loader = DataLoader(dataset=data, batch_size=100, shuffle=True)
  • 备注:上述​​transform = transforms.Compose​​ 的作用主要是进行数据增强,详细内容在补充知识部分展开介绍。

定义模型

定义生成器

"""
    定义生成器
"""

classGenerator(nn.Module):
"""
        定义一个图像生成
        输入:一个向量
        输出:一个向量(代表图像)
    """
def__init__(self, in_features=100, out_features=28 * 28):
"""
            挂载超参数
        """
# 先初始化父类,再初始化子类
super(Generator, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

# 第一个隐藏层
        self.hidden0 = nn.Linear(in_features=self.in_features, out_features=256)

# 第二个隐藏层
        self.hidden1 = nn.Linear(in_features=256, out_features=512)

# 第三个隐藏层
        self.hidden2 = nn.Linear(in_features=512, out_features=self.out_features)

defforward(self, x):

# 第一层 [b, 100] --> [b, 256]
        h = self.hidden0(x)
        h = F.leaky_relu(input=h, negative_slope=0.2)

# 第二层 [b, 256] --> [b, 512]
        h = self.hidden1(h)
        h = F.leaky_relu(input=h, negative_slope=0.2)

# 第三层 [b, 512] --> [b, 28 * 28]
        h = self.hidden2(h)

# 压缩数据的变化范围
        o = torch.tanh(h)

return o

定义鉴别器

"""
    定义一个鉴别器
"""

classDiscriminator(nn.Module):
"""
        本质:二分类分类器
        输入:一个对象
        输出:真品还是赝品
    """
def__init__(self, in_features=28*28, out_features=1):
super(Discriminator, self).__init__()

        self.in_features=in_features
        self.out_features=out_features

# 第一个隐藏层
        self.hidden0= nn.Linear(in_features=self.in_features, out_features=512)

# 第二个隐藏层
        self.hidden1= nn.Linear(in_features=512, out_features=256)

# 第三个隐藏层
        self.hidden2= nn.Linear(in_features=256, out_features=32)

# 第四个隐藏层
        self.hidden3= nn.Linear(in_features=32, out_features=self.out_features)


defforward(self, x):

# 第一层
        h = self.hidden0(x)
        h = F.leaky_relu(input=h, negative_slope=0.2)
        h = F.dropout(input=h, p=0.2)

# 第二层
        h = self.hidden1(h)
        h = F.leaky_relu(input=h, negative_slope=0.2)
        h = F.dropout(input=h, p=0.2)

# 第三层
        h = self.hidden2(h)
        h = F.leaky_relu(input=h, negative_slope=0.2)
        h = F.dropout(input=h, p=0.2)

# 第四层
        h = self.hidden3(h)

# 输出概率
        o = torch.sigmoid(h)

return o

构建模型

"""
    构建模型
"""
# 定义一个生成器
generator = Generator(in_features=100, out_features=784)
generator.to(device=device)

# 定义一个鉴别器
discriminator = Discriminator(in_features=784, out_features=1)
discriminator.to(device=device)

定义优化器

"""
    定义优化器
"""

# 定义一个生成器的优化器
g_optimizer = torch.optim.Adam(params=generator.parameters(), lr=1e-4)

# 定义一个鉴别的优化器
d_optimizer = torch.optim.Adam(params=discriminator.parameters(), lr=1e-4)

定义损失函数

"""
    定义一个损失函数
"""
loss_fn = nn.BCELoss()

筹备训练

定义训练轮次

# 定义训练轮次
num_epochs = 1000

获取数据的标签

"""
    获取数据的标签
"""

defget_real_data_labels(size):
"""
        获取真实数据的标签
    """
    labels = torch.ones(size,1, device=device)

return labels

defget_fake_data_labels(size):
"""
        获取虚假数据的标签
    """
    labels = torch.zeros(size,1, device=device)

return labels

定义噪声生成器

"""
    噪声生成器
"""
defget_noise(size):
"""
        给生成器准备数据
        - 100维度的向量
    """
    X = torch.randn(size,100, device=device)

return X

# 获取一批测试数据

num_test_samples =16
test_noise = get_noise(num_test_samples)

噪声生成器的作用:因为我们需要监控模型训练的效果,所以将噪声固定下来,在训练过程中看同样的噪声最后给出的结果是否变得越来越好。

训练模型

"""
    训练过程
"""

g_losses =[]
d_losses =[]
for epoch inrange(1, num_epochs+1):

print(f"当前正在进行 第 {epoch} 轮 ....")

# 设置训练模式
    generator.train()
    discriminator.train()

# 遍历真实的图像
for batch_idx,(batch_real_data, _)inenumerate(data_loader):
"""
        1, 先训练鉴别器
            鉴别器就是一个二分类问题
            - 给一批真数据,输出真
            - 给一批假数据,输出假
        
        """

# 1.1 准备数据
# 图像转向量 [b, 1, 28, 28] ---> [b, 784]
# 从数据集中获取100个真实的手写数字图像
        real_data = batch_real_data.view(batch_real_data.size(0),-1).to(device=device)

# 噪声[b, 100]
# 随机生成100个100维度的噪声,用于生成假图像
        noise = get_noise(real_data.size(0))

# 根据噪声,生成假数据 
# [b, 100] --> [b, 784]
        fake_data = generator(noise).detach()


# 1.2 训练过程

# 鉴别器的优化器梯度情况
        d_optimizer.zero_grad()

# 对真实数据鉴别
        real_pred = discriminator(real_data)

# 计算真实数据的误差
        real_loss = loss_fn(real_pred, get_real_data_labels(real_data.size(0)))

# 真实数据的梯度回传
        real_loss.backward()


# 对假数据鉴别
        fake_pred = discriminator(fake_data)

# 计算假数据的误差
        fake_loss = loss_fn(fake_pred, get_fake_data_labels(fake_data.size(0)))

# 假数据梯度回传
        fake_loss.backward()

# 梯度更新
        d_optimizer.step()

# ----------------
        d_losses.append((real_loss + fake_loss).item())
# print(f"鉴别器的损失:{real_loss + fake_loss}")


"""2, 再训练生成器"""

# 获取生成器的生成结果
        fake_pred = generator(get_noise(real_data.size(0)))

# 生产器梯度清空
        g_optimizer.zero_grad()

# 把假数据让鉴别器鉴别一下
# 把discriminator requires_grad = False
# 设置为不可学习
for param in discriminator.parameters():
            param.requires_grad =False

        d_pred = discriminator(fake_pred)

# 设置为可学习
for param in discriminator.parameters():
            param.requires_grad =True

# 计算损失
# 把一个假东西,给专家看,专家说是真的,这个时候,造假的水平就可以了
        g_loss = loss_fn(d_pred, get_real_data_labels(d_pred.size(0)))

# 梯度回传
        g_loss.backward()

# 参数更新
        g_optimizer.step()

# print(f"生成器误差:{g_loss}")
        g_losses.append(g_loss.item())

# 每轮训练之后,观察生成器的效果
    generator.eval()

with torch.no_grad():

# 正向推理
        img_pred = generator(test_noise)
        img_pred = img_pred.view(img_pred.size(0),28,28).cpu().data

# 画图
        display.clear_output(wait=True)

# 设置画图的大小
        fig = plt.figure(1, figsize=(12,8))
# 划分为 4 x 4 的 网格
        gs = gridspec.GridSpec(4,4)

# 遍历每一个
for i inrange(4):
for j inrange(4):
# 取每一个图
                X = img_pred[i *4+ j,:,:]
# 添加一个对应网格内的子图
                ax = fig.add_subplot(gs[i, j])
# 在子图内绘制图像
                ax.matshow(X, cmap=plt.get_cmap("Greys"))
# ax.set_xlabel(f"{label}")
                ax.set_xticks(())
                ax.set_yticks(())
        plt.show()

运行结果:

对抗生成网络GAN-AI.x社区

核心代码说明:

训练过程

  • 随机生成一组潜在向量z,并使用生成器生成一组假数据。
  • 将一组真实数据和一组假数据作为输入,训练判别器。
  • 使用生成器生成一组新的假数据,并训练判别器。
  • 重复步骤2和3,直到生成器生成的假数据与真实数据的分布相似。

对抗生成网络GAN-AI.x社区

核心代码

  • ​fake_data = generator(noise).detach()​​:

作用:是生成器生成一组假数据,并使用detach()方法将其从计算图中分离出来,防止梯度回传。

说明:(因为在训练鉴别器时,生成器只是工具人,其前向传播过程中记录的梯度信息不会被使用,所以不需要记录梯度信息)

  • ​g_loss = loss_fn(d_pred, get_real_data_labels(d_pred.size(0)))​​ 这里是体现对抗的核心代码,即:生成器训练的好不好,是要与真实数据的判别结果越接近越好。

补充知识

数据增强

在人工智能模型的训练中,采集样本是需要成本的,所以为了提升样本的丰富性,一般会采用数据增强的方式。

  • 方式:在样本固定的基础上,通过软件模拟,来生成假数据,丰富样本的多样性
  • 本质:给样本加上适当的噪声,模拟出不同场景的样本
  • 说明:数据增强只发生在模型训练中,为了增加训练样本的多样性

transform介绍

在 PyTorch 中,transform 主要用于数据预处理和增强,特别是在图像处理任务中。transform 是 torchvision 库的一部分,能够对数据集中的图像进行各种转换,以便更好地适应模型训练的需求。以下是 transform 的主要作用

import torch
from torchvision import datasets, transforms

from PIL import Image
import matplotlib.pyplot as plt

# 读取本地下载的一张图片
img = Image.open('girl.png')
img

对抗生成网络GAN-AI.x社区

重设图片尺寸

resize = transforms.Resize((300, 200))
resize_img = resize(img)
resize_img

运行效果:

对抗生成网络GAN-AI.x社区

中心裁剪

centercrop = transforms.CenterCrop(size=(200, 200))
center_img = centercrop(img)
center_img

运行效果:

对抗生成网络GAN-AI.x社区

随机调整亮度、饱和度、对比度等

color_jitter = transforms.ColorJitter(brightness=0.5, 
                                      contrast=0.5, 
                                      saturation=0.5, 
                                      hue=0.5)
color_jitter(img)

运行效果:

对抗生成网络GAN-AI.x社区

随机旋转

random_rotation = transforms.RandomRotation(degrees=10)
random_rotation(img)

运行效果:

对抗生成网络GAN-AI.x社区

组合变换

Compose:可以将多个变换组合在一起,形成一个转换管道,方便批量处理。例如:

from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),          # 将PIL Image转换为Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # 将数据归一化到[-1, 1]之间
])

内容小结

  • GAN(生成对抗网络)是一种特殊的学习策略,它由生成器和判别器组成,生成器生成假数据,判别器判断真假。
  • 生成器(Generator)通过机器生成假的数据(大部分情况下是图像),最终目的是“骗过”判别器。
  • 判别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器做的“假数据”。
  • 训练过程是:先训练判别器,再训练生成器。
  • 训练判别器时,生成器是"工具人",所以需要使用detach()方法,将生成器生成的假数据从计算图中分离出来,防止梯度回传。
  • 训练生成器时,判别器是"工具人",为了避免整个梯度消失,需要使用param.requires_grad = False设置为不可学习,判别完之后再使用param.requires_grad = True设置为可学习。
  • 在人工智能模型训练过程中,通常会使用数据增强的方式,在样本固定的基础上,通过软件模拟,来生成假数据,丰富样本的多样性。
  • transform:在 PyTorch 中,transform 主要用于数据预处理和增强,特别是在图像处理任务中。

本文转载自公众号一起AI技术 作者:Dongming

原文链接:​​https://mp.weixin.qq.com/s/tHEfP7_rTXkXWfWJZu9vYw​

©著作权归作者所有,如需转载,请注明出处,否则将追究法律责任
收藏
回复
举报
回复
相关推荐