当前位置: 首页 > news >正文

昇思训练营打卡第二十天(CycleGAN图像风格迁移互换)

CycleGAN(循环生成对抗网络)是一种生成对抗网络(GAN),它能够在没有成对训练样本的情况下,将一种风格的图片转换成另一种风格。CycleGAN通常用于图像到图像的转换任务,比如将马的图片转换成斑马的图片,或者将夏天的风景转换成冬天的风景。

CycleGAN的核心思想是通过两个循环一致性损失来训练两个映射函数:一个将图片从源域X转换到目标域Y,另一个将图片从目标域Y转换回源域X。这样,CycleGAN可以确保在转换过程中保留原始图片的内容。

CycleGAN的架构包括两个生成器和一个判别器。每个生成器负责将图片从一个域转换到另一个域,而判别器则负责区分真实图片和由生成器生成的图片。

训练CycleGAN时,生成器和判别器会进行对抗性训练,生成器尝试生成能够欺骗判别器的图片,而判别器尝试正确地识别真实图片和生成图片。同时,CycleGAN还会通过循环一致性损失来确保转换后的图片能够转换回原始图片。

from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"download(url, ".", kind="zip", replace=True)
from mindspore.dataset import MindDataset# 读取MindRecord格式数据
name_mr = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"
data = MindDataset(dataset_files=name_mr)
print("Datasize: ", data.get_dataset_size())batch_size = 1
dataset = data.batch(batch_size)
datasize = dataset.get_dataset_size()

可视化

通过 create_dict_iterator 函数将数据转换成字典迭代器,然后使用 matplotlib 模块可视化部分训练数据。

import numpy as np
import matplotlib.pyplot as pltmean = 0.5 * 255
std = 0.5 * 255plt.figure(figsize=(12, 5), dpi=60)
for i, data in enumerate(dataset.create_dict_iterator()):if i < 5:show_images_a = data["image_A"].asnumpy()show_images_b = data["image_B"].asnumpy()plt.subplot(2, 5, i+1)show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_a)plt.axis("off")plt.subplot(2, 5, i+6)show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_b)plt.axis("off")else:break
plt.show()

构建生成器

本案例生成器的模型结构参考的 ResNet 模型的结构,参考原论文,对于128×128大小的输入图片采用6个残差块相连,图片大小为256×256以上的需要采用9个残差块相连,所以本文网络有9个残差块相连,超参数 n_layers 参数控制残差块数。

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normalweight_init = Normal(sigma=0.02)class ConvNormReLU(nn.Cell):def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):super(ConvNormReLU, self).__init__()norm = nn.BatchNorm2d(out_planes)if norm_mode == 'instance':norm = nn.BatchNorm2d(out_planes, affine=False)has_bias = (norm_mode == 'instance')if padding is None:padding = (kernel_size - 1) // 2if pad_mode == 'CONSTANT':if transpose:conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',has_bias=has_bias, weight_init=weight_init)else:conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, padding=padding, weight_init=weight_init)layers = [conv, norm]else:paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))pad = nn.Pad(paddings=paddings, mode=pad_mode)if transpose:conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, weight_init=weight_init)else:conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, weight_init=weight_init)layers = [pad, conv, norm]if use_relu:relu = nn.ReLU()if alpha > 0:relu = nn.LeakyReLU(alpha)layers.append(relu)self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return outputclass ResidualBlock(nn.Cell):def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):super(ResidualBlock, self).__init__()self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)self.dropout = dropoutif dropout:self.dropout = nn.Dropout(p=0.5)def construct(self, x):out = self.conv1(x)if self.dropout:out = self.dropout(out)out = self.conv2(out)return x + outclass ResNetGenerator(nn.Cell):def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,pad_mode="CONSTANT"):super(ResNetGenerator, self).__init__()self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layersself.residuals = nn.SequentialCell(layers)self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)if pad_mode == "CONSTANT":self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',padding=3, weight_init=weight_init)else:pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)self.conv_out = nn.SequentialCell([pad, conv])def construct(self, x):x = self.conv_in(x)x = self.down_1(x)x = self.down_2(x)x = self.residuals(x)x = self.up_2(x)x = self.up_1(x)output = self.conv_out(x)return ops.tanh(output)# 实例化生成器
net_rg_a = ResNetGenerator()
net_rg_a.update_parameters_name('net_rg_a.')net_rg_b = ResNetGenerator()
net_rg_b.update_parameters_name('net_rg_b.')

优化器和损失函数

根据不同模型需要单独的设置优化器,这是训练过程决定的。

对生成器 𝐺𝐺 及其判别器 𝐷𝑌𝐷𝑌 ,目标损失函数定义为:

𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌)=𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[𝑙𝑜𝑔𝐷𝑌(𝑦)]+𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[𝑙𝑜𝑔(1−𝐷𝑌(𝐺(𝑥)))]𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌)=𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[𝑙𝑜𝑔𝐷𝑌(𝑦)]+𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[𝑙𝑜𝑔(1−𝐷𝑌(𝐺(𝑥)))]

其中 𝐺𝐺 试图生成看起来与 𝑌𝑌 中的图像相似的图像 𝐺(𝑥)𝐺(𝑥) ,而 𝐷𝑌𝐷𝑌 的目标是区分翻译样本 𝐺(𝑥)𝐺(𝑥) 和真实样本 𝑦𝑦 ,生成器的目标是最小化这个损失函数以此来对抗判别器。即 𝑚𝑖𝑛𝐺𝑚𝑎𝑥𝐷𝑌𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌)𝑚𝑖𝑛𝐺𝑚𝑎𝑥𝐷𝑌𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌) 。

单独的对抗损失不能保证所学函数可以将单个输入映射到期望的输出,为了进一步减少可能的映射函数的空间,学习到的映射函数应该是周期一致的,例如对于 𝑋𝑋 的每个图像 𝑥𝑥 ,图像转换周期应能够将 𝑥𝑥 带回原始图像,可以称之为正向循环一致性,即 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥 。对于 𝑌𝑌 ,类似的 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥 。可以理解采用了一个循环一致性损失来激励这种行为。

循环一致损失函数定义如下:

𝐿𝑐𝑦𝑐(𝐺,𝐹)=𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[‖𝐹(𝐺(𝑥))−𝑥‖1]+𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[‖𝐺(𝐹(𝑦))−𝑦‖1]𝐿𝑐𝑦𝑐(𝐺,𝐹)=𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[‖𝐹(𝐺(𝑥))−𝑥‖1]+𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[‖𝐺(𝐹(𝑦))−𝑦‖1]

循环一致损失能够保证重建图像 𝐹(𝐺(𝑥))𝐹(𝐺(𝑥)) 与输入图像 𝑥𝑥 紧密匹配。

# 构建生成器,判别器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)# GAN网络损失函数,这里最后一层不使用sigmoid函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss("mean")def gan_loss(predict, target):target = ops.ones_like(predict) * targetloss = loss_fn(predict, target)return loss

前向计算

搭建模型前向计算损失的过程,过程如下代码。

为了减少模型振荡[1],这里遵循 Shrivastava 等人的策略[2],使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。

import mindspore as ms# 前向计算def generator(img_a, img_b):fake_a = net_rg_b(img_b)fake_b = net_rg_a(img_a)rec_a = net_rg_b(fake_b)rec_b = net_rg_a(fake_a)identity_a = net_rg_b(img_a)identity_b = net_rg_a(img_b)return fake_a, fake_b, rec_a, rec_b, identity_a, identity_blambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5def generator_forward(img_a, img_b):true = Tensor(True, dtype.bool_)fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)loss_g_a = gan_loss(net_d_b(fake_b), true)loss_g_b = gan_loss(net_d_a(fake_a), true)loss_c_a = l1_loss(rec_a, img_a) * lambda_aloss_c_b = l1_loss(rec_b, img_b) * lambda_bloss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idtloss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idtloss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_breturn fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_bdef generator_forward_grad(img_a, img_b):_, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)return loss_gdef discriminator_forward(img_a, img_b, fake_a, fake_b):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)d_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)loss_d = (loss_d_a + loss_d_b) * 0.5return loss_ddef discriminator_forward_a(img_a, fake_a):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)return loss_d_adef discriminator_forward_b(img_b, fake_b):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)return loss_d_b# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50
def image_pool(images):num_imgs = 0image1 = []if isinstance(images, Tensor):images = images.asnumpy()return_images = []for image in images:if num_imgs < pool_size:num_imgs = num_imgs + 1image1.append(image)return_images.append(image)else:if random.uniform(0, 1) > 0.5:random_id = random.randint(0, pool_size - 1)tmp = image1[random_id].copy()image1[random_id] = imagereturn_images.append(tmp)else:return_images.append(image)output = Tensor(return_images, ms.float32)if output.ndim != 4:raise ValueError("img should be 4d, but get shape {}".format(output.shape))return output

计算梯度和反向传播

from mindspore import value_and_grad# 实例化求梯度的方法
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):net_d_a.set_grad(False)net_d_b.set_grad(False)fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)_, grads_g_a = grad_g_a(img_a, img_b)_, grads_g_b = grad_g_b(img_a, img_b)optimizer_rg_a(grads_g_a)optimizer_rg_b(grads_g_b)return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):net_d_a.set_grad(True)net_d_b.set_grad(True)loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)loss_d = (loss_d_a + loss_d_b) * 0.5optimizer_d_a(grads_d_a)optimizer_d_b(grads_d_b)return loss_d

模型训练

训练分为两个主要部分:训练判别器和训练生成器,在前文的判别器损失函数中,论文采用了最小二乘损失代替负对数似然目标。

  • 训练判别器:训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化 𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[(𝐷(𝑦)−1)2]𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[(𝐷(𝑦)−1)2] ;

  • 训练生成器:如 CycleGAN 论文所述,我们希望通过最小化 𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[(𝐷(𝐺(𝑥)−1)2]𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[(𝐷(𝐺(𝑥)−1)2] 来训练生成器,以产生更好的虚假图像。

下面定义了生成器和判别器的训练过程:

import os
import time
import random
import numpy as np
from PIL import Image
from mindspore import Tensor, save_checkpoint
from mindspore import dtype# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1
save_step_num = 80
save_checkpoint_epochs = 1
save_ckpt_dir = './train_ckpt_outputs/'print('Start training!')for epoch in range(epochs):g_loss = []d_loss = []start_time_e = time.time()for step, data in enumerate(dataset.create_dict_iterator()):start_time_s = time.time()img_a = data["image_A"]img_b = data["image_B"]res_g = train_step_g(img_a, img_b)fake_a = res_g[0]fake_b = res_g[1]res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))loss_d = float(res_d.asnumpy())step_time = time.time() - start_time_sres = []for item in res_g[2:]:res.append(float(item.asnumpy()))g_loss.append(res[0])d_loss.append(loss_d)if step % save_step_num == 0:print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "f"step:[{int(step):>4d}/{int(datasize):>4d}], "f"time:{step_time:>3f}s,\n"f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")epoch_cost = time.time() - start_time_eper_step_time = epoch_cost / datasizemean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasizeprint(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")if epoch % save_checkpoint_epochs == 0:os.makedirs(save_ckpt_dir, exist_ok=True)save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))print('End of training!')

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 基于JavaScript、puppeteer的爬虫
  • 【Unix/Linux】Unix/Linux如何查看系统版本
  • 双系统ubuntu20.04扩容
  • 使用 Qt 和 ECharts 进行数据可视化
  • 百川工作手机实现销售管理微信监控系统
  • The IsA relationship and HasA relationship
  • Ubuntu安装PostgreSQL
  • Python开发 ——循环中的 `continue` 语句
  • DNS隧道
  • Kafka 和 RabbitMQ对比
  • 我跟ai学web知识点:“短链接”
  • React@16.x(51)路由v5.x(16)- 手动实现文件目录参考
  • ARM/Linux嵌入式面经(十):极氪
  • 静态网页基础知识
  • 19_谷歌GoogLeNet(InceptionV1)深度学习图像分类算法
  • 【Redis学习笔记】2018-06-28 redis命令源码学习1
  • 【vuex入门系列02】mutation接收单个参数和多个参数
  • 【剑指offer】让抽象问题具体化
  • 77. Combinations
  • Java IO学习笔记一
  • Javascript基础之Array数组API
  • javascript面向对象之创建对象
  • Java精华积累:初学者都应该搞懂的问题
  • leetcode46 Permutation 排列组合
  • Node项目之评分系统(二)- 数据库设计
  • PermissionScope Swift4 兼容问题
  • Spring声明式事务管理之一:五大属性分析
  • Webpack4 学习笔记 - 01:webpack的安装和简单配置
  • 闭包--闭包作用之保存(一)
  • 测试开发系类之接口自动化测试
  • 如何学习JavaEE,项目又该如何做?
  • 探索 JS 中的模块化
  • 腾讯视频格式如何转换成mp4 将下载的qlv文件转换成mp4的方法
  • 微信开放平台全网发布【失败】的几点排查方法
  • 温故知新之javascript面向对象
  • Java数据解析之JSON
  • 宾利慕尚创始人典藏版国内首秀,2025年前实现全系车型电动化 | 2019上海车展 ...
  • ​1:1公有云能力整体输出,腾讯云“七剑”下云端
  • # 20155222 2016-2017-2 《Java程序设计》第5周学习总结
  • #我与Java虚拟机的故事#连载06:收获颇多的经典之作
  • (M)unity2D敌人的创建、人物属性设置,遇敌掉血
  • (二)fiber的基本认识
  • (二)PySpark3:SparkSQL编程
  • (二)springcloud实战之config配置中心
  • (六)Hibernate的二级缓存
  • (十二)springboot实战——SSE服务推送事件案例实现
  • (一)RocketMQ初步认识
  • (一)为什么要选择C++
  • (最简单,详细,直接上手)uniapp/vue中英文多语言切换
  • **Java有哪些悲观锁的实现_乐观锁、悲观锁、Redis分布式锁和Zookeeper分布式锁的实现以及流程原理...
  • .gitignore不生效的解决方案
  • .gitignore文件_Git:.gitignore
  • .JPG图片,各种压缩率下的文件尺寸
  • .NET Compact Framework 多线程环境下的UI异步刷新
  • .net core 使用js,.net core 使用javascript,在.net core项目中怎么使用javascript