当前位置: 首页 > news >正文

【论文翻译】结构化状态空间模型

文章目录

    • 3.2 对角结构化状态空间模型
      • 3.2.1 S4D:对角SSM算法
      • 3.2.2 完整应用实例
    • 3.3 对角化加低秩(DPLR)参数化
      • 3.3.1 DPLR 状态空间核算法
      • 3.3.2 S4-DPLR 算法和计算复杂度
      • 3.3.3赫尔维兹(稳定)DPLR形式

这篇文章是Mamba作者博士论文 MODELING SEQUENCES WITH STRUCTURED STATE SPACES
的第三章的部分翻译,为了解决计算上存在的代价问题,引入了结构化状态空间模型,介绍了对角结构化状态空间模型和低秩对角结构化状态空间模型。

3.2 对角结构化状态空间模型

为了解决SSM的计算瓶颈,我们使用一个允许我们变换和简化SSM的结构化结果。

Lemma 3.3 共轭是SSM的等价关系:
( A , B , C ) ∼ ( V − 1 AV , V − 1 B , CV ) (\textbf A, \textbf B, \textbf C) \sim(\textbf V^{-1}\textbf A \textbf V, \textbf V^{-1}\textbf B, \textbf C \textbf V) (A,B,C)(V1AV,V1B,CV)

证明:写出两个SSM, x x x x ~ \tilde{x} x~为对应的状态:
x ′ = A x + B u x ~ = V − 1 AV x ~ + V − 1 B u y = C x y = CV x ~ x^{'} = \textbf Ax +\textbf Bu \ \ \ \ \ \ \ \ \ \tilde x = \textbf V^{-1}\textbf A \textbf V\tilde x +\textbf V^{-1}\textbf Bu \\ y = \textbf C x \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ y = \textbf C \textbf V \tilde x x=Ax+Bu         x~=V1AVx~+V1Buy=Cx                  y=CVx~

当用 V \textbf V V乘以SSM右侧,两个SSM相同,其中 x = V x ~ x = \textbf V \tilde x x=Vx~。因此它们在计算相同的算子 u ↦ y u\mapsto y uy,但是状态基被 V \textbf V V改变了。

Lemma 3.3说明了状态空间 ( A , B , C ) (\textbf A, \textbf B, \textbf C) (A,B,C) ( V − 1 AV , V − 1 B , CV ) (\textbf V^{-1}\textbf A \textbf V, \textbf V^{-1}\textbf B, \textbf C \textbf V) (V1AV,V1B,CV)
实际是等价的。换句话说,它们表示的是同一个映射 u ↦ y u\mapsto y uy,在SSM文献中也被叫做状态空间变换

对此一个非常自然的选择是对角矩阵形式,可能是最典型的形式。众所周知,几乎所有的矩阵在复平面上对角化。

Proposition 3.4 集合 D ⊂ C N × N \mathcal D \subset \mathcal C ^{N\times N} DCN×N 可对角矩阵在 C N × N \mathcal C ^{N\times N} CN×N 上稠密且满测度

换句话说,Proposition 3.4说明(几乎)所有的SSM可以等价成一个对角SSM。除此之外,对角SSM结构化可以解决问题一和问题二,特别是计算 K ‾ \overline{\textbf K} K成为一个成熟的结构化矩阵乘法有高效的时间和空间复杂度。

3.2.1 S4D:对角SSM算法

Remark 3.6. 对于对角SSM的例子, A \textbf A A是对交的,因此我们重载定义 A n \textbf A_n An代表其对角的迹。回想我们定义SISO情况为 B ∈ R N × 1 \textbf B \in \mathcal R ^ {N\times 1} BRN×1 C ∈ R 1 × N \textbf C \in \mathcal R ^ {1\times N} CR1×N,因此我们令 B n , C n \textbf B_n, \textbf C_n Bn,Cn直接索引它们的元素。

现在我们提出了S4D在对角SSM上解决了问题一和问题二

S4D 递归

在对角SSM上计算任何对角化都很简单,因为对角矩阵上的解析函数简化为其对角线上按元素进行。实现一个对角矩阵的矩阵乘法也很简单,因为它减少了元素级的乘法。因此对角SSM轻松地适合Definition3.1

S4D 卷积核:范德蒙矩阵乘法

A \textbf A A是对角的,计算卷积核变得十分简单:
K ‾ ℓ = ∑ n = 0 N − 1 C n A ‾ n ℓ B ‾ n ⟹ K ‾ = ( B ‾ ⊤ ∘ C ) ⋅ V L ( A ‾ ) where V L ( A ‾ ) n , ℓ = A ‾ n ℓ ( 3.2 ) \begin{aligned}\overline{K}_\ell=\sum_{n=0}^{N-1}C_n\overline{A}_n^\ell\overline{B}_n\implies\overline{K}=(\overline{B}^\top\circ C)\cdot\mathcal{V}_L(\overline{A})\quad\text{where}\quad\mathcal{V}_L(\overline{A})_{n,\ell}=\overline{A}_n^\ell\quad(3.2)\end{aligned} K=n=0N1CnAnBnK=(BC)VL(A)whereVL(A)n,=An(3.2)
∘ \circ 是哈达玛积, ⋅ \cdot 是矩阵乘法, V \mathcal V V被称为范德蒙矩阵

再展开一下,我们可以把 K ‾ \overline {\textbf K} K写成下面的范德蒙矩阵-向量乘法
K ‾ = [ B ‾ 0 C 0 … B ‾ N − 1 C N − 1 ] [ 1 A ‾ 0 A ‾ 0 2 … A ‾ 0 L − 1 1 A ‾ 1 A ‾ 1 2 … A ‾ 1 L − 1 ⋮ ⋮ ⋮ ⋱ ⋮ 1 A ‾ N − 1 A ‾ N − 1 2 … A ‾ N − 1 L − 1 ] \overline{K}=\begin{bmatrix}\overline{B}_0C_0&\ldots&\overline{B}_{N-1}C_{N-1}\end{bmatrix}\begin{bmatrix}1&\overline{A}_0&\overline{A}_0^2&\ldots&\overline{A}_0^{L-1}\\1&\overline{A}_1&\overline{A}_1^2&\ldots&\overline{A}_1^{L-1}\\\vdots&\vdots&\vdots&\ddots&\vdots\\1&\overline{A}_{N-1}&\overline{A}_{N-1}^2&\ldots&\overline{A}_{N-1}^{L-1}\end{bmatrix} K=[B0C0BN1CN1] 111A0A1AN1A02A12AN12A0L1A1L1AN1L1
在这里插入图片描述

对角化结构SSM(S4D)有一个非常简单的解释。(左)对角化结构允许它被看作1维SSM的集合,或者scalar递归(右)。作为一个卷积模型,S4D有一个简单的可解释的卷积核,可以用两行代码实现。颜色代表独立的1-D SSM;紫色代表可训练参数。

时间和空间复杂度

原始方法计算3.2是通过范德蒙矩阵 V L ( A ‾ ) \mathcal V_L(\overline{\textbf A}) VL(A)和实现一个矩阵乘法,需要 O ( N L ) O(NL) O(NL)的时间和空间。

然而,范德蒙矩阵已经经过大量研究在理论上乘法可以以 O ~ ( N + L ) \tilde O(N+L) O~(N+L)操作和 O ( N + L ) O(N+L) O(N+L)空间实现。

3.2.2 完整应用实例

整个S4D方法可以直接应用,仅仅需要几行代码来参数化和初始化,核计算和完整的前向传播。

最后,注意结合不同的参数化选择可能导致在kernel实现上的少许不同。图3.1说明了用ZOH离散化的S4D核甚至可以进一步简化到两行代码。

def parameters(N, dt_min = 1e-3, dt_max = 1e-1):#初始化#几何均匀时间尺度 [第五章]log_dt = np.rnadom.rand() * (np.log(dt_max) - np.log(dt_min)) + np.log(dt_min)# S4D-Lin 初始化 (A, B) [第六章]A = -0.5 + 1j * np.pi * np.arange(N // 2)B = np.one(n // 2) + 0j#方差保持初始化 [第五章]C = np.random.randn(N // 2) + 1j * np.random.randn(N)return log_dt, np,log(-A.real, A_imag, B, C)def kernel(L, log_dt, log_A_real, A_imag, B, C):#离散化(例如双线性变换)dt, A = np.exp(log_dt), -np.exp(log_A_real) + 1j * A_imagdA, dB = (1 + dt * A /2) / (1 - dt * A / 2), dt * B / (1 - dt * A / 2)#计算(范德蒙矩阵乘法-可以被优化)#返回实部两倍-核添加共轭对相同return 2 * ((B * C) @ (dA[:, None] ** np.arrange(L))).realdef forward(u, parameters):L = u.shape[-1]K = kernel(L, *parameters)#用FFT卷积 y = u * KK_f,u_f = np.fft.fft(K, n = 2 * L), np.fft.fft(u, n = 2 * L)return np.fft.ifft(K_f*u_f, n = 2 * L)[...,:L]

参数化和计算一通道S4D模型的完整Numpy示例

3.3 对角化加低秩(DPLR)参数化

当可能的时候,对角SSM在实际中使用是理想的因为它们的简单和灵活。然而,它们的强结构有时太过限制。特别是,Chapter 6将会说明基于HIPPO矩阵的重要SSM类(Chapter 4 和 5)不能在数值上表达为对角SSM,而使用一个对角结构的拓展替代。虽然我们推迟这一动机到部分二,这一部分从计算角度,独立地表示这个结构。除了和部分二中的特殊SSM的关系,这个重参数化背后的想法和算法理论上是独立的,在之后的序列模型中会用到3.6.2

这个部分定义了对角SSM的拓展依然可以高效计算的**对角低秩(DPLR)**SSM。我们主要的技术结果关注于发展这个参数化和展示如何高效计算所有的SSM表达(Section 2.3),特别是找到一个问题一和问题二的算法。

在这里插入图片描述

3.3.1给出了我们方法关键组成部分的总览并形式上定义了S4—DPLR参数化。3.3.2给出了主要的结构,说明S4是渐进有效的对于序列模型。证明在附录3.1。

3.3.1 DPLR 状态空间核算法

尽管从对角到DPLE矩阵的扩展看起来很小,额外的低秩项时矩阵计算更困难。特别是不像对角矩阵,计算等式2.8DPLR矩阵的幂次方依然很慢(和非结构化矩阵相同)并且难以被优化,我们通过同时应用三种新技术解决这个瓶颈。

  • 我们通过评估它的单元 ζ \zeta ζ的根截断生成函数 ∑ j = 0 L − 1 K ‾ j ζ j \sum _{j = 0}^{L - 1}\overline{\textbf K}_j\zeta^j j=0L1Kjζj来计算它的谱而不是直接计算 K ‾ \overline {\textbf K} K K ‾ \overline {\textbf K} K之后可以通过一个反FFT实现。
  • 这个生成函数和矩阵分解相近,现在包括一个矩阵求逆而不是幂。低秩项现在可以通过Woodbury恒等式(Proposition A.2)将$(A + PQ*){-1} 按 按 A^{-1}$真正减少到对角情形。
  • 最后,我们表明对角矩阵形式是Cauchy kernel 1 w j − ζ k \frac{1}{w_j - \zeta _k} wjζk1的等价形式,一个使用stable near-linear算法的充分研究问题。

3.3.2 S4-DPLR 算法和计算复杂度

我们的算法在循环和卷积表达下都是经过优化的,满足Definitions 3.1和3.2

Theorem 3.5 (S4递归) 给定任意步长 Δ \Delta Δ,计算提柜的以部可以在 O ( N ) O(N) O(N)操作下完成, N N N是状态大小。

Theorem 3.6 (S4卷积)给定任意步长 Δ \Delta Δ,计算SSM卷积核 K ‾ \overline{\textbf K} K可以被减少到4次Cauchy 乘法,需要仅仅 O ~ ( N + L ) \tilde O(N+L) O~(N+L)次操作和 O ( N + L ) O(N+L) O(N+L)空间

附录C.1,定义C.5形式上定义了Cauchy 矩阵,和有理插值问题相关。在数值分析上计算Cauchy 矩阵同样得到充分研究,有基于著名的快速多极子算法(FMM)的快速算术和数值算法。不同情况下这些算法的计算复杂度在附录C.1 Proposition C.6中展示。

3.3.3赫尔维兹(稳定)DPLR形式

独立于计算S4-DPLR的算法细节,我们使用一个基础DPLR参数化的修正来确保状态空间模型的稳定性。特别是,赫尔维茨矩阵(又称为稳定矩阵)是一类可以确保SSM渐进稳定的。

Definition 3.7. 一个赫尔维茨矩阵 A \textbf A A是一个所有本征值都有负实数部分的矩阵

从离散时间SSM角度,我们很容易明白为什么 A \textbf A A需要是一个赫尔维茨矩阵从基本原则和下面简单的观察。受限,展开RNN模式包含重复升幂 A ‾ \overline {\textbf A} A,只有在 A ‾ \overline {\textbf A} A的所有本征值在(复数)单位圆内或上才是稳定的。第二,变换(2.4)(不论是对于双线性还是ZOH离散化)映射复数左半平面到单位圆,因此计算一个SSM的RNN模式(例如自回归推断)需要 A \textbf A A是一个赫尔维兹矩阵。

从连续角度看,另一种方式看到至一点是线性ODE解是指数形式。我们也可以看到等价卷积形式有脉冲响应 K ( t ) = C e t A B K(t) = \textbf C e^{t\textbf A}\textbf B K(t)=CetAB t → ∞ t\rightarrow\infin t时, K ( t ) = C e t A B K(t) = \textbf C e^{t\textbf A}\textbf B K(t)=CetAB也会爆炸到 ∞ \infin

然而,控制一个常见DPLR矩阵的谱是困难的。在S4的先前版本,我们发现无限制DPLR矩阵在训练胡变成非赫尔维茨(因此不能再无限循环模式中运用)。

为了解决这一点,我们使用DPLR矩阵的小改建,我们称之为赫尔维茨 DPLR形式,我们可以使用参数 Λ − P P ∗ \Lambda - PP^* ΛPP代替 Λ + P Q ∗ \Lambda + PQ^* Λ+PQ。这相当于基本上绑定了参数 Q = − P Q = -P Q=P。注意在技术上这依然是一个DPLR,因此我们使用S4-DPLR算法作为黑盒。

接着,我们讨论这种参数化是如何让S4稳定。高阶想法是SSM的稳定性包含状态矩阵 A \textbf A A的谱,更容易被控制因为 − P P ∗ -\textbf P \textbf P^* PP是半负定矩阵(我们知道它的谱的符号)

Lemma 3. 8 一个矩阵 A = Λ − P P ∗ \textbf A = \Lambda - \textbf P \textbf P^{*} A=ΛPP是赫尔维茨的如果 Λ \Lambda Λ的所有迹有负的实数部分。

证明:我们首先观察到如果 A + A ∗ \textbf A +\textbf A^* A+A是半负定(NSD)的,那么 A \textbf A A是赫尔维茨的。这是因为 0 > v ∗ ( A + A ∗ ) v = ( v ∗ A v ) + ( v ∗ A v ) ∗ = 2 R e ( v ∗ A v ) = 2 λ 0>v^*(A+A^*)v = (v^*Av)+(v^*Av)^* = 2\mathcal Re(v^*Av) = 2\lambda 0>v(A+A)v=(vAv)+(vAv)=2Re(vAv)=2λ对于任何 A A A的(单位长度)本征对来说。之后,注意到条件暗示 A + A ∗ \textbf A +\textbf A^* A+A是半负定(NSD)的(非正数迹的实数对角矩阵)。因为矩阵 − P P ∗ -PP^* PP也是NSD的, A + A ∗ A+A^* A+A也是这样。

Lemma 3.8表明,对于赫尔维兹DPLR表示,控制学习的A矩阵的频谱变成简单地控制对角线部分 Λ \Lambda Λ。这是一个比控制一般DPLR矩阵容易得多的问题,可以通过正则化或重新参数化来强制执行(第3.4.2节)。

Remark 3.7. 赫尔维茨DPLR形式 Λ − P P ∗ \Lambda - PP* ΛPP有更少的参数而且在技术上表现能力差于不受限DPLR形式 Λ + P Q ∗ \Lambda + PQ^* Λ+PQ但在经验上并没有影响模型表现。

Remark 3.8. 潜在的稳定性问题只在使用S4在特定内容如自回归生成时上升,因为S4的卷积模式在训练时并没有升幂 A ‾ \overline {\textbf A} A因此对赫尔维茨矩阵并不是严格要求。在实践中,出于原则,我们仍然总是使用赫尔维茨DPLR。

相关文章:

  • 13【CPP】Hash(闭散列||开散列)
  • 软考笔记--软件架构风格
  • Matlab/Simulink验证MAB建模规范
  • Android布局优化之include、merge、ViewStub的使用,7年老Android一次坑爹的面试经历
  • 宠物的异味,用空气净化器可以解决吗?宠物空气净化器品牌推荐
  • 【C++】贪心算法
  • Redis是单线程还是多线程?
  • 代码随想录算法训练营第三十三天|LeetCode1005 K次取反后最大化的数组和 、LeetCode134 加油站、LeetCode135 分发糖果
  • 【vue/组件封装】封装一个带条件筛选的搜索框组件(多组条件思路、可多选)详细流程
  • Nginx 常用的基础配置(前端相关方面)
  • C# SwinV2 Stable Diffusion 提示词反推 Onnx Demo
  • 微软研究深度报告:Sora文转视频AI模型全景剖析及未来展望
  • 网关kong记录接口处理请求和响应插件 tcp-log-with-body的安装
  • [python] dict类型变量写在文件中
  • js设计模式:解释器模式
  • Angular 响应式表单之下拉框
  • Flex布局到底解决了什么问题
  • Hibernate【inverse和cascade属性】知识要点
  • JAVA并发编程--1.基础概念
  • jQuery(一)
  • magento2项目上线注意事项
  • Meteor的表单提交:Form
  • MobX
  • mysql_config not found
  • PAT A1017 优先队列
  • Perseus-BERT——业内性能极致优化的BERT训练方案
  • python学习笔记-类对象的信息
  • Spring Security中异常上抛机制及对于转型处理的一些感悟
  • springMvc学习笔记(2)
  • storm drpc实例
  • Webpack 4x 之路 ( 四 )
  • yii2权限控制rbac之rule详细讲解
  • -- 查询加强-- 使用如何where子句进行筛选,% _ like的使用
  • 第2章 网络文档
  • 高程读书笔记 第六章 面向对象程序设计
  • 理解IaaS, PaaS, SaaS等云模型 (Cloud Models)
  • 使用 5W1H 写出高可读的 Git Commit Message
  • 算法-插入排序
  • 用Canvas画一棵二叉树
  • ionic异常记录
  • Spark2.4.0源码分析之WorldCount 默认shuffling并行度为200(九) ...
  • 摩拜创始人胡玮炜也彻底离开了,共享单车行业还有未来吗? ...
  • ###STL(标准模板库)
  • #13 yum、编译安装与sed命令的使用
  • #数学建模# 线性规划问题的Matlab求解
  • $().each和$.each的区别
  • (Java)【深基9.例1】选举学生会
  • (JSP)EL——优化登录界面,获取对象,获取数据
  • (二十四)Flask之flask-session组件
  • (一)pytest自动化测试框架之生成测试报告(mac系统)
  • (转)创业的注意事项
  • .naturalWidth 和naturalHeight属性,
  • .net refrector
  • .NET企业级应用架构设计系列之技术选型
  • @Bean, @Component, @Configuration简析