当前位置: 首页 > news >正文

《昇思25天学习打卡营第24天》

接续上一天的学习任务,我们要继续进行下一步的操作

构造网络

当处理完数据后,就可以来进行网络的搭建了。按照DCGAN论文中的描述,所有模型权重均应从mean为0,sigma为0.02的正态分布中随机初始化。

接下来了解一下其他内容

生成器

生成器G的功能是将隐向量z映射到数据空间。实践场景中,该功能是通过一系列Conv2dTranspose转置卷积层来完成的,每个层都与BatchNorm2d层和ReLu激活层配对,输出数据会经过tanh函数,使其返回[-1,1]的数据范围内。

DCGAN论文生成图像如下所示:

通过输入部分中设置的nzngfnc来影响代码中的生成器结构。nz是隐向量z的长度,ngf与通过生成器传播的特征图的大小有关,nc是输出图像中的通道数。

代码实现

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()

判别器

判别器D是一个二分类网络模型,输出判定该图像为真实图的概率。

代码实现

class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()

接下来进入模型训练阶段

模型训练

其中分为几个要素:

损失函数

当定义了DG后,接下来将使用MindSpore中定义的二进制交叉熵损失函数BCELoss。

优化器

训练模型:训练判别器和训练生成器。

实现模型训练正向逻辑:

def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgs

代码训练

结果展示就不多说了看成品

文末附上打卡时间

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Springboot 开发之 RestTemplate 简介
  • 微信小程序-获取手机号:HttpClientErrorException: 412 Precondition Failed: [no body]
  • 人工智能与机器学习原理精解【11】
  • 【Git】git stash
  • 解决 Git 访问 GitHub 时的 SSL 错误
  • 等保测评与《网络安全法》的深度融合
  • 视频主题Qinmei 3.0视频站源码_WordPress影视视频主题/附详细安装教程
  • springboot电动自行车租赁系统-计算机毕业设计源码64081
  • SpringBoot的基础配置
  • leetcode-207. 课程表
  • java基础概念08-跳出多重循环嵌套
  • 基于主成分分析(PCA)的平面拟合(python)
  • 从0开始学习c++01-软件下载和安装
  • 诊断技巧分享 | 用WPS500压力传感器测试空调压力波形?
  • MySQL存储引擎MyISAM和InnoDB
  • android百种动画侧滑库、步骤视图、TextView效果、社交、搜房、K线图等源码
  • canvas实际项目操作,包含:线条,圆形,扇形,图片绘制,图片圆角遮罩,矩形,弧形文字...
  • es6--symbol
  • Essential Studio for ASP.NET Web Forms 2017 v2,新增自定义树形网格工具栏
  • GitUp, 你不可错过的秀外慧中的git工具
  • Material Design
  • Vue 动态创建 component
  • Vue全家桶实现一个Web App
  • 包装类对象
  • 分类模型——Logistics Regression
  • 分享几个不错的工具
  • 基于webpack 的 vue 多页架构
  • 聊聊redis的数据结构的应用
  • 区块链将重新定义世界
  • 如何借助 NoSQL 提高 JPA 应用性能
  • 入门到放弃node系列之Hello Word篇
  • 实习面试笔记
  • 世界编程语言排行榜2008年06月(ActionScript 挺进20强)
  • 数据仓库的几种建模方法
  • Redis4.x新特性 -- 萌萌的MEMORY DOCTOR
  • ​LeetCode解法汇总518. 零钱兑换 II
  • # Python csv、xlsx、json、二进制(MP3) 文件读写基本使用
  • #如何使用 Qt 5.6 在 Android 上启用 NFC
  • #在线报价接单​再坚持一下 明天是真的周六.出现货 实单来谈
  • $.proxy和$.extend
  • (bean配置类的注解开发)学习Spring的第十三天
  • (C语言)输入自定义个数的整数,打印出最大值和最小值
  • (floyd+补集) poj 3275
  • (Matlab)基于蝙蝠算法实现电力系统经济调度
  • (windows2012共享文件夹和防火墙设置
  • (板子)A* astar算法,AcWing第k短路+八数码 带注释
  • (差分)胡桃爱原石
  • (附源码)ssm考生评分系统 毕业设计 071114
  • (三分钟了解debug)SLAM研究方向-Debug总结
  • (译)2019年前端性能优化清单 — 下篇
  • *setTimeout实现text输入在用户停顿时才调用事件!*
  • .axf 转化 .bin文件 的方法
  • .NET core 自定义过滤器 Filter 实现webapi RestFul 统一接口数据返回格式
  • .NET/C# 中你可以在代码中写多个 Main 函数,然后按需要随时切换
  • @ConditionalOnProperty注解使用说明