当前位置: 首页 > news >正文

【Week-G7】Semi-Supervised GAN 实践,使用MNIST数据集

文章目录

  • 一、基础知识
  • 二、代码实现
    • 2.1 导入所需模块 & 设置网络初始参数
    • 2.2 初始化权重
    • 2.3 定义算法模型
    • 2.4 配置模型
    • 2.5 训练模型
    • 2.6 训练结果

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

本次学习进行Semi-Supervised GAN的实践,数据集为MNIST
主要为了解惑:加入生成的图像本身就携带标签,比如数字1~9,那么:为什么还需要鉴别器判断输入图像的真假,而不直接判断图像属于0-9中的哪一个数字?

一、基础知识

本次学习使用到的SGAN将GAN扩展到半监督学习方式,通过强制判别器D来输出类别标签。具体结构如下图:
在这里插入图片描述

输入数据集:N类中某一个
生成器G:输出第N+1个类
判别器D:充当分类器C的效果
训练时:判别器D被用于预测输入时属于N+1类中的哪一个

SGAN可以用于训练效果更好的判别器D,并且比普通的GAN产生更加高质量的样本。
在这里插入图片描述

二、代码实现

2.1 导入所需模块 & 设置网络初始参数

import argparse
import os
import numpy as np
import mathimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--num_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)cuda = True if torch.cuda.is_available() else False

2.2 初始化权重

def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)

2.3 定义算法模型


class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_emb = nn.Embedding(opt.num_classes, opt.latent_dim)self.init_size = opt.img_size // 4  # Initial size before upsamplingself.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise):out = self.l1(noise)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, bn=True):"""Returns layers of each discriminator block"""block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return blockself.conv_blocks = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# The height and width of downsampled imageds_size = opt.img_size // 2 ** 4# Output layersself.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.num_classes + 1), nn.Softmax())def forward(self, img):out = self.conv_blocks(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)label = self.aux_layer(out)return validity, label

2.4 配置模型


# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()auxiliary_loss.cuda()# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

2.5 训练模型


# ----------
#  Training
# ----------for epoch in range(opt.n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# Adversarial ground truthsvalid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes), requires_grad=False)# Configure inputreal_imgs = Variable(imgs.type(FloatTensor))labels = Variable(labels.type(LongTensor))# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()# Sample noise and labels as generator inputz = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))# Generate a batch of imagesgen_imgs = generator(z)# Loss measures generator's ability to fool the discriminatorvalidity, _ = discriminator(gen_imgs)g_loss = adversarial_loss(validity, valid)g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Loss for real imagesreal_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2# Loss for fake imagesfake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2# Total discriminator lossd_loss = (d_real_loss + d_fake_loss) / 2# Calculate discriminator accuracypred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:25], "images/GAN/sgan/%d.png" % batches_done, nrow=5, normalize=True)print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()))

2.6 训练结果

下载MNIST数据集:
在这里插入图片描述
训练过程:
在这里插入图片描述
训练输出的图像:
在这里插入图片描述

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Oracle DBA常用 sql
  • AI时代,我们还可以做什么?
  • android系统中data下的xml乱码无法查看问题剖析及解决方法
  • C++ 11 for 循环和容器
  • Linux安全与高级应用(七)深入Linux Shell脚本编程:循环与分支结构的高级应用
  • 【算法】装箱问题
  • Apache Kylin分布式的分析数据仓库
  • pdf怎么加密码怎么设置密码?pdf加密码的几种设置方法
  • Python的安装环境以及应用
  • 日撸Java三百行(day17:链队列)
  • Adobe Premiere Pro 2024 v24.5.0.057 最新免费修改版
  • Flink Maven 依赖
  • gorm入门——如何实现分页查询
  • LVS(Linux virual server)详解
  • 密码学基础-为什么使用真随机数(True Random Number Generators)
  • android百种动画侧滑库、步骤视图、TextView效果、社交、搜房、K线图等源码
  • Android系统模拟器绘制实现概述
  • dva中组件的懒加载
  • javascript 总结(常用工具类的封装)
  • java中具有继承关系的类及其对象初始化顺序
  • JS变量作用域
  • Swoft 源码剖析 - 代码自动更新机制
  • V4L2视频输入框架概述
  • 包装类对象
  • 搭建gitbook 和 访问权限认证
  • 互联网大裁员:Java程序员失工作,焉知不能进ali?
  • 基于遗传算法的优化问题求解
  • 开放才能进步!Angular和Wijmo一起走过的日子
  • 如何打造100亿SDK累计覆盖量的大数据系统
  • 新手搭建网站的主要流程
  • 新书推荐|Windows黑客编程技术详解
  • TPG领衔财团投资轻奢珠宝品牌APM Monaco
  • ​​​​​​​​​​​​​​汽车网络信息安全分析方法论
  • ​sqlite3 --- SQLite 数据库 DB-API 2.0 接口模块​
  • ​人工智能之父图灵诞辰纪念日,一起来看最受读者欢迎的AI技术好书
  • # Panda3d 碰撞检测系统介绍
  • # Python csv、xlsx、json、二进制(MP3) 文件读写基本使用
  • #pragma once
  • (1)(1.11) SiK Radio v2(一)
  • (145)光线追踪距离场柔和阴影
  • (6)设计一个TimeMap
  • (回溯) LeetCode 78. 子集
  • (六)c52学习之旅-独立按键
  • (论文阅读11/100)Fast R-CNN
  • .[hudsonL@cock.li].mkp勒索病毒数据怎么处理|数据解密恢复
  • .htaccess配置常用技巧
  • .NET 4 并行(多核)“.NET研究”编程系列之二 从Task开始
  • .net core 使用js,.net core 使用javascript,在.net core项目中怎么使用javascript
  • .net dataexcel 脚本公式 函数源码
  • .net Stream篇(六)
  • .NET/C# 使用反射注册事件
  • .NET单元测试使用AutoFixture按需填充的方法总结
  • .NET关于 跳过SSL中遇到的问题
  • .NET企业级应用架构设计系列之应用服务器
  • .NET微信公众号开发-2.0创建自定义菜单