从零开始,用英伟达T4、A10训练小型文生视频模型,几小时搞定

人工智能 新闻
在这篇博客中,作者将展示如何将从头开始构建一个小规模的文本生成视频模型,涵盖了从理解理论概念、到编写整个架构再到生成最终结果的所有内容。

本文经计算机视觉研究院公众号授权转载,转载请联系出处。

OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及许多其他已经发布或未来将出现的文本生成视频模型,是继大语言模型 (LLM) 之后 2024 年最流行的 AI 趋势之一。

在这篇博客中,作者将展示如何将从头开始构建一个小规模的文本生成视频模型,涵盖了从理解理论概念、到编写整个架构再到生成最终结果的所有内容。

由于作者没有大算力的 GPU,所以仅编写了小规模架构。以下是在不同处理器上训练模型所需时间的比较。

图片

作者表示,在 CPU 上运行显然需要更长的时间来训练模型。如果你需要快速测试代码中的更改并查看结果,CPU 不是最佳选择。因此建议使用 Colab 或 Kaggle 的 T4 GPU 进行更高效、更快速的训练。

构建目标

我们采用了与传统机器学习或深度学习模型类似的方法,即在数据集上进行训练,然后在未见过数据上进行测试。在文本转视频的背景下,假设有一个包含 10 万个狗捡球和猫追老鼠视频的训练数据集,然后训练模型来生成猫捡球或狗追老鼠的视频。

图片

图源:iStock, GettyImages

虽然此类训练数据集在互联网上很容易获得,但所需的算力极高。因此,我们将使用由 Python 代码生成的移动对象视频数据集。同时使用 GAN(生成对抗网络)架构来创建模型,而不是 OpenAI Sora 使用的扩散模型。

我们也尝试使用扩散模型,但内存要求超出了自己的能力。另一方面,GAN 可以更容易、更快地进行训练和测试。

准备条件

我们将使用 OOP(面向对象编程),因此必须对它以及神经网络有基本的了解。此外 GAN(生成对抗网络)的知识不是必需的,因为这里简单介绍它们的架构。

  • OOP:https://www.youtube.com/watch?v=q2SGW2VgwAM
  • 神经网络理论:https://www.youtube.com/watch?v=Jy4wM2X21u0
  • GAN 架构:https://www.youtube.com/watch?v=TpMIssRdhco
  • Python 基础:https://www.youtube.com/watch?v=eWRfhZUzrAc

了解 GAN 架构

什么是 GAN?

生成对抗网络是一种深度学习模型,其中两个神经网络相互竞争:一个从给定的数据集创建新数据(如图像或音乐),另一个则判断数据是真实的还是虚假的。这个过程一直持续到生成的数据与原始数据无法区分。

真实世界应用

  • 生成图像:GAN 根据文本 prompt 创建逼真的图像或修改现有图像,例如增强分辨率或为黑白照片添加颜色。
  • 数据增强:GAN 生成合成数据来训练其他机器学习模型,例如为欺诈检测系统创建欺诈交易数据。
  • 补充缺失信息:GAN 可以填充缺失数据,例如根据地形图生成地下图像以用于能源应用。
  • 生成 3D 模型:GAN 将 2D 图像转换为 3D 模型,在医疗保健等领域非常有用,可用于为手术规划创建逼真的器官图像。

GAN 工作原理

GAN 由两个深度神经网络组成:生成器和判别器。这两个网络在对抗设置中一起训练,其中一个网络生成新数据,另一个网络评估数据是真是假。

图片

GAN 训练示例

让我们以图像到图像的转换为例,解释一下 GAN 模型,重点是修改人脸。

1. 输入图像:输入图像是一张真实的人脸图像。

2. 属性修改:生成器会修改人脸的属性,比如给眼睛加上墨镜。

3. 生成图像:生成器会创建一组添加了太阳镜的图像。

4. 判别器的任务:判别器接收到混合的真实图像(带有太阳镜的人)和生成的图像(添加了太阳镜的人脸)。 

5. 评估:判别器尝试区分真实图像和生成图像。 

6. 反馈回路:如果判别器正确识别出假图像,生成器会调整其参数以生成更逼真的图像。如果生成器成功欺骗了判别器,判别器会更新其参数以提高检测能力。 

通过这一对抗过程,两个网络都在不断改进。生成器越来越善于生成逼真的图像,而判别器则越来越善于识别假图像,直到达到平衡,判别器再也无法区分真实图像和生成的图像。此时,GAN 已成功学会生成逼真的修改图像。

设置背景

我们将使用一系列 Python 库,让我们导入它们。

# Operating System module for interacting with the operating system
import os


# Module for generating random numbers
import random


# Module for numerical operations
import numpy as np


# OpenCV library for image processing
import cv2


# Python Imaging Library for image processing
from PIL import Image, ImageDraw, ImageFont


# PyTorch library for deep learning
import torch


# Dataset class for creating custom datasets in PyTorch
from torch.utils.data import Dataset


# Module for image transformations
import torchvision.transforms as transforms


# Neural network module in PyTorch
import torch.nn as nn


# Optimization algorithms in PyTorch
import torch.optim as optim


# Function for padding sequences in PyTorch
from torch.nn.utils.rnn import pad_sequence


# Function for saving images in PyTorch
from torchvision.utils import save_image


# Module for plotting graphs and images
import matplotlib.pyplot as plt


# Module for displaying rich content in IPython environments
from IPython.display import clear_output, display, HTML


# Module for encoding and decoding binary data to text
import base64

现在我们已经导入了所有的库,下一步就是定义我们的训练数据,用于训练 GAN 架构。

对训练数据进行编码

我们需要至少 10000 个视频作为训练数据。为什么呢?因为我测试了较小数量的视频,结果非常糟糕,几乎没有任何效果。下一个重要问题是:这些视频内容是什么?  我们的训练视频数据集包括一个圆圈以不同方向和不同运动方式移动的视频。让我们来编写代码并生成 10,000 个视频,看看它的效果如何。 

# Create a directory named 'training_dataset'
os.makedirs('training_dataset', exist_ok=True)


# Define the number of videos to generate for the dataset
num_videos = 10000


# Define the number of frames per video (1 Second Video)
frames_per_video = 10


# Define the size of each image in the dataset
img_size = (64, 64)


# Define the size of the shapes (Circle)
shape_size = 10

设置一些基本参数后,接下来我们需要定义训练数据集的文本 prompt,并据此生成训练视频。

# Define text prompts and corresponding movements for circles
prompts_and_movements = [
 ("circle moving down", "circle", "down"), # Move circle downward
 ("circle moving left", "circle", "left"), # Move circle leftward
 ("circle moving right", "circle", "right"), # Move circle rightward
 ("circle moving diagonally up-right", "circle", "diagonal_up_right"), # Move circle diagonally up-right
 ("circle moving diagonally down-left", "circle", "diagonal_down_left"), # Move circle diagonally down-left
 ("circle moving diagonally up-left", "circle", "diagonal_up_left"), # Move circle diagonally up-left
 ("circle moving diagonally down-right", "circle", "diagonal_down_right"), # Move circle diagonally down-right
 ("circle rotating clockwise", "circle", "rotate_clockwise"), # Rotate circle clockwise
 ("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"), # Rotate circle counter-clockwise
 ("circle shrinking", "circle", "shrink"), # Shrink circle
 ("circle expanding", "circle", "expand"), # Expand circle
 ("circle bouncing vertically", "circle", "bounce_vertical"), # Bounce circle vertically
 ("circle bouncing horizontally", "circle", "bounce_horizontal"), # Bounce circle horizontally
 ("circle zigzagging vertically", "circle", "zigzag_vertical"), # Zigzag circle vertically
 ("circle zigzagging horizontally", "circle", "zigzag_horizontal"), # Zigzag circle horizontally
 ("circle moving up-left", "circle", "up_left"), # Move circle up-left
 ("circle moving down-right", "circle", "down_right"), # Move circle down-right
 ("circle moving down-left", "circle", "down_left"), # Move circle down-left
]

我们已经利用这些 prompt 定义了圆的几个运动轨迹。现在,我们需要编写一些数学公式,以便根据 prompt 移动圆。

# Define function with parameters
def create_image_with_moving_shape(size, frame_num, shape, direction):
 
 # Create a new RGB image with specified size and white background
 img = Image.new('RGB', size, color=(255, 255, 255)) 


 # Create a drawing context for the image
 draw = ImageDraw.Draw(img) 


 # Calculate the center coordinates of the image
 center_x, center_y = size[0] // 2, size[1] // 2 


 # Initialize position with center for all movements
 position = (center_x, center_y) 


 # Define a dictionary mapping directions to their respective position adjustments or image transformations
 direction_map = { 
 # Adjust position downwards based on frame number
 "down": (0, frame_num * 5 % size[1]), 
 # Adjust position to the left based on frame number
 "left": (-frame_num * 5 % size[0], 0), 
 # Adjust position to the right based on frame number
 "right": (frame_num * 5 % size[0], 0), 
 # Adjust position diagonally up and to the right
 "diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]), 
 # Adjust position diagonally down and to the left
 "diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]), 
 # Adjust position diagonally up and to the left
 "diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]), 
 # Adjust position diagonally down and to the right
 "diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]), 
 # Rotate the image clockwise based on frame number
 "rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)), 
 # Rotate the image counter-clockwise based on frame number
 "rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)), 
 # Adjust position for a bouncing effect vertically
 "bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)), 
 # Adjust position for a bouncing effect horizontally
 "bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0), 
 # Adjust position for a zigzag effect vertically
 "zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]), 
 # Adjust position for a zigzag effect horizontally
 "zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y), 
 # Adjust position upwards and to the right based on frame number
 "up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]), 
 # Adjust position upwards and to the left based on frame number
 "up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]), 
 # Adjust position downwards and to the right based on frame number
 "down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]), 
 # Adjust position downwards and to the left based on frame number
 "down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]) 
 }


 # Check if direction is in the direction map
 if direction in direction_map: 
 # Check if the direction maps to a position adjustment
 if isinstance(direction_map[direction], tuple): 
 # Update position based on the adjustment
 position = tuple(np.add(position, direction_map[direction])) 
 else: # If the direction maps to an image transformation
 # Update the image based on the transformation
 img = direction_map[direction] 


 # Return the image as a numpy array
 return np.array(img)

上述函数用于根据所选方向在每一帧中移动我们的圆。我们只需在其上运行一个循环,直至生成所有视频的次数。

# Iterate over the number of videos to generate
for i in range(num_videos):
    # Randomly choose a prompt and movement from the predefined list
    prompt, shape, direction = random.choice(prompts_and_movements)
    
    # Create a directory for the current video
    video_dir = f'training_dataset/video_{i}'
    os.makedirs(video_dir, exist_ok=True)
    
    # Write the chosen prompt to a text file in the video directory
    with open(f'{video_dir}/prompt.txt', 'w') as f:
        f.write(prompt)
    
    # Generate frames for the current video
    for frame_num in range(frames_per_video):
        # Create an image with a moving shape based on the current frame number, shape, and direction
        img = create_image_with_moving_shape(img_size, frame_num, shape, direction)
        
        # Save the generated image as a PNG file in the video directory
        cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)

运行上述代码后,就会生成整个训练数据集。以下是训练数据集文件的结构。

图片

每个训练视频文件夹包含其帧以及对应的文本 prompt。让我们看一下我们的训练数据集样本。

在我们的训练数据集中,我们没有包含圆圈先向上移动然后向右移动的运动。我们将使用这个作为测试 prompt,来评估我们训练的模型在未见过的数据上的表现。

图片

还有一个重要的要点需要注意,我们的训练数据包含许多物体从场景中移出或部分出现在摄像机前方的样本,类似于我们在 OpenAI Sora 演示视频中观察到的情况。 

图片

在我们的训练数据中包含此类样本的原因是为了测试当圆圈从角落进入场景时,模型是否能够保持一致性而不会破坏其形状。

现在我们的训练数据已经生成,需要将训练视频转换为张量,这是 PyTorch 等深度学习框架中使用的主要数据类型。此外,通过将数据缩放到较小的范围,执行归一化等转换有助于提高训练架构的收敛性和稳定性。

预处理训练数据

我们必须为文本转视频任务编写一个数据集类,它可以从训练数据集目录中读取视频帧及其相应的文本 prompt,使其可以在 PyTorch 中使用。

# Define a dataset class inheriting from torch.utils.data.Dataset
class TextToVideoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        # Initialize the dataset with root directory and optional transform
        self.root_dir = root_dir
        self.transform = transform
        # List all subdirectories in the root directory
        self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        # Initialize lists to store frame paths and corresponding prompts
        self.frame_paths = []
        self.prompts = []


        # Loop through each video directory
        for video_dir in self.video_dirs:
            # List all PNG files in the video directory and store their paths
            frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]
            self.frame_paths.extend(frames)
            # Read the prompt text file in the video directory and store its content
            with open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:
                prompt = f.read().strip()
            # Repeat the prompt for each frame in the video and store in prompts list
            self.prompts.extend([prompt] * len(frames))


    # Return the total number of samples in the dataset
    def __len__(self):
        return len(self.frame_paths)


    # Retrieve a sample from the dataset given an index
    def __getitem__(self, idx):
        # Get the path of the frame corresponding to the given index
        frame_path = self.frame_paths[idx]
        # Open the image using PIL (Python Imaging Library)
        image = Image.open(frame_path)
        # Get the prompt corresponding to the given index
        prompt = self.prompts[idx]


        # Apply transformation if specified
        if self.transform:
            image = self.transform(image)


        # Return the transformed image and the prompt
        return image, prompt

在继续编写架构代码之前,我们需要对训练数据进行归一化处理。我们使用 16 的 batch 大小并对数据进行混洗以引入更多随机性。

实现文本嵌入层

你可能已经看到,在 Transformer 架构中,起点是将文本输入转换为嵌入,从而在多头注意力中进行进一步处理。类似地,我们在这里必须编写一个文本嵌入层。基于该层,GAN 架构训练在我们的嵌入数据和图像张量上进行。

# Define a class for text embedding
class TextEmbedding(nn.Module):
    # Constructor method with vocab_size and embed_size parameters
    def __init__(self, vocab_size, embed_size):
        # Call the superclass constructor
        super(TextEmbedding, self).__init__()
        # Initialize embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_size)


    # Define the forward pass method
    def forward(self, x):
        # Return embedded representation of input
        return self.embedding(x)

词汇量将基于我们的训练数据,在稍后进行计算。嵌入大小将为 10。如果使用更大的数据集,你还可以使用 Hugging Face 上已有的嵌入模型。

实现生成器层

现在我们已经知道生成器在 GAN 中的作用,接下来让我们对这一层进行编码,然后了解其内容。

class Generator(nn.Module):
 def __init__(self, text_embed_size):
 super(Generator, self).__init__()
 
 # Fully connected layer that takes noise and text embedding as input
 self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)
 
 # Transposed convolutional layers to upsample the input
 self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
 self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
 self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1) # Output has 3 channels for RGB images
 
 # Activation functions
 self.relu = nn.ReLU(True) # ReLU activation function
 self.tanh = nn.Tanh() # Tanh activation function for final output


 def forward(self, noise, text_embed):
 # Concatenate noise and text embedding along the channel dimension
 x = torch.cat((noise, text_embed), dim=1)
 
 # Fully connected layer followed by reshaping to 4D tensor
 x = self.fc1(x).view(-1, 256, 8, 8)
 
 # Upsampling through transposed convolution layers with ReLU activation
 x = self.relu(self.deconv1(x))
 x = self.relu(self.deconv2(x))
 
 # Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)
 x = self.tanh(self.deconv3(x))
 
 return x

该 Generator 类负责根据随机噪声和文本嵌入的组合创建视频帧,旨在根据给定的文本描述生成逼真的视频帧。该网络从完全连接层 (nn.Linear) 开始,将噪声向量和文本嵌入组合成单个特征向量。然后,该向量被重新整形并经过一系列的转置卷积层 (nn.ConvTranspose2d),这些层将特征图逐步上采样到所需的视频帧大小。

这些层使用 ReLU 激活 (nn.ReLU) 实现非线性,最后一层使用 Tanh 激活 (nn.Tanh) 将输出缩放到 [-1, 1] 的范围。因此,生成器将抽象的高维输入转换为以视觉方式表示输入文本的连贯视频帧。

实现判别器层

在编写完生成器层之后,我们需要实现另一半,即判别器部分。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # Convolutional layers to process input images
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)   # 3 input channels (RGB), 64 output channels, kernel size 4x4, stride 2, padding 1
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # 64 input channels, 128 output channels, kernel size 4x4, stride 2, padding 1
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128 input channels, 256 output channels, kernel size 4x4, stride 2, padding 1
        
        # Fully connected layer for classification
        self.fc1 = nn.Linear(256 * 8 * 8, 1)  # Input size 256x8x8 (output size of last convolution), output size 1 (binary classification)
        
        # Activation functions
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)  # Leaky ReLU activation with negative slope 0.2
        self.sigmoid = nn.Sigmoid()  # Sigmoid activation for final output (probability)


    def forward(self, input):
        # Pass input through convolutional layers with LeakyReLU activation
        x = self.leaky_relu(self.conv1(input))
        x = self.leaky_relu(self.conv2(x))
        x = self.leaky_relu(self.conv3(x))
        
        # Flatten the output of convolutional layers
        x = x.view(-1, 256 * 8 * 8)
        
        # Pass through fully connected layer with Sigmoid activation for binary classification
        x = self.sigmoid(self.fc1(x))
        
        return x

判别器类用作二元分类器,区分真实视频帧和生成的视频帧。目的是评估视频帧的真实性,从而指导生成器产生更真实的输出。该网络由卷积层 (nn.Conv2d) 组成,这些卷积层从输入视频帧中提取分层特征, Leaky ReLU 激活 (nn.LeakyReLU) 增加非线性,同时允许负值的小梯度。

然后,特征图被展平并通过完全连接层 (nn.Linear),最终以 S 形激活 (nn.Sigmoid) 输出指示帧是真实还是假的概率分数。

通过训练判别器准确地对帧进行分类,生成器同时接受训练以创建更令人信服的视频帧,从而骗过判别器。

编写训练参数

我们必须设置用于训练 GAN 的基础组件,例如损失函数、优化器等。

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Create a simple vocabulary for text prompts
all_prompts = [prompt for prompt, _, _ in prompts_and_movements]  # Extract all prompts from prompts_and_movements list
vocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))}  # Create a vocabulary dictionary where each unique word is assigned an index
vocab_size = len(vocab)  # Size of the vocabulary
embed_size = 10  # Size of the text embedding vector


def encode_text(prompt):
    # Encode a given prompt into a tensor of indices using the vocabulary
    return torch.tensor([vocab[word] for word in prompt.split()])


# Initialize models, loss function, and optimizers
text_embedding = TextEmbedding(vocab_size, embed_size).to(device)  # Initialize TextEmbedding model with vocab_size and embed_size
netG = Generator(embed_size).to(device)  # Initialize Generator model with embed_size
netD = Discriminator().to(device)  # Initialize Discriminator model
criterion = nn.BCELoss().to(device)  # Binary Cross Entropy loss function
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Discriminator
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Generator

这是我们必须转换代码以在 GPU 上运行的部分(如果可用)。我们已经编写了代码来查找 vocab_size,并且我们正在为生成器和判别器使用 ADAM 优化器。你可以选择自己的优化器。在这里,我们将学习率设置为较小的值 0.0002,嵌入大小为 10,这比其他可供公众使用的 Hugging Face 模型要小得多。

编写训练 loop

就像其他神经网络一样,我们将以类似的方式对 GAN 架构训练进行编码。

# Number of epochs
num_epochs = 13


# Iterate over each epoch
for epoch in range(num_epochs):
    # Iterate over each batch of data
    for i, (data, prompts) in enumerate(dataloader):
        # Move real data to device
        real_data = data.to(device)
        
        # Convert prompts to list
        prompts = [prompt for prompt in prompts]


        # Update Discriminator
        netD.zero_grad()  # Zero the gradients of the Discriminator
        batch_size = real_data.size(0)  # Get the batch size
        labels = torch.ones(batch_size, 1).to(device)  # Create labels for real data (ones)
        output = netD(real_data)  # Forward pass real data through Discriminator
        lossD_real = criterion(output, labels)  # Calculate loss on real data
        lossD_real.backward()  # Backward pass to calculate gradients
       
        # Generate fake data
        noise = torch.randn(batch_size, 100).to(device)  # Generate random noise
        text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts])  # Encode prompts into text embeddings
        fake_data = netG(noise, text_embeds)  # Generate fake data from noise and text embeddings
        labels = torch.zeros(batch_size, 1).to(device)  # Create labels for fake data (zeros)
        output = netD(fake_data.detach())  # Forward pass fake data through Discriminator (detach to avoid gradients flowing back to Generator)
        lossD_fake = criterion(output, labels)  # Calculate loss on fake data
        lossD_fake.backward()  # Backward pass to calculate gradients
        optimizerD.step()  # Update Discriminator parameters


        # Update Generator
        netG.zero_grad()  # Zero the gradients of the Generator
        labels = torch.ones(batch_size, 1).to(device)  # Create labels for fake data (ones) to fool Discriminator
        output = netD(fake_data)  # Forward pass fake data (now updated) through Discriminator
        lossG = criterion(output, labels)  # Calculate loss for Generator based on Discriminator's response
        lossG.backward()  # Backward pass to calculate gradients
        optimizerG.step()  # Update Generator parameters
    
    # Print epoch information
    print(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")

通过反向传播,我们的损失将针对生成器和判别器进行调整。我们在训练 loop 中使用了 13 个 epoch。我们测试了不同的值,但如果 epoch 高于这个值,结果并没有太大差异。此外,过度拟合的风险很高。如果我们的数据集更加多样化,包含更多动作和形状,则可以考虑使用更高的 epoch,但在这里没有这样做。

当我们运行此代码时,它会开始训练,并在每个 epoch 之后 print 生成器和判别器的损失。

## OUTPUT ##


Epoch [1/13] Loss D: 0.8798642754554749, Loss G: 1.300612449645996
Epoch [2/13] Loss D: 0.8235711455345154, Loss G: 1.3729925155639648
Epoch [3/13] Loss D: 0.6098687052726746, Loss G: 1.3266581296920776


...

保存训练的模型

训练完成后,我们需要保存训练好的 GAN 架构的判别器和生成器,这只需两行代码即可实现。

# Save the Generator model's state dictionary to a file named 'generator.pth'
torch.save(netG.state_dict(), 'generator.pth')


# Save the Discriminator model's state dictionary to a file named 'discriminator.pth'
torch.save(netD.state_dict(), 'discriminator.pth')

生成 AI 视频

正如我们所讨论的,我们在未见过的数据上测试模型的方法与我们训练数据中涉及狗取球和猫追老鼠的示例类似。因此,我们的测试 prompt 可能涉及猫取球或狗追老鼠等场景。

在我们的特定情况下,圆圈向上移动然后向右移动的运动在训练数据中不存在,因此模型不熟悉这种特定运动。但是,模型已经在其他动作上进行了训练。我们可以使用此动作作为 prompt 来测试我们训练过的模型并观察其性能。

# Inference function to generate a video based on a given text promptdef generate_video(text_prompt, num_frames=10):    # Create a directory for the generated video frames based on the text prompt    os.makedirs(f'generated_video_{text_prompt.replace(" ", "_")}', exist_ok=True)        # Encode the text prompt into a text embedding tensor    text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0)        # Generate frames for the video    for frame_num in range(num_frames):        # Generate random noise        noise = torch.randn(1, 100).to(device)                # Generate a fake frame using the Generator network        with torch.no_grad():            fake_frame = netG(noise, text_embed)                # Save the generated fake frame as an image file        save_image(fake_frame, f'generated_video_{text_prompt.replace(" ", "_")}/frame_{frame_num}.png')# usage of the generate_video function with a specific text promptgenerate_video('circle moving up-right')

当我们运行上述代码时,它将生成一个目录,其中包含我们生成视频的所有帧。我们需要使用一些代码将所有这些帧合并为一个短视频。

# Define the path to your folder containing the PNG frames
folder_path = 'generated_video_circle_moving_up-right'

# Get the list of all PNG files in the folder
image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]


# Sort the images by name (assuming they are numbered sequentially)
image_files.sort()


# Create a list to store the frames
frames = []


# Read each image and append it to the frames list
for image_file in image_files:
 image_path = os.path.join(folder_path, image_file)
 frame = cv2.imread(image_path)
 frames.append(frame)


# Convert the frames list to a numpy array for easier processing
frames = np.array(frames)


# Define the frame rate (frames per second)
fps = 10


# Create a video writer object
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))


# Write each frame to the video
for frame in frames:
 out.write(frame)


# Release the video writer
out.release()

确保文件夹路径指向你新生成的视频所在的位置。运行此代码后,你将成功创建 AI 视频。让我们看看它是什么样子。

图片

我们进行了多次训练,训练次数相同。在两种情况下,圆圈都是从底部开始,出现一半。好消息是,我们的模型在两种情况下都尝试执行直立运动。

例如,在尝试 1 中,圆圈沿对角线向上移动,然后执行向上运动,而在尝试 2 中,圆圈沿对角线移动,同时尺寸缩小。在两种情况下,圆圈都没有向左移动或完全消失,这是一个好兆头。

最后,作者表示已经测试了该架构的各个方面,发现训练数据是关键。通过在数据集中包含更多动作和形状,你可以增加可变性并提高模型的性能。由于数据是通过代码生成的,因此生成更多样的数据不会花费太多时间;相反,你可以专注于完善逻辑。

此外,文章中讨论的 GAN 架构相对简单。你可以通过集成高级技术或使用语言模型嵌入 (LLM) 而不是基本神经网络嵌入来使其更复杂。此外,调整嵌入大小等参数会显著影响模型的有效性。

责任编辑:张燕妮 来源: 计算机视觉研究院
相关推荐

2024-05-13 08:20:00

GPU芯片

2024-07-31 08:14:17

2023-02-06 10:25:13

AI模型

2017-12-12 12:24:39

Python决策树

2018-01-16 11:00:25

2015-11-17 16:11:07

Code Review

2019-01-18 12:39:45

云计算PaaS公有云

2018-04-18 07:01:59

Docker容器虚拟机

2024-05-10 07:58:03

2020-07-02 15:32:23

Kubernetes容器架构

2013-09-11 09:37:17

企业级移动应用

2021-03-16 11:30:33

2017-07-19 10:22:07

2024-08-06 14:30:00

AI模型

2009-07-22 16:34:36

使用T4ASP.NET MVC

2018-09-14 17:16:22

云计算软件计算机网络

2010-05-26 17:35:08

配置Xcode SVN

2024-05-15 14:29:45

2015-08-11 11:08:19

中科创达

2023-06-02 07:37:12

LLM​大语言模型
点赞
收藏

51CTO技术栈公众号