以往人们普遍认为生成图像是不可能完成的任务,因为按照传统的机器学习思路,我们根本没有真值(ground truth)可以拿来检验生成的图像是否合格。
2014年,Goodfellow等人则提出生成 对抗网络(Generative Adversarial Network, GAN) ,能够让我们完全依靠机器学习来生成极为逼真的图片。GAN的横空出世使得整个人工智能行业都为之震动,计算机视觉和图像生成领域发生了巨变。
本文将带大家了解 GAN的工作原理 ,并介绍如何 通过PyTorch简单上手GAN 。
GAN的原理
按照传统的方法,模型的预测结果可以直接与已有的真值进行比较。然而,我们却很难定义和衡量到底怎样才算作是“正确的”生成图像。
Goodfellow等人则提出了一个有趣的解决办法:我们可以先训练好一个分类工具,来自动区分生成图像和真实图像。这样一来,我们就可以用这个分类工具来训练一个生成网络,直到它能够输出完全以假乱真的图像,连分类工具自己都没有办法评判真假。
按照这一思路,我们便有了GAN:也就是一个 生成器(generator) 和一个 判别器(discriminator) 。生成器负责根据给定的数据集生成图像,判别器则负责区分图像是真是假。GAN的运作流程如上图所示。
损失函数
在GAN的运作流程中,我们可以发现一个明显的矛盾:同时优化生成器和判别器是很困难的。可以想象,这两个模型有着完全相反的目标:生成器想要尽可能伪造出真实的东西,而判别器则必须要识破生成器生成的图像。
为了说明这一点,我们设D(x)为判别器的输出,即x是真实图像的概率,并设G(z)为生成器的输出。判别器类似于一种二进制的分类器,所以其目标是使该函数的结果最大化:这一函数本质上是非负的二元交叉熵损失函数。另一方面,生成器的目标是最小化判别器做出正确判断的机率,因此它的目标是使上述函数的结果最小化。
因此,最终的损失函数将会是两个分类器之间的极小极大博弈,表示如下:理论上来说,博弈的最终结果将是让判别器判断成功的概率收敛到0.5。然而在实践中,极大极小博弈通常会导致网络不收敛,因此仔细调整模型训练的参数非常重要。
在训练GAN时,我们尤其要注意学习率等超参数,学习率比较小时能让GAN在输入噪音较多的情况下也能有较为统一的输出。
计算环境
库
本文将指导大家通过PyTorch搭建整个程序(包括torchvision)。同时,我们将会使用Matplotlib来让GAN的生成结果可视化。以下代码能够导入上述所有库:
- """
- Import necessary libraries to create a generative adversarial network
- The code is mainly developed using the PyTorch library
- """
- import time
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.utils.data import DataLoader
- from torchvision import datasets
- from torchvision.transforms import transforms
- from model import discriminator, generator
- import numpy as np
- import matplotlib.pyplot as plt
数据集
数据集对于训练GAN来说非常重要,尤其考虑到我们在GAN中处理的通常是非结构化数据(一般是图片、视频等),任意一class都可以有数据的分布。这种数据分布恰恰是GAN生成输出的基础。
为了更好地演示GAN的搭建流程,本文将带大家使用最简单的MNIST数据集,其中含有6万张手写阿拉伯数字的图片。
像 MNIST 这样高质量的非结构化数据集都可以在 格物钛 的 公开数据集 网站上找到。事实上,格物钛Open Datasets平台涵盖了很多优质的公开数据集,同时也可以实现 数据集托管及一站式搜索的功能 ,这对AI开发者来说,是相当实用的社区平台。
硬件需求
一般来说,虽然可以使用CPU来训练神经网络,但最佳选择其实是GPU,因为这样可以大幅提升训练速度。我们可以用下面的代码来测试自己的机器能否用GPU来训练:
- """
- Determine if any GPUs are available
- """
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
实现
网络结构
由于数字是非常简单的信息,我们可以将判别器和生成器这两层结构都组建成全连接层(fully connected layers)。
我们可以用以下代码在PyTorch中搭建判别器和生成器:
- """
- Network Architectures
- The following are the discriminator and generator architectures
- """
- class discriminator(nn.Module):
- def __init__(self):
- super(discriminator, self).__init__()
- self.fc1 = nn.Linear(784, 512)
- self.fc2 = nn.Linear(512, 1)
- self.activation = nn.LeakyReLU(0.1)
- def forward(self, x):
- x = x.view(-1, 784)
- x = self.activation(self.fc1(x))
- x = self.fc2(x)
- return nn.Sigmoid()(x)
- class generator(nn.Module):
- def __init__(self):
- super(generator, self).__init__()
- self.fc1 = nn.Linear(128, 1024)
- self.fc2 = nn.Linear(1024, 2048)
- self.fc3 = nn.Linear(2048, 784)
- self.activation = nn.ReLU()
- def forward(self, x):
- x = self.activation(self.fc1(x))
- x = self.activation(self.fc2(x))
- x = self.fc3(x)
- x = x.view(-1, 1, 28, 28)
- return nn.Tanh()(x)
训练
在训练GAN的时候,我们需要一边优化判别器,一边改进生成器,因此每次迭代我们都需要同时优化两个互相矛盾的损失函数。
对于生成器,我们将输入一些随机噪音,让生成器来根据噪音的微小改变输出的图像:
- """
- Network training procedure
- Every step both the loss for disciminator and generator is updated
- Discriminator aims to classify reals and fakes
- Generator aims to generate images as realistic as possible
- """
- for epoch in range(epochs):
- for idx, (imgs, _) in enumerate(train_loader):
- idx += 1
- # Training the discriminator
- # Real inputs are actual images of the MNIST dataset
- # Fake inputs are from the generator
- # Real inputs should be classified as 1 and fake as 0
- real_inputs = imgs.to(device)
- real_outputs = D(real_inputs)
- real_label = torch.ones(real_inputs.shape[0], 1).to(device)
- noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
- noise = noise.to(device)
- fake_inputs = G(noise)
- fake_outputs = D(fake_inputs)
- fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)
- outputs = torch.cat((real_outputs, fake_outputs), 0)
- targets = torch.cat((real_label, fake_label), 0)
- D_loss = loss(outputs, targets)
- D_optimizer.zero_grad()
- D_loss.backward()
- D_optimizer.step()
- # Training the generator
- # For generator, goal is to make the discriminator believe everything is 1
- noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
- noise = noise.to(device)
- fake_inputs = G(noise)
- fake_outputs = D(fake_inputs)
- fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
- G_loss = loss(fake_outputs, fake_targets)
- G_optimizer.zero_grad()
- G_loss.backward()
- G_optimizer.step()
- if idx % 100 == 0 or idx == len(train_loader):
- print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))
- if (epoch+1) % 10 == 0:
- torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
- print('Model saved.')
结果
经过100个训练时期之后,我们就可以对数据集进行可视化处理,直接看到模型从随机噪音生成的数字:
我们可以看到,生成的结果和真实的数据非常相像。考虑到我们在这里只是搭建了一个非常简单的模型,实际的应用效果会有非常大的上升空间。
不仅是有样学样
GAN和以往机器视觉专家提出的想法都不一样,而利用GAN进行的具体场景应用更是让许多人赞叹深度网络的无限潜力。下面我们来看一下两个最为出名的GAN延申应用。
CycleGAN
朱俊彦等人2017年发表的CycleGAN能够在没有配对图片的情况下将一张图片从X域直接转换到Y域,比如把马变成斑马、将热夏变成隆冬、把莫奈的画变成梵高的画等等。这些看似天方夜谭的转换CycleGAN都能轻松做到,并且结果非常准确。
GauGAN
英伟达则通过GAN让人们能够只需要寥寥数笔勾勒出自己的想法,便能得到一张极为逼真的真实场景图片。虽然这种应用需要的计算成本极为高昂,但是GauGAN凭借它的转换能力探索出了前所未有的研究和应用领域。
结语
相信看到这里,你已经知道了GAN的大致工作原理,并且能够自己动手简单搭建一个GAN了。