当前位置: 首页 > news >正文

【PyTorch】生成对抗网络

生成对抗网络是什么

概念

Generative Adversarial Nets,简称GAN
GAN:生成对抗网络 —— 一种可以生成特定分布数据的模型
《Generative Adversarial Nets》 Ian J Goodfellow-2014

GAN网络结构

Recent Progress on Generative Adversarial Networks (GANs): A Survey
在这里插入图片描述

How Generative Adversarial Networks and Their Variants Work: An Overview
在这里插入图片描述

Generative Adversarial Networks_ A Survey and Taxonomy

在这里插入图片描述

GAN的训练

训练目的

  1. 对于D:对真样本输出高概率
  2. 对于G:输出使D会给出高概率的数据

GAN 的训练和监督学习训练模式的差异

在监督学习的训练模式中,训练数经过模型得到输出值,然后使用损失函数计算输出值与标签之间的差异,根据差异值进行反向传播,更新模型的参数,如下图所示。
在这里插入图片描述
在 GAN 的训练模式中,Generator 接收随机数得到输出值,目标是让输出值的分布与训练数据的分布接近,但是这里不是使用人为定义的损失函数来计算输出值与训练数据分布之间的差异,而是使用 Discriminator 来计算这个差异。需要注意的是这个差异不是单个数字上的差异,而是分布上的差异。如下图所示。
在这里插入图片描述

具体训练过程

step1:训练D
输入:真实数据加G生成的假数据
输出:二分类概率

step2:训练G
输入:随机噪声z
输出:分类概率——D(G(z))

在这里插入图片描述

DCGAN

Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
在这里插入图片描述

Discriminator:卷积结构的模型
Generator:卷积结构的模型

DCGAN 的定义如下:

from collections import OrderedDict
import torch
import torch.nn as nnclass Generator(nn.Module):def __init__(self, nz=100, ngf=128, nc=3):super(Generator, self).__init__()self.main = nn.Sequential(# input is Z, going into a convolutionnn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(True),# state size. (ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# state size. (ngf*4) x 8 x 8nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# state size. (ngf*2) x 16 x 16nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# state size. (ngf) x 32 x 32nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),nn.Tanh()# state size. (nc) x 64 x 64)def forward(self, input):return self.main(input)def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):for m in self.modules():classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, w_mean, w_std)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, b_mean, b_std)nn.init.constant_(m.bias.data, 0)class Discriminator(nn.Module):def __init__(self, nc=3, ndf=128):super(Discriminator, self).__init__()self.main = nn.Sequential(# input is (nc) x 64 x 64nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):return self.main(input)def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):for m in self.modules():classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, w_mean, w_std)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, b_mean, b_std)nn.init.constant_(m.bias.data, 0)

相关文章:

  • C++游戏开发详解:从入门到实践
  • c++primier第十二章类和动态内存
  • openKylin--安装 .net6.0
  • 锁住K8S集群版本和系统内核版本
  • 生产环境升级mysql流程及配置主从服务
  • 【深度学习】ubuntu系统下docker部署cvat的自动标注功能(yolov8 segmentation)
  • 投影算子 Projection
  • CMake构建学习笔记18-cpp-httplib库的构建
  • CC攻击和DDOS攻击的区别有哪些?
  • [单master节点k8s部署]22.构建EFK日志收集平台(一)
  • 河南人社厅:注册满两年可按条件认定副高
  • 【Python报错已解决】TypeError: object of type ‘complex‘ has no len()
  • 红绿灯倒计时读秒数字识别系统源码分享
  • VisualGLM-6B——原理与部署
  • 【题解】Codeforces Round 975 (Div. 2) A~E
  • [译]Python中的类属性与实例属性的区别
  • 【跃迁之路】【463天】刻意练习系列222(2018.05.14)
  • CNN 在图像分割中的简史:从 R-CNN 到 Mask R-CNN
  • css属性的继承、初识值、计算值、当前值、应用值
  • JS函数式编程 数组部分风格 ES6版
  • Linux各目录及每个目录的详细介绍
  • 聊聊flink的BlobWriter
  • 扑朔迷离的属性和特性【彻底弄清】
  • 写代码的正确姿势
  • 深度学习之轻量级神经网络在TWS蓝牙音频处理器上的部署
  • 3月7日云栖精选夜读 | RSA 2019安全大会:企业资产管理成行业新风向标,云上安全占绝对优势 ...
  • # windows 运行框输入mrt提示错误:Windows 找不到文件‘mrt‘。请确定文件名是否正确后,再试一次
  • #NOIP 2014#Day.2 T3 解方程
  • (~_~)
  • (3)医疗图像处理:MRI磁共振成像-快速采集--(杨正汉)
  • (SpringBoot)第二章:Spring创建和使用
  • (vue)页面文件上传获取:action地址
  • (分享)自己整理的一些简单awk实用语句
  • (附源码)计算机毕业设计SSM教师教学质量评价系统
  • (简单有案例)前端实现主题切换、动态换肤的两种简单方式
  • (区间dp) (经典例题) 石子合并
  • (全部习题答案)研究生英语读写教程基础级教师用书PDF|| 研究生英语读写教程提高级教师用书PDF
  • (三)SvelteKit教程:layout 文件
  • (四)Linux Shell编程——输入输出重定向
  • (原創) 系統分析和系統設計有什麼差別? (OO)
  • (转)利用PHP的debug_backtrace函数,实现PHP文件权限管理、动态加载 【反射】...
  • (转)四层和七层负载均衡的区别
  • .bat批处理(十一):替换字符串中包含百分号%的子串
  • .gitignore不生效的解决方案
  • .NET / MSBuild 扩展编译时什么时候用 BeforeTargets / AfterTargets 什么时候用 DependsOnTargets?
  • .NET Compact Framework 3.5 支持 WCF 的子集
  • .NET/C# 使用反射调用含 ref 或 out 参数的方法
  • .NET的数据绑定
  • @Import注解详解
  • @modelattribute注解用postman测试怎么传参_接口测试之问题挖掘
  • @Tag和@Operation标签失效问题。SpringDoc 2.2.0(OpenApi 3)和Spring Boot 3.1.1集成
  • [ 云计算 | Azure 实践 ] 在 Azure 门户中创建 VM 虚拟机并进行验证
  • []使用 Tortoise SVN 创建 Externals 外部引用目录
  • [1]-基于图搜索的路径规划基础
  • [2024最新教程]地表最强AGI:Claude 3注册账号/登录账号/访问方法,小白教程包教包会