当前位置: 首页 > news >正文

FLOW MATCHING FOR GENERATIVE MODELING 阅读笔记

Flow Matching (FM)是一种训练连续标准化流Continuous Normalizing Flow (CNF)的方法。
FM可以训练扩散路径,作者发现基于FM训练扩散路径的方法训练扩散模型更稳定。

核心的思想是把无条件估计问题的转换为有条件的问题的来学习。这个思想在其他工作也出现了,本文作者正是从denoised score matching得到的启发。

We first show that we can construct such target vector fields through per-example (i.e., conditional) formulations. Then, inspired by denoising score matching, we show that a per-example training objective, termed Conditional Flow Matching (CFM), provides equivalent gradients and does not require explicit knowledge of the intractable target vector field.

连续标准化流

数据点 x ∈ R d \pmb x \in \mathbb R^d xRd,时变概率密度路径 p : [ 0 , 1 ] × R d → R > 0 p:[0,1] \times \mathbb R^d \rightarrow \mathbb R_{>0} p:[0,1]×RdR>0,时变向量场 v t : [ 0 , 1 ] × R d → R d v_t:[0,1] \times \mathbb R^d \rightarrow \mathbb R^d vt:[0,1]×RdRd
流flow把一个分布映射成另一个分布,可以通过常微分方程用 v t v_t vt构建flow ϕ : [ 0 , 1 ] × R d → R d \phi:[0,1] \times \mathbb R^d \rightarrow \mathbb R^d ϕ:[0,1]×RdRd
d ϕ t ( x ) d t = v t ( ϕ t ( x ) ) ϕ 0 ( x ) = x (1) \frac{d\phi_t(\pmb x)}{dt}=v_t(\phi_t(\pmb x)) \tag{1} \\ \phi_0(\pmb x)=\pmb x dtdϕt(x)=vt(ϕt(x))ϕ0(x)=x(1)时变向量场可以用神经网络 v t ( x ; θ ) v_t(\pmb x; \theta) vt(x;θ)来建模,这样构建的flow ϕ t \phi_t ϕt叫做连续标准化流(Continuous Normalizing Flow,CNF)。CNF通常用于把一个简单的分布 p 0 p_0 p0变成一个复杂的分布 p 1 p_1 p1,其符合push-forward方程:
p t ( x ) = [ ϕ t ] ⋆ p 0 ( x ) = p 0 ( ϕ t − 1 ( x ) ) det ⁡ [ ∂ ϕ t − 1 ∂ x ( x ) ] p_t(x)=[\phi_t]_\star p_0(x)=p_0(\phi_t^{-1}(x))\det[\frac{\partial \phi_t^{-1}}{\partial x}(x)] pt(x)=[ϕt]p0(x)=p0(ϕt1(x))det[xϕt1(x)]我们的目标是采样服从复杂目标分布的样本,方法是首先随机采样服从简单分布的噪声样本 x ∼ N ( 0 , I ) \pmb x \sim \mathcal N (\pmb 0, \pmb I) xN(0,I),然后使用ODE求解器在区间 t ∈ [ 0 , 1 ] t \in [0, 1] t[0,1]上使用训练得到的向量场 v t v_t vt求解方程(1)得到服从目标分布的样本 ϕ 1 ( x ) \phi_1(\pmb x) ϕ1(x)。所以主要的问题是如何学习 v t ( x ; θ ) v_t(\pmb x; \theta) vt(x;θ)

Flow Matching(FM)

x 1 \pmb x_1 x1表示服从未知的目标分布 q ( x 1 ) q(\pmb x_1) q(x1)的随机变量,我们不知道 q ( x 1 ) q(\pmb x_1) q(x1)的密度函数,但可以获得服从 q ( x 1 ) q(\pmb x_1) q(x1)的样本。用 p t p_t pt表示概率密度路径, p 0 p_0 p0服从标准高斯分布, p 1 p_1 p1近似 q q q
Flow Matching的训练目标是学习 v t v_t vt,损失函数是 L F M ( θ ) = E t , p t ( x ) ∥ v t ( x ; θ ) − u t ( x ) ∥ 2 \mathcal L_{FM}(\theta)=\mathbb E_{t,p_t(\pmb x)}\|v_t(\pmb x; \theta)-u_t(\pmb x)\|^2 LFM(θ)=Et,pt(x)vt(x;θ)ut(x)2流匹配的损失函数很简单,但在实践中没法使用,因为我们不知道如何定义合适的 p t p_t pt u t u_t ut

Conditional Flow Matching(CFM)

为了解决上面的问题,考虑条件流匹配。条件流匹配的损失函数是 L C F M ( θ ) = E t , q ( x 1 ) , p t ( x ∣ x 1 ) ∥ v t ( x ; θ ) − u t ( x ∣ x 1 ) ∥ 2 \mathcal L_{CFM}(\theta)=\mathbb E_{t,q(\pmb x_1),p_t(\pmb x|\pmb x_1)}\|v_t(\pmb x; \theta)-u_t(\pmb x|\pmb x_1)\|^2 LCFM(θ)=Et,q(x1),pt(xx1)vt(x;θ)ut(xx1)2与流匹配的目标不同,条件流匹配的目标允许我们轻松地对无偏估计进行采样,只要我们可以从 p t ( x ∣ x 1 ) p_t(\pmb x|\pmb x_1) pt(xx1) 有效地采样并计算 u t ( x ∣ x 1 ) u_t(\pmb x|\pmb x_1) ut(xx1),这两者都可以很容易地完成,因为它们是对每个样本定义的。
论文中证明了优化CFM目标等价于优化FM目标(从期望的角度)。所以,剩下的问题是如何设计合适的条件概率路径 p t ( x ∣ x 1 ) p_t(\pmb x|\pmb x_1) pt(xx1)和向量场 u t ( x ∣ x 1 ) u_t(\pmb x|\pmb x_1) ut(xx1)

条件概率路径和条件向量场

上面的讨论是通用的,并没有规定条件概率路径和条件向量场的形式。为了简单,作者讨论的是高斯条件概率路径:
p t ( x ∣ x 1 ) = N ( x ∣ μ t ( x 1 ) , σ t ( x 1 ) 2 I ) p_t(\pmb x|\pmb x_1)=\mathcal N(\pmb x| \mu_t(\pmb x_1), \sigma_t(\pmb x_1)^2\pmb I) pt(xx1)=N(xμt(x1),σt(x1)2I)其中 μ 0 ( x 1 ) = 0 \mu_0(\pmb x_1)=0 μ0(x1)=0 σ 0 ( x 1 ) = 1 \sigma_0(\pmb x_1)=1 σ0(x1)=1 μ 1 ( x 1 ) = x 1 \mu_1(\pmb x_1)=\pmb x_1 μ1(x1)=x1 σ 1 ( x 1 ) = σ min ⁡ \sigma_1(\pmb x_1)=\sigma_{\min} σ1(x1)=σmin
有无数的向量场可以产生给定的概率路径,这里作者讨论的是最简单的典型变换。
考虑条件flow:
ψ t ( x ) = σ t ( x 1 ) x + μ t ( x 1 ) \psi_t(\pmb x)= \sigma_t(\pmb x_1)\pmb x + \mu_t(\pmb x_1) ψt(x)=σt(x1)x+μt(x1)对应的条件向量场可以通过求解方程得到,并有封闭解:
u t ( x ∣ x 1 ) = σ t ′ ( x 1 ) σ t ( x 1 ) ( x − μ t ( x 1 ) ) + μ t ′ ( x 1 ) u_t(\pmb x|\pmb x_1)=\frac{\sigma'_t(\pmb x_1)}{\sigma_t(\pmb x_1)}(x-\mu_t(\pmb x_1))+\mu'_t(\pmb x_1) ut(xx1)=σt(x1)σt(x1)(xμt(x1))+μt(x1)优化的损失函数是 L C F M ( θ ) = E t , q ( x 1 ) , p ( x 0 ) ∥ v t ( ψ t ( x 0 ) ; θ ) − u t ( ψ t ( x 0 ) ∣ x 1 ) ∥ 2 \mathcal L_{CFM}(\theta)=\mathbb E_{t,q(\pmb x_1),p(\pmb x_0)}\|v_t(\psi_t(\pmb x_0); \theta)-u_t(\psi_t(\pmb x_0)|\pmb x_1)\|^2 LCFM(θ)=Et,q(x1),p(x0)vt(ψt(x0);θ)ut(ψt(x0)x1)2

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • C++ primer plus 第17 章 输入、输出和文件:用cout进行格式化
  • Hibernate Validator 数据校验框架
  • 【从零开始一步步学习VSOA开发】创建VSOA的client端
  • poetry配置镜像
  • 【秋招笔试】2024-08-03-科大讯飞秋招笔试题(算法岗)-三语言题解(CPP/Python/Java)
  • DREAMLLM: SYNERGISTIC MULTIMODALCOMPREHENSION AND CREATION
  • C语言基础题:吃冰棍(C语言版)
  • Android笔试面试题AI答之Activity常见考点
  • AI智能测评应用平台项目分享
  • 数值分析——埃尔米特(Hermit)插值
  • Apple在Swift中引入同态加密
  • Stable Diffusion 官方模型V1.5版本下载
  • LLM - 理解 主流大模型 LLM 使用 Decoder Only 架构 (8点)
  • 回顾前面刷过的算法(4)
  • HanLP和Jieba区别
  • CentOS 7 修改主机名
  • iOS | NSProxy
  • JavaSE小实践1:Java爬取斗图网站的所有表情包
  • Js基础知识(一) - 变量
  • OpenStack安装流程(juno版)- 添加网络服务(neutron)- controller节点
  • sessionStorage和localStorage
  • sublime配置文件
  • Tornado学习笔记(1)
  • vue.js框架原理浅析
  • 动态规划入门(以爬楼梯为例)
  • 给第三方使用接口的 URL 签名实现
  • 构建二叉树进行数值数组的去重及优化
  • 前端路由实现-history
  • 前端自动化解决方案
  • 融云开发漫谈:你是否了解Go语言并发编程的第一要义?
  • 资深实践篇 | 基于Kubernetes 1.61的Kubernetes Scheduler 调度详解 ...
  • ​ubuntu下安装kvm虚拟机
  • # 再次尝试 连接失败_无线WiFi无法连接到网络怎么办【解决方法】
  • #include到底该写在哪
  • (1/2)敏捷实践指南 Agile Practice Guide ([美] Project Management institute 著)
  • (145)光线追踪距离场柔和阴影
  • (23)Linux的软硬连接
  • (android 地图实战开发)3 在地图上显示当前位置和自定义银行位置
  • (NO.00004)iOS实现打砖块游戏(十二):伸缩自如,我是如意金箍棒(上)!
  • (动态规划)5. 最长回文子串 java解决
  • (三)docker:Dockerfile构建容器运行jar包
  • (三)终结任务
  • (一)VirtualBox安装增强功能
  • (转)清华学霸演讲稿:永远不要说你已经尽力了
  • .CSS-hover 的解释
  • .NET 4.0中的泛型协变和反变
  • .net core + vue 搭建前后端分离的框架
  • .NET 直连SAP HANA数据库
  • .NET(C#、VB)APP开发——Smobiler平台控件介绍:Bluetooth组件
  • .NET程序集编辑器/调试器 dnSpy 使用介绍
  • .Net程序猿乐Android发展---(10)框架布局FrameLayout
  • .NET中GET与SET的用法
  • .NET中的十进制浮点类型,徐汇区网站设计
  • .net中生成excel后调整宽度
  • .php文件都打不开,打不开php文件怎么办