当前位置: 首页 > news >正文

pytorch 实现一个最简单的 GAN:用mnist数据集生成新图像

文章目录

    • 一、代码
    • 二、生成结果
      • 2.1 loss的变化
      • 2.2 生成的虚假图像的变化
    • 三、不足之处

用 pytorch 实现一个最简单的GAN:用mnist数据集生成新图像

一、代码

训练细节见代码注释:

# @Time    : 2022/9/25
# @Function: 用pytorch实现一个最简单的GAN,用MNIST数据集生成新图片

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

import os
import shutil
from tqdm import tqdm


# 判别器,判断一张图片来源于真实数据集的概率,输入0-1之间的数,数值越大表示数据来源于真实数据集的概率越高。
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features=img_dim, out_features=128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),  # 将输出值映射到0-1之间
        )

    def forward(self, x):
        return self.disc(x)


# 生成器,用随机噪声生成图片
class Generator(nn.Module):
    def __init__(self, noise_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),
            # normalize inputs to [-1, 1] so make outputs [-1, 1]
            # 一般二分类问题中,隐藏层用Tanh函数,输出层用Sigmod函数
        )

    def forward(self, x):
        return self.gen(x)


if __name__ == '__main__':
    device = "cuda" if torch.cuda.is_available() else "cpu"
    lr = 3e-4
    noise_dim = 50  # noise
    image_dim = 28 * 28 * 1  # 784
    batch_size = 32
    num_epochs = 200

    # dataset
    transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5), (0.5))])
    dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    fixed_noise = torch.randn((batch_size, noise_dim)).to(device)

    D = Discriminator(image_dim).to(device)
    G = Generator(noise_dim, image_dim).to(device)
    opt_disc = optim.Adam(D.parameters(), lr=lr)
    opt_gen = optim.Adam(G.parameters(), lr=lr)
    criterion = nn.BCELoss()     # 二分类交叉熵损失函数

    # 存放log的文件夹
    log_dir = "log"
    if (os.path.exists(log_dir)):
        shutil.rmtree(log_dir)
    writer = SummaryWriter(log_dir)

    for epoch in tqdm(range(num_epochs), desc='epochs'):
        # GAN不需要真实label
        for batch_idx, (img, _) in enumerate(loader):
            img = img.view(-1, 784).to(device)
            batch_size = img.shape[0]

            # 训练判别器: max log(D(x)) + log(1 - D(G(z)))
            noise = torch.randn(batch_size, noise_dim).to(device)
            fake_img = G(noise)    # 根据随机噪声生成虚假数据
            disc_fake = D(fake_img)    # 判别器判断生成数据为真的概率
            # torch.zeros_like(x) 表示生成与 x 形状相同、元素全为0的张量
            lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))    # 虚假数据与0计算损失
            disc_real = D(img)    # 判别器判断真实数据为真的概率
            lossD_real = criterion(disc_real, torch.ones_like(disc_real))     # 真实数据与1计算损失
            lossD = (lossD_real + lossD_fake) / 2

            D.zero_grad()
            lossD.backward(retain_graph=True)
            opt_disc.step()

            # 训练生成器: 在此过程中将判别器固定,min log(1 - D(G(z))) <-> max log(D(G(z))
            output = D(fake_img)
            lossG = criterion(output, torch.ones_like(output))
            G.zero_grad()
            lossG.backward()
            opt_gen.step()

            if batch_idx == 0:
                # print( f"Epoch [{epoch+1}/{num_epochs}]  Batch {batch_idx}/{len(loader)}   lossD = {lossD:.4f}, lossG = {lossG:.4f}")
                with torch.no_grad():
                    # 用固定的噪声数据生成图像,以对比经过不同epoch训练后的生成器的生成能力
                    fake_img = G(fixed_noise).reshape(-1, 1, 28, 28)
                    real_img = img.reshape(-1, 1, 28, 28)

                    # make_grid的作用是将若干幅图像拼成一幅图像
                    img_grid_fake = torchvision.utils.make_grid(fake_img, normalize=True)
                    img_grid_real = torchvision.utils.make_grid(real_img, normalize=True)

                    writer.add_image("Fake Images", img_grid_fake, global_step=epoch)
                    writer.add_image("Real Images", img_grid_real, global_step=epoch)
                    writer.add_scalar(tag="lossD", scalar_value=lossD, global_step=epoch)
                    writer.add_scalar(tag="lossG", scalar_value=lossG, global_step=epoch)

二、生成结果

2.1 loss的变化

使用 tensorboard可视化,生成器和判别器的loss变化如下:
在这里插入图片描述
这里训练了200个epoch,每个epoch保存了一次loss。按照之前每个batch保存一次loss的结果来看,在训练100个epoch左右时,生成器和判别器的loss达到平衡,可以视为收敛,之后模型过拟合了。

2.2 生成的虚假图像的变化

使用相同的噪声生成图像,以观测经过不同epoch训练后的生成器的生成能力(以假乱真能力):

epoch=3:

在这里插入图片描述
epoch=20:

在这里插入图片描述
epoch=53:

在这里插入图片描述
epoch=141:
在这里插入图片描述
epoch=199:

在这里插入图片描述

三、不足之处

程序还有很多不足之处:

(1)程序实现的是最早的GAN版本,生成器是一个MLP(多层感知机)而不是神经网络,因此特征提取和生成能力较差。

(2)图像的生成效果与超参数设置有很大关系,如学习率的设置(包括学习率的演化策略)、训练次数、随机噪声的维度,甚至数据集的归一化参数(transforms.Normalize((0.5), (0.5)))都会对生成效果产生一定影响。

(3)理论上损失函数只要能够适用于二分类即可,如MSE,但一般使用BCE。有一种观点认为BCE的形式与GAN的理论代价函数是一致的,二者可以互推,可以参考 GAN网络概述及LOSS函数详解

相关文章:

  • 七雄争霸武将技能搭配
  • 利用Python进行数据分析-Numpy入门基础知识
  • QML的Popup遇到的坑
  • 解数独 视频讲解 c++
  • kubernetes 网络
  • 运维流程化和标准化
  • LeetCode104. 二叉树的最大深度和N叉树的最大深度
  • Games104 引擎工具链笔记
  • 如何梳理当天的事情?
  • 【历年NeurIPS论文下载】一文了解NeurIPS国际顶会(含NeurIPS2022)
  • 《JVM学习笔记》字节码基础
  • Java 学习 --SpringBoot 常用注解详解
  • 基于springboot网上书城系统
  • Java项目:JSP药店药品商城管理系统
  • app启动流程
  • ----------
  • 【React系列】如何构建React应用程序
  • Android交互
  • CentOS从零开始部署Nodejs项目
  • go语言学习初探(一)
  • HTTP传输编码增加了传输量,只为解决这一个问题 | 实用 HTTP
  • WinRAR存在严重的安全漏洞影响5亿用户
  • Zsh 开发指南(第十四篇 文件读写)
  • 那些被忽略的 JavaScript 数组方法细节
  • 排序算法学习笔记
  • 前端知识点整理(待续)
  • 微服务核心架构梳理
  • 一道闭包题引发的思考
  • 原生 js 实现移动端 Touch 滑动反弹
  • AI又要和人类“对打”,Deepmind宣布《星战Ⅱ》即将开始 ...
  • 关于Kubernetes Dashboard漏洞CVE-2018-18264的修复公告
  • ​DB-Engines 11月数据库排名:PostgreSQL坐稳同期涨幅榜冠军宝座
  • ​io --- 处理流的核心工具​
  • ​无人机石油管道巡检方案新亮点:灵活准确又高效
  • (HAL库版)freeRTOS移植STMF103
  • (第61天)多租户架构(CDB/PDB)
  • (二)c52学习之旅-简单了解单片机
  • (二)学习JVM —— 垃圾回收机制
  • (附源码)计算机毕业设计SSM疫情社区管理系统
  • (三)centos7案例实战—vmware虚拟机硬盘挂载与卸载
  • (一)基于IDEA的JAVA基础1
  • (中等) HDU 4370 0 or 1,建模+Dijkstra。
  • .NET : 在VS2008中计算代码度量值
  • .Net Core/.Net6/.Net8 ,启动配置/Program.cs 配置
  • .Net Framework 4.x 程序到底运行在哪个 CLR 版本之上
  • .NET 同步与异步 之 原子操作和自旋锁(Interlocked、SpinLock)(九)
  • .NET/C# 在 64 位进程中读取 32 位进程重定向后的注册表
  • .net图片验证码生成、点击刷新及验证输入是否正确
  • .pub是什么文件_Rust 模块和文件 - 「译」
  • [ Algorithm ] N次方算法 N Square 动态规划解决
  • [ 云计算 | AWS ] 对比分析:Amazon SNS 与 SQS 消息服务的异同与选择
  • [23] GaussianAvatars: Photorealistic Head Avatars with Rigged 3D Gaussians
  • [Android]RecyclerView添加HeaderView出现宽度问题
  • [C++从入门到精通] 14.虚函数、纯虚函数和虚析构(virtual)
  • [HCTF 2018]WarmUp (代码审计)