当前位置: 首页 > news >正文

Gan生成手写数字

一、GAN

生成对抗网络(Generative Adversarial Network,GAN)在2014年被Ian Goodfellow等人首次提出,此后迅速流行,成为热门的深度学习模型。Gan能生成非常逼真的图像、画作、音乐,近些年不乏有利用Gan生成的画作斩获大奖的情况,如前段时间引起争议的由AI生成的《空间歌剧院》在科罗拉多州博览会(Colorado State Fair)的美术比赛中,获得了第一名,如下图所示。
在这里插入图片描述

图1 利用AI生成的画作《空间歌剧院》

尽管借助AI的力量去和人类竞争,目前仍存在很大的争议,但不可否认的是AI已经在这一领域展现了很强大的应用前景。
GAN的原理非常简单,其利用两个模型,其中一个不断生成“假”数据,另外一个模型判断前一个生成的“假”数据,如果能够骗过判别模型,则说明生成了可以以假乱真的数据。

二、GAN的步骤

在GAN中,A是生成器,负责生成“假”数据,B是判别器,负责判断A生成的数据质量,其是一个博弈的过程。
生成器: 接受一个随机噪声向量x作为输入,生成一个张量G(x)
判别器: 接受一个张量作为输入,输出其真假
以图像为例,GAN的整个训练过程如下:
(1) 生成器接受随机噪声,并生成假图像
(2)判别器接受假图像和真图像组合的数据,学习如何判别真假图像
(3)生成器生成新的图像,并使用判别器来判别真假,同时通过判别器结果来判别此次造假的的水平。
(4)重复步骤(1)~(3)

三、生成器

原则上讲生成器并无特定的模型,只要能够生成图像的模型即可,但目前考虑到模型的训练一般选择神经网络,因为可以和判别器一同训练。生成器负责生成一副图片,当然此时的图片为噪声,类似于下图,细看啥也不是,但这不重要,因为GAN中,生成器不需要任何真数据!,是的,你没看错,不管它生成的是什么样的数据,它的老师判别器会告诉他这副图的真假,换句话讲,下图太假了,老师一眼就辨别出来了。
在这里插入图片描述

图2 生成器生成的图片

生成器的代码如下:

import matplotlib.pyplot as plt
from tensorflow import keras
import numpy as np
import tqdm
from IPython.display import clear_output

L = keras.layers
LATENT_DIM = 100  # 潜在空间的维度
IMAGE_SHAPE = (28, 28, 1)  # 输出图像的尺寸

# 生成器
generate_net = [
    L.Input(shape=(LATENT_DIM, )),
    L.Dense(256),
    L.LeakyReLU(alpha=0.2),
    L.BatchNormalization(momentum=0.8),
    L.Dense(512),
    L.LeakyReLU(alpha=0.2),
    L.BatchNormalization(momentum=0.8),
    L.Dense(1024),
    L.LeakyReLU(alpha=0.2),
    L.BatchNormalization(momentum=0.8),
    L.Dense(np.prod(IMAGE_SHAPE), activation='tanh'),
    L.Reshape(IMAGE_SHAPE)
]
generate = keras.models.Sequential(generate_net)
generate.summary()

注意
1、我们用到LeakyReLU激活函数,这是GAN中常用的激活函数
2、归一化方式常用的有Batch Normalization(BN),Instance Normalization(IN),Spectral Normalization(SN)
3、生成器的最后一层激活函数一般用tanh函数

四、判别器

GAN中判别器的作用就是判断生成的数据的水平,要判断真假,所以要先训练判别器,类似于你要预测房价,肯定要先去学习(拟合)房价数据,因此,GAN中判别器的训练中一部分是真实数据,这部分考虑到大家下载复制复现代码方便,我们用MNIST数据集,MNIST数据集可以通过tensorflow直接下载,比较方便。有真数据还不行,那必须喂给模型假数据,不然怎么学习真假对吧,那假数据从那来呢?对!生成器,我们生成器不正好可以生成假数据嘛,在这种思路下,我们就可以构造我们的判别器了。
mnist数据集

图2 MNIST数据集

判别器的代码如下:

# 判别器
discriminator_net = [
    L.Input(shape=IMAGE_SHAPE),
    L.Flatten(),
    L.Dense(512),
    L.LeakyReLU(alpha=0.2),
    L.Dense(256),
    L.LeakyReLU(alpha=0.2),
    L.Dense(1, activation='sigmoid')
]
discriminator = keras.models.Sequential(discriminator_net)
discriminator.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(0.0002, 0.5), metrics=['accuracy'])
discriminator.summary()

五、生成对抗模型

前面我们已经得知GAN由生成器和判别器两部分构成,且我们已经搭建好了生成器和判别器,把他们一组和即是GAN,但为什么要分开搭建,原因是GAN的训练是不断迭代的一个过程,要分开训练生成器和判别器,生成器生成的好不好要判别器判断,此时应该是冻结判别器的权重,只更新生成器的权重,因为生成器的目标是不断提升“造假”的能力。
GAN的模型如下

adversarial_net = generate_net + discriminator_net
# 冻结判别器的权重
for layer in discriminator_net:
    layer.trainable = False

adversarial = keras.models.Sequential(adversarial_net)
adversarial.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(0.0002, 0.5), metrics=['accuracy'])
adversarial.summary()

六、GAN 的训练

GAN的训练是一个不断调试不断优化的过程,需要大量的经验,推荐一些新手在训练GAN时看一些其他博主建议的小技巧——训练GAN的小技巧

首先要明白,GAN的训练是一个交替迭代的过程

前面我们提到生成器最开始的造假水平很垃圾,只是随便给一副噪声图,这部分假数据我们赋予标签0,代表假数据,然后和真实数据集,也就是MNIST,进行组合,形成一个二分类的图像数据集,来训练我们的判别器,训练结束后我们的判别器已经具备了分辨真假数据的能力(二分类)

判别器训练好后,我们开始训练生成器

生成器我们希望生成非常逼真的数据,但生成器生成的好坏,换句话说生成的数据要让判别器判断为1,也就是以假乱真的水平才行,所以一般来说训练生成器的时候其实是生成器+判别器的串接网络,只不过判别器的权重被冻结,类似于只训练生成器

此时我们用生成器生成一副图片,但我们要把这幅图的标签设置为1,也就是真,别急!!!,你没看错,就是要赋予标签1,这幅图被我们的判别器判断后输出0,也就是假,这样前后形成了非常大的误差,明明是假图,生成器却说真,此时生成器就会拼命的调整参数,直到判别器判断为真!这个过程标准的说法就是反向传播,对生成器网络的参数进行大更新!等到后续生成器能够产生出逼真的图片时,反向传播对生成器的参数就是微调,不断优化的一个过程!

交替训练判别器和生成器即可实现GAN的训练

训练代码如下:

# 数据可视化
def sample_images(batch):
    rows, columns = 3, 10
    sample_count = rows * columns
    plt.figure(figsize=(columns, rows))
    # 使用生成器生成图像
    noise = np.random.normal(0, 1, (sample_count, LATENT_DIM))
    gen_imgs = generate.predict(noise)
    # 生成器图像张量的范围从【-1,1】改为【0,1】
    gen_imgs = 0.5 * gen_imgs + 0.5

    index = 0
    for row in range(rows):
        for column in range(columns):
            image = np.reshape(gen_imgs[index], [28, 28])
            plt.subplot(rows, columns, index+1)
            plt.imshow(image, cmap='gray')
            plt.axis('off')
            index += 1
    plt.tight_layout()
    plt.show()
    return gen_imgs

# 训练
def train(batch=30000, batch_size=32):
    # 读取数据,无需标签
    (image_set, _), (_, _) = keras.datasets.mnist.load_data()
    # 数据归一化
    image_set = image_set / 127.5 - 1.
    # 数据格式转换
    image_set = image_set.reshape(len(image_set), 28, 28, 1)
    # 准备batch_size同样大小的真假标签
    valid = np.ones((batch_size))
    fake = np.zeros((batch_size))
    # 利用tqdm生成迭代器
    batch_list = tqdm.trange(batch)
    for batch in batch_list:
        #  生成器生成图像
        idx = np.random.randint(0, image_set.shape[0], batch_size)
        imgs = image_set[idx]

        # 生成噪声数据并作为生成器的输入
        noise = np.random.normal(0, 1, (batch_size, LATENT_DIM))
        # 使用生成器生成图像
        gen_imgs = generate.predict(noise)

        # 训练判别器
        d_state_real = discriminator.train_on_batch(imgs, valid)
        d_state_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_state = 0.5 * np.add(d_state_real, d_state_fake)

        # 训练生成器
        adv_state = adversarial.train_on_batch(noise, valid)

        # 更新进度条后缀文本,用于输出训练进度
        state = f"[D loss:{d_state[0]:.4f} acc: {d_state[1]:.4f}]" \
                f"[A loss:{adv_state[0]:.4f} acc: {adv_state[1]:.4f}"
        batch_list.set_postfix(state=state)
        if (batch + 1) % 50 == 0:
            clear_output(wait=True)
            _ = sample_images(batch)

train()

训练过程中我们设定每个50个batch输出生成器生成的图片
初始图片
在这里插入图片描述
迭代2000次后生成的图片
在这里插入图片描述
迭代10000次生成的图片
在这里插入图片描述
迭代30000次生成的图片
在这里插入图片描述
整个训练过程的GIF图如下:

在这里插入图片描述

七、GAN的注意事项

GAN的训练过程是一个动态过程,每个批次是新的开始,不会有简单的梯度下降过程,而是一个不断对抗平衡的过程,类似于minimax,我们要的是最小的判别器损失,最大的生成器误差,因此GAN的训练需要一些技巧,比如
1、一开始无需分类精度很高的判别器
2、初始学习率要小,否则下降过快或者Max过大不利于GAN的拟合
3、生成器和迭代器无需训练相同次数,比如可以生成器训练1次,判别器训练5次
4、迭代次数需要不断微调,迭代次数过小有可能生成的图像效果一般,迭代次数过大也会导致生成的图像效果一般,很多人会疑问迭代次数过大为什么会导致生成的图像效果一般,因为判别器每次训练更新权重用到的是生成器生成的假数据和真实数据,后期生成器生成的数据已经非常逼真了,而判别器学习到仍然判定为假,因此反而会导致生成器又开始生成很假的数据,如该博主利用GAN生成动漫头像文章点击这,迭代200次的图片如下
在这里插入图片描述
当其迭代750次后出现了上面提到的问题,由于生成器生成了图像质量非常高,但判别器仍然判定为假,导致生成器开始产生反向作用,如下所示
在这里插入图片描述

八、GAN的评价

GAN的评价一直是一个难题,早期人们通过肉眼判定生成的图像质量,但不可否认的是这种评价方式明显存在缺陷,2016年来,GAN的评价方式开始如雨后春笋般展现,目前比较流行的是:

1、Inception Score

Inception Score(IS)通过利用谷歌图像分类模型nception Net来衡量模型生成图像的清晰度和多样性,Inception Score越高,表示模型越好。

2、Frechet Inception距离

Frechet Inception距离(FID)通过对比真实样本和生成样本在Inception V3模型上的抽象特征的差异来评估生成样本和真实样本的差异,FID越小,表示模型越小。

九、其他

生成对抗网络有很多种,如GAN 、ACGAN、DCGAN、Pix2Pix等。

要查看我的其他博客,点击这里

相关文章:

  • 基于Springboot+Vue开发前后端端分离农产品进销存系统
  • poi-tl 用word模板生成报告
  • leveldb-FilterBlock实现
  • 关于移动端H5获取微信非静默授权被拦截进入【微信快照页】问题及解决方案
  • token和JWT token区别、登录安全、页面权限、数据权限、单点登录
  • Liteos信号量的使用
  • 基于Verilog搭建一个卷积运算单元的简单实现
  • pytorch-实现mnist手写数字识别(彩色)
  • C/C++语言100题练习计划 99——找第一个只出现一次的字符
  • Go使用Gin+mysql实现增删改查
  • PIE-Engine:房山区洪涝灾害风险评价
  • 【我的渲染技术进阶之旅】如何编译Filament的windows版本程序?
  • 03 C++ 字符串、向量和数组
  • python 代码 C 执行
  • 字节外包凭借【ui自动化测试框架】成功进入内部编制
  • 【mysql】环境安装、服务启动、密码设置
  • 11111111
  • 3.7、@ResponseBody 和 @RestController
  • Apache Pulsar 2.1 重磅发布
  • go append函数以及写入
  • HTML5新特性总结
  • Mithril.js 入门介绍
  • Sass Day-01
  • Spring Cloud中负载均衡器概览
  • UMLCHINA 首席专家潘加宇鼎力推荐
  • 后端_ThinkPHP5
  • 入口文件开始,分析Vue源码实现
  • 原创:新手布局福音!微信小程序使用flex的一些基础样式属性(一)
  • 源码安装memcached和php memcache扩展
  • 湖北分布式智能数据采集方法有哪些?
  • #define用法
  • #pragma once与条件编译
  • $分析了六十多年间100万字的政府工作报告,我看到了这样的变迁
  • (¥1011)-(一千零一拾一元整)输出
  • (cljs/run-at (JSVM. :browser) 搭建刚好可用的开发环境!)
  • (vue)el-checkbox 实现展示区分 label 和 value(展示值与选中获取值需不同)
  • (顺序)容器的好伴侣 --- 容器适配器
  • (转)shell中括号的特殊用法 linux if多条件判断
  • (状压dp)uva 10817 Headmaster's Headache
  • .NET CORE 第一节 创建基本的 asp.net core
  • .NET 设计模式初探
  • .NET 使用配置文件
  • .NetCore Flurl.Http 升级到4.0后 https 无法建立SSL连接
  • .net反混淆脱壳工具de4dot的使用
  • .net访问oracle数据库性能问题
  • @DependsOn:解析 Spring 中的依赖关系之艺术
  • @html.ActionLink的几种参数格式
  • @JSONField或@JsonProperty注解使用
  • @value 静态变量_Python彻底搞懂:变量、对象、赋值、引用、拷贝
  • [ 云计算 | AWS ] AI 编程助手新势力 Amazon CodeWhisperer:优势功能及实用技巧
  • [100天算法】-二叉树剪枝(day 48)
  • [202209]mysql8.0 双主集群搭建 亲测可用
  • [C#]C# winform部署yolov8目标检测的openvino模型
  • [C]编译和预处理详解
  • [Codeforces] combinatorics (R1600) Part.2