当前位置: 首页 > news >正文

MindSpore深度概率推断算法与概率模型

上篇文章对MindSpore深度概率学习进行了背景和总体特性上的介绍,链接戳这里:

https://zhuanlan.zhihu.com/p/234931176

本篇文章会介绍深度概率学习的第二部分:深度概率推断算法与概率模型,并在MindSpore上进行代码的实践。

 

概率推断,也可以叫做贝叶斯推理,是概率统计中一个重要的问题。我们先从最基础的贝叶斯公式出发:

P(A|B)称为条件概率,描述的是在事件B已经发生的条件下,A事件发生的概率。这其实就可以理解为推断问题,上面的公式我们可以这样来理解:已知概率模型中的随机变量B,计算其他部分随机变量A的后验概率。

我们来重新理解一下机器学习,通常可以从两种角度来看,一个是优化问题,一个是积分问题。

首先我们来看一下优化问题,以SVM为例,其是一个很典型的优化问题,我们的优化目标:

 

上面也提到了推断本质上就是求后验分布,当然概率模型往往会比较复杂,涉及的变量会非常多,若直接精确求解,计算开销太大,因此需要一些近似计算的方法。

本篇文章主要讨论基于近似的变分推理(Variational Inference)方法,来解决概率推断问题。

l 什么是变分推理?

变分推理是什么呢?为了让大家通俗的理解,这里参考 @过小咩 的回答。

深蓝色的分布是我们的原始目标p,不好求。它看上去有点像高斯,那我们尝试从高斯分布中找一个红q和一个绿q,分别计算一下p和他们重叠部分面积,选更像p的q作为p的近似分布。

简单来说,为了求解p,但求解起来复杂,寻找容易求解的q,使得p和q尽可能接近,这个接近可以用KL散度来衡量。 

看到这里,大家应该对变分推断的核心步骤比较了解,其中涉及到的几个主要概念也已经提及,更多详细的,大家可以参考Variational Inference: A Review for Statisticians这篇论文。

l 概率模型之变分自编码器

变分推断最经典应用的概率模型就是变分自编码器(Variational Auto-Encoders),它是由 Kingma 等人于 2014 年提出的基于变分贝叶斯(Variational Bayes)推断的生成式网络结构。与传统的自编码器不同,它以概率的方式描述对潜在空间的观察。

直接看图,可能大家更好理解些。

变分自编码器简单来说主要分为三部分,编码器、解码器和隐向量空间。编码器将输入向量压缩成隐向量特征,但这个隐向量特征用概率分布来表示,比如常见的为正态分布,可以定义为隐向量空间,当从隐向量空间解码时,我们从中随机采样,生成一个向量作为解码器的输入。之后解码器进行解码,得到输出向量。与传统的自编码器相比,我们可以生成一些新图片,更具有多样性。

讲到这里大家可能觉得这和变分推断有啥关系呢,马上介绍。

l 变分自编码器之“变分”

变分自编码器需要在模型的准确率上和隐向量空间服从正态分布做一个权衡,即输入图片和生成图片之间的loss以及隐向量空间和正态分布的相似程度,大家应该知道了,这个相似程度用上文中讲的KL散度来度量,另一个loss可以用均方根误差来衡量。我们知道正态分布包含均值和方差两个参数,如何来获得这两个参数呢?采用神经网络来进行拟合,这样每个输入x都有一个专属的正态分布来进行采样生成新图片。

大家如果直接去看变分自编码器的论文,可能没有明确说变分法的概念,但其实就像刚刚讲到的,在隐向量空间的相似度量中,就用到了变分的推导和性质。

上面讲的VAE是无监督训练的,还有一种VAE的变体,叫做Conditional VAE。区别在于加入了标签的信息,可以控制来生成特定类型的图片,更多的细节大家可以参考Learning Structured Output Representation using Deep Conditional Generative Models这篇论文。

Talk is cheap, show me the code.

接下来想和大家讲讲在MindSpore深度概率学习库中,我们是如何构造VAE等概率网络,并通过变分推断来求解的。

首先,我们在mindspore.nn.probability.dpn下面实现了两类基本的接口:VAE和ConditionalVAE,前者为无监督模型,后者为有监督模型。dpn是deep probability network的简称,后续也会支持更多的深度概率网络。

VAE

根据上文的介绍,首先,我们需要先自定义encoder和decoder,调用mindspore.nn.probability.dpn.VAE接口来构建VAE网络,我们除了传入encoder和decoder之外,还需要传入encoder输出变量的维度hidden size,以及VAE网络存储潜在变量的维度latent size,一般latent size会小于hidden size。

import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.nn.probability.dpn import VAE
 
IMAGE_SHAPE = (-1, 1, 32, 32)
 
class Encoder(nn.Cell):
  def __init__(self):
    super(Encoder, self).__init__()
    self.fc1 = nn.Dense(1024, 800)
    self.fc2 = nn.Dense(800, 400)
    self.relu = nn.ReLU()
    self.flatten = nn.Flatten()
 
  def construct(self, x):
    x = self.flatten(x)
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    x = self.relu(x)
    return x
 
class Decoder(nn.Cell):
  def __init__(self):
    super(Decoder, self).__init__()
    self.fc1 = nn.Dense(400, 1024)
    self.sigmoid = nn.Sigmoid()
    self.reshape = P.Reshape()
 
  def construct(self, z):
    z = self.fc1(z)
    z = self.reshape(z, IMAGE_SHAPE)
    z = self.sigmoid(z)
    return z
 
encoder = Encoder()
decoder = Decoder()
vae = VAE(encoder, decoder, hidden_size=400, latent_size=20)

调用ELBO接口:

mindspore.nn.probability.infer.ELBO

来定义VAE网络的损失函数,这里需要传入的参数分别是隐变量空间和输出数据的先验分布,调用WithLossCell封装VAE网络和损失函数,并定义优化器,例如为Adam优化器,之后传入SVI接口(mindspore.nn.probability.infer.SVI)。SVI的run函数可理解为VAE网络的训练,可以指定训练的epochs,返回结果为训练好的网络;get_train_loss函数可以返回训练好后模型的loss。

from mindspore.nn.probability.infer import ELBO, SVI
 
net_loss = ELBO(latent_prior='Normal', output_prior='Normal')
net_with_loss = nn.WithLossCell(vae, net_loss)
optimizer = nn.Adam(params=vae.trainable_params(), learning_rate=0.001)
 
vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
vae = vi.run(train_dataset=ds_train, epochs=10)
trained_loss = vi.get_train_loss()

最后,得到训练好的VAE网络后,我们可以使用vae.generate_sample生成新样本,需要传入待生成样本的个数,及生成样本的shape,shape需要保持和原数据集中的样本shape一样;当然,我们也可以使用vae.reconstruct_sample重构原来数据集中的样本,来测试VAE网络的重建能力。

IMAGE_SHAPE = (-1, 1, 32, 32)
generated_sample = vae.generate_sample(64, IMAGE_SHAPE)
for sample in ds_train.create_dict_iterator():
  sample_x = Tensor(sample['image'], dtype=mstype.float32)
  reconstructed_sample = vae.reconstruct_sample(sample_x)

完整的代码可以戳An example of VAE。

ConditionalVAE

类似地,ConditionalVAE与VAE的使用方法比较相近,不同的是,ConditionalVAE利用了数据集的标签信息,属于有监督学习算法,其生成效果一般会比VAE好。和VAE定义过程类似,先自定义encoder和decoder,并调用mindspore.nn.probability.dpn.ConditionalVAE接口来构建ConditionalVAE网络,这里的encoder和VAE的不同,因为需要传入数据集的标签信息;decoder和上述的一样。ConditionalVAE接口的传入则还需要传入数据集的标签类别个数,其余和VAE接口一样。之后的推理过程和VAE类似,这里就不再详述啦,完整代码戳An example of ConditionalVAE。

VAE-GAN

GAN网络大家应该都有听说过,其全称是生成对抗网络,主要结构包括一个生成器G(Generator)和一个判别器D(Discriminator),是一种博弈式的训练网络。

VAE-GAN的原理是用GAN来强化VAE,即将VAE网络的decoder作为GAN网络的Generator,后面接一个Discriminator,如下图所示。这样做有什么好处呢?上面有介绍VAE的reconstruction loss一般定义只定义MSE loss,这就会导致生成的图片比较模糊。VAE-GAN在VAE的基础上,引入了一个Discriminator,通过Discriminator来判别输入的图片是属于reconstruct之后的图片,还是属于真实的数据的图片,这样相当于额外增加了loss。

那么,基于VAE接口,我们可以很方便地构建VAE-GAN网络,自定义Encoder、Decoder(Generator)以及Discriminator。

class VaeGan(nn.Cell):
def __init__(self):
super(VaeGan, self).__init__()
self.E = Encoder()
self.G = Decoder()
self.D = Discriminator()
self.dense = nn.Dense(20, 400)
self.vae = VAE(self.E, self.G, 400, 20)
self.shape = P.Shape()
self.normal = C.normal
self.to_tensor = P.ScalarToArray()

def construct(self, x):
recon_x, x, mu, std = self.vae(x)
z_p = self.normal(self.shape(mu), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
z_p = self.dense(z_p)
x_p = self.G(z_p)
ld_real = self.D(x)
ld_fake = self.D(recon_x)
ld_p = self.D(x_p)
return ld_real, ld_fake, ld_p, recon_x, x, mu, std

损失函数的构造也十分方便,在原有ELBO的基础上,增加额外的loss部分。

class VaeGanLoss(ELBO):
def __init__(self):
super(VaeGanLoss, self).__init__()
self.zeros = P.ZerosLike()
self.mse = nn.MSELoss(reduction='sum')

def construct(self, data, label):
ld_real, ld_fake, ld_p, recon_x, x, mu, std = data
y_real = self.zeros(ld_real) + 1
y_fake = self.zeros(ld_fake)
loss_D = self.mse(ld_real, y_real)
loss_GD = self.mse(ld_p, y_fake)
loss_G = self.mse(ld_fake, y_real)
reconstruct_loss = self.recon_loss(x, recon_x)
kl_loss = self.posterior('kl_loss', 'Normal', self.zeros(mu), self.zeros(mu) + 1, mu, std)
elbo_loss = reconstruct_loss + self.sum(kl_loss)
return loss_D + loss_G + loss_GD + elbo_loss

之后,就可以进行模型的训练啦,完整代码戳An example of VAE-GAN。

本篇文章就到这里啦,这次主要分享了概率推断中的变分推断算法和变分自编码器及衍生模型,如果有不对之处欢迎大家批评指正哈。

l 参考文献:

[1] Blei D M, Kucukelbir A, McAuliffe J D. Variational inference: A review for statisticians [J]. Journal of the American Statistical Association, 2017, 112(518): 859-877.

[2] Kingma D P, Welling M. Auto-Encoding Variational Bayes[J]. stat, 2014, 1050: 10.

[3] yuxiangyu, 从自编码器到变分自编码器(其二)

[4] Goodfellow, Ian J., Pouget-Abadie, Jean, Mirza, Mehdi, Xu, Bing, Warde-Farley, David, Ozair, Sherjil, Courville, Aaron C., and Bengio, Yoshua. Generative adversarial nets. NIPS, 2014.

相关文章:

  • 热敏性聚N-乙烯基异丁酰胺(PNVIBA)/聚(N—乙烯基异丁酰胺)接枝聚苯乙烯微球的研究
  • Linux中的服务管理
  • 异步 PHP — 多进程、多线程和协程
  • 适用于90%网剧、网大的最新备案流程解析
  • 在PyG上构建自己的数据集
  • Docker部署Logstash 7.2.0
  • Nginx -- -- 配置SSL证书
  • DID革命:详解PoP、SBT和VC三种去中心化身份方案
  • Redis与Python交互
  • 算法基础: 位运算
  • 记录一次坑 | 包版本不一致产生的问题的排查过程
  • SmartX Everoute 如何通过微分段技术实现 “零信任” | 社区成长营分享回顾
  • “相信美好,即将发生”——天泽智云
  • 面试阿里技术专家岗,对答如流,这些面试题你能答出多少
  • Spring AOP与事务
  • [iOS]Core Data浅析一 -- 启用Core Data
  • 《微软的软件测试之道》成书始末、出版宣告、补充致谢名单及相关信息
  • 【140天】尚学堂高淇Java300集视频精华笔记(86-87)
  • 【399天】跃迁之路——程序员高效学习方法论探索系列(实验阶段156-2018.03.11)...
  • CentOS学习笔记 - 12. Nginx搭建Centos7.5远程repo
  • css布局,左右固定中间自适应实现
  • fetch 从初识到应用
  • happypack两次报错的问题
  • interface和setter,getter
  • JS变量作用域
  • JS实现简单的MVC模式开发小游戏
  • Linux各目录及每个目录的详细介绍
  • mysql外键的使用
  • Spring Cloud中负载均衡器概览
  • spring security oauth2 password授权模式
  • 分类模型——Logistics Regression
  • 基于webpack 的 vue 多页架构
  • 前端存储 - localStorage
  • 实战:基于Spring Boot快速开发RESTful风格API接口
  • 使用SAX解析XML
  • 小程序、APP Store 需要的 SSL 证书是个什么东西?
  • 白色的风信子
  • ​草莓熊python turtle绘图代码(玫瑰花版)附源代码
  • #14vue3生成表单并跳转到外部地址的方式
  • #设计模式#4.6 Flyweight(享元) 对象结构型模式
  • #使用清华镜像源 安装/更新 指定版本tensorflow
  • #我与Java虚拟机的故事#连载15:完整阅读的第一本技术书籍
  • (BFS)hdoj2377-Bus Pass
  • (二)fiber的基本认识
  • (二十一)devops持续集成开发——使用jenkins的Docker Pipeline插件完成docker项目的pipeline流水线发布
  • (附源码)计算机毕业设计SSM疫情下的学生出入管理系统
  • (附源码)计算机毕业设计大学生兼职系统
  • (七)理解angular中的module和injector,即依赖注入
  • (三)Pytorch快速搭建卷积神经网络模型实现手写数字识别(代码+详细注解)
  • (转)Linux下编译安装log4cxx
  • .bat批处理(四):路径相关%cd%和%~dp0的区别
  • .NET CORE 第一节 创建基本的 asp.net core
  • .NET Core Web APi类库如何内嵌运行?
  • .net core 微服务_.NET Core 3.0中用 Code-First 方式创建 gRPC 服务与客户端
  • .NET I/O 学习笔记:对文件和目录进行解压缩操作