图像风格迁移也有框架了:使用Python编写,与PyTorch完美兼容,外行也能用

开发 开发工具
pystiche 是一个用 Python 编写的 NST 框架,基于 PyTorch 构建,并与之完全兼容。相关研究由 pyOpenSci 进行同行评审,并发表在 JOSS 期刊 (Journal of Open Source Software) 上。

易于使用的神经风格迁移框架 pystiche。

将内容图片与艺术风格图片进行融合,生成一张具有特定风格的新图,这种想法并不新鲜。早在 2015 年,Gatys、 Ecker 以及 Bethge 开创性地提出了神经风格迁移(Neural Style Transfer ,NST)。

不同于深度学习,目前 NST 还没有现成的库或框架。因此,新的 NST 技术要么从头开始实现所有内容,要么基于现有的方法实现。但这两种方法都有各自的缺点:前者由于可重用部分的冗长实现,限制了技术创新;后者继承了 DL 硬件和软件快速发展导致的技术债务。

最近,新项目 pystiche 很好地解决了这些问题,虽然它的核心受众是研究人员,但其易于使用的用户界面为非专业人员使用 NST 提供了可能。

pystiche 是一个用 Python 编写的 NST 框架,基于 PyTorch 构建,并与之完全兼容。相关研究由 pyOpenSci 进行同行评审,并发表在 JOSS 期刊 (Journal of Open Source Software) 上。

论文地址:https://joss.theoj.org/papers/10.21105/joss.02761

项目地址:https://github.com/pmeier/pystiche

在深入实现之前,我们先来回顾一下 NST 的原理。它有两种优化方式:基于图像的优化和基于模型的优化。虽然 pystiche 能够很好地处理后者,但更为复杂,因此本文只讨论基于图像的优化方法。

在基于图像的方法中,将图像的像素迭代调整训练,来拟合感知损失函数(perceptual loss)。感知损失是 NST 的核心部分,分为内容损失(content loss)和风格损失(style loss),这些损失评估输出图像与目标图像的匹配程度。与传统的风格迁移算法不同,感知损失包含一个称为编码器的多层模型,这就是 pystiche 基于 PyTorch 构建的原因。

如何使用 pystiche

让我们用一个例子介绍怎么使用 pystiche 生成神经风格迁移图片。首先导入所需模块,选择处理设备。虽然 pystiche 的设计与设备无关,但使用 GPU 可以将 NST 的速度提高几个数量级。

模块导入与设备选择:

  1. import torch 
  2. import pystiche 
  3. from pystiche import demo, enc, loss, ops, optim 
  4.  
  5. print(f"pystiche=={pystiche.__version__}") 
  6.  
  7. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

输出:

  1. pystiche==0.7.0 

多层编码器

content_loss 和 style_loss 是对图像编码进行操作而不是图像本身,这些编码是由在不同层级的预训练编码器生成的。pystiche 定义了 enc.MultiLayerEncoder 类,该类在单个前向传递中可以有效地处理编码问题。该示例使用基于 VGG19 架构的 vgg19_multi_layer_encoder。默认情况下,它将加载 torchvision 提供的权重。

多层编码器:

  1. multi_layer_encoder = enc.vgg19_multi_layer_encoder() 
  2. print(multi_layer_encoder) 

输出:

  1. VGGMultiLayerEncoder( 
  2.   arch=vgg19framework=torchallow_inplace=True 
  3.   (preprocessing): TorchPreprocessing( 
  4.    (0): Normalize( 
  5.      mean=('0.485', '0.456', '0.406'), 
  6.      std=('0.229', '0.224', '0.225') 
  7.     ) 
  8.   ) 
  9.  (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  10.  (relu1_1): ReLU(inplace=True
  11.  (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  12.  (relu1_2): ReLU(inplace=True
  13.  (pool1): MaxPool2d(kernel_size=2stride=2padding=0dilation=1ceil_mode=False
  14.  (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  15.  (relu2_1): ReLU(inplace=True
  16.  (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  17.  (relu2_2): ReLU(inplace=True
  18.  (pool2): MaxPool2d(kernel_size=2stride=2padding=0dilation=1ceil_mode=False
  19.  (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  20.  (relu3_1): ReLU(inplace=True
  21.  (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  22.  (relu3_2): ReLU(inplace=True
  23.  (conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  24.  (relu3_3): ReLU(inplace=True
  25.  (conv3_4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  26.  (relu3_4): ReLU(inplace=True
  27.  (pool3): MaxPool2d(kernel_size=2stride=2padding=0dilation=1ceil_mode=False
  28.  (conv4_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  29.  (relu4_1): ReLU(inplace=True
  30.  (conv4_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  31.  (relu4_2): ReLU(inplace=True
  32.  (conv4_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  33.  (relu4_3): ReLU(inplace=True
  34.  (conv4_4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  35.  (relu4_4): ReLU(inplace=True
  36.  (pool4): MaxPool2d(kernel_size=2stride=2padding=0dilation=1ceil_mode=False
  37.  (conv5_1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  38.  (relu5_1): ReLU(inplace=True
  39.  (conv5_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  40.  (relu5_2): ReLU(inplace=True
  41.  (conv5_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  42.  (relu5_3): ReLU(inplace=True
  43.  (conv5_4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
  44.  (relu5_4): ReLU(inplace=True
  45.  (pool5): MaxPool2d(kernel_size=2stride=2padding=0dilation=1ceil_mode=False

感知损失

pystiche 将内容损失和风格损失定义为操作符。使用 ops.FeatureReconstructionOperator 作为 content_loss,直接与编码进行对比。如果编码器针对分类任务进行过训练,如该示例中这些编码表示内容。对于content_layer,选择 multi_layer_encoder 的较深层来获取抽象的内容表示,而不是许多不必要的细节。

  1. content_layer = "relu4_2" 
  2. encoder = multi_layer_encoder.extract_encoder(content_layer) 
  3. content_loss = ops.FeatureReconstructionOperator(encoder) 

pystiche 使用 ops.GramOperator 作为 style_loss 的基础,通过比较编码各个通道之间的相关性来丢弃空间信息。这样就可以在输出图像中的任意区域合成风格元素,而不仅仅是风格图像中它们所在的位置。对于 ops.GramOperator,如果它在浅层和深层 style_layers 都能很好地运行,则其性能达到最佳。

style_weight 可以控制模型对输出图像的重点——内容或风格。为了方便起见,pystiche 将所有内容包装在 ops.MultiLayerEncodingOperator 中,该操作处理在同一 multi_layer_encoder 的多个层上进行操作的相同类型操作符的情况。

  1. style_layers = ("relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1") 
  2. style_weight = 1e3 
  3.  
  4.  
  5. def get_encoding_op(encoder, layer_weight): 
  6.     return ops.GramOperator(encoder, score_weight=layer_weight
  7.  
  8.  
  9. style_loss = ops.MultiLayerEncodingOperator( 
  10.     multi_layer_encoder, style_layers, get_encoding_op, score_weight=style_weight

loss.PerceptualLoss 结合了 content_loss 与 style_loss,将作为优化的标准。

  1. criterion = loss.PerceptualLoss(content_loss, style_loss).to(device) 
  2. print(criterion) 

输出:

  1. PerceptualLoss( 
  2.  (content_loss): FeatureReconstructionOperator( 
  3.    score_weight=1
  4.    encoder=VGGMultiLayerEncoder
  5.      layer=relu4_2
  6.      arch=vgg19
  7.      framework=torch
  8.      allow_inplace=True 
  9.    ) 
  10.  ) 
  11.  (style_loss): MultiLayerEncodingOperator( 
  12.    encoder=VGGMultiLayerEncoder
  13.      arch=vgg19
  14.      framework=torch
  15.      allow_inplace=True 
  16.  ), 
  17.  score_weight=1000 
  18.  (relu1_1): GramOperator(score_weight=0.2) 
  19.  (relu2_1): GramOperator(score_weight=0.2) 
  20.  (relu3_1): GramOperator(score_weight=0.2) 
  21.  (relu4_1): GramOperator(score_weight=0.2) 
  22.  (relu5_1): GramOperator(score_weight=0.2) 
  23.  ) 

图像加载

首先加载并显在 NST 需要的目标图片。因为 NST 占用内存较多,故将图像大小调整为 500 像素。

  1. size = 500 
  2. images = demo.images() 
  1. content_image = images["bird1"].read(sizesize=size, devicedevice=device) 
  2. criterion.set_content_image(content_image) 

内容图片

  1. style_image = images["paint"].read(sizesize=size, devicedevice=device) 
  2. criterion.set_style_image(style_image) 

风格图片

神经风格迁移

创建 input_image。从 content_image 开始执行 NST,这样可以实现快速收敛。image_optimization 函数是为了方便,也可以由手动优化循环代替,且不受限制。如果没有指定,则使用 torch.optim.LBFGS 作为优化器。

  1. input_image = content_image.clone() 
  2. output_image = optim.image_optimization(input_image, criterion, num_steps=500

【本文是51CTO专栏机构“机器之心”的原创译文,微信公众号“机器之心( id: almosthuman2014)”】 

戳这里,看该作者更多好文

责任编辑:赵宁宁 来源: 51CTO专栏
相关推荐

2024-12-11 15:15:42

2023-12-13 15:00:38

浅拷贝深拷贝Python

2009-09-17 08:35:56

Windows 7Itunes兼容性

2009-02-01 14:34:26

PythonUnix管道风格

2009-11-26 11:00:28

Chrome浏览器Windows 7

2009-11-27 09:05:19

Windows 7Chrome兼容性

2018-05-07 14:11:15

RootAndroidXposed

2017-03-10 16:32:44

Apache Spar大数据工具

2023-12-04 09:00:00

PythonRuff

2024-12-13 16:01:35

2020-11-20 11:05:39

编程工具开发

2023-09-27 23:08:08

Web前端Vue.jsVue3.0

2021-06-18 15:15:51

机器学习Rust框架

2016-03-30 11:20:10

2022-10-30 15:00:40

小样本学习数据集机器学习

2009-07-15 11:00:48

proxool连接池

2013-06-10 23:23:29

操作系统OS X

2023-10-09 14:36:28

工具PLGEFK

2012-04-26 19:27:18

2022-03-16 15:30:25

barAndroid开发者
点赞
收藏

51CTO技术栈公众号