当前位置: 首页 > news >正文

代码解读 | Hybrid Transformers for Music Source Separation[05]

一、背景

        0、Hybrid Transformer 论文解读

        1、代码复现|Demucs Music Source Separation_demucs架构原理-CSDN博客

        2、Hybrid Transformer 各个模块对应的代码具体在工程的哪个地方

        3、Hybrid Transformer 各个模块的底层到底是个啥(初步感受)?

        4、Hybrid Transformer 各个模块处理后,数据的维度大小是咋变换的?

        5、Hybrid Transformer 拆解STFT模块


        从模块上划分,Hybrid Transformer Demucs 共包含 (STFT模块、时域编码模块、频域编码模块、Cross-Domain Transformer Encoder模块、时域解码模块、频域解码模块、ISTFT模块)7个模块。

        本篇目标:拆解频域编码模块的底层

        时域编码和频域编码原理类似(后续不再拆解时域编码模块)。

二、频域编码模块


class HEncLayer(nn.Module):def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,rewrite=True):"""Encoder layer. This used both by the time and the frequency branch.Args:chin: number of input channels.chout: number of output channels.norm_groups: number of groups for group norm.empty: used to make a layer with just the first conv. this is usedbefore merging the time and freq. branches.freq: this is acting on frequencies.dconv: insert DConv residual branches.norm: use GroupNorm.context: context size for the 1x1 conv.dconv_kw: list of kwargs for the DConv class.pad: pad the input. Padding is done so that the output size isalways the input size / stride.rewrite: add 1x1 conv at the end of the layer."""super().__init__()norm_fn = lambda d: nn.Identity()  # noqaif norm:norm_fn = lambda d: nn.GroupNorm(norm_groups, d)  # noqaif pad:pad = kernel_size // 4else:pad = 0klass = nn.Conv1dself.freq = freqself.kernel_size = kernel_sizeself.stride = strideself.empty = emptyself.norm = normself.pad = padif freq:kernel_size = [kernel_size, 1]stride = [stride, 1]pad = [pad, 0]klass = nn.Conv2dself.conv = klass(chin, chout, kernel_size, stride, pad)if self.empty:returnself.norm1 = norm_fn(chout)self.rewrite = Noneif rewrite:self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)self.norm2 = norm_fn(2 * chout)self.dconv = Noneif dconv:self.dconv = DConv(chout, **dconv_kw)def forward(self, x, inject=None):"""`inject` is used to inject the result from the time branch into the frequency branch,when both have the same stride."""if not self.freq and x.dim() == 4:B, C, Fr, T = x.shapex = x.view(B, -1, T)if not self.freq:le = x.shape[-1]if not le % self.stride == 0:x = F.pad(x, (0, self.stride - (le % self.stride)))y = self.conv(x)if self.empty:return yif inject is not None:assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)if inject.dim() == 3 and y.dim() == 4:inject = inject[:, :, None]y = y + injecty = F.gelu(self.norm1(y))if self.dconv:if self.freq:B, C, Fr, T = y.shapey = y.permute(0, 2, 1, 3).reshape(-1, C, T)y = self.dconv(y)if self.freq:y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)if self.rewrite:z = self.norm2(self.rewrite(y))z = F.glu(z, dim=1)else:z = yreturn z

        核心代码如上所示。

        使用print函数打印出各个关键节点的信息,可以得到频域编解码模块的全景图。

        编码层:Conv2d+Norm1+GELU,  Norm1:Identity()

        残差连接:(Conv1d+GroupNorm+GELU +Conv1d+GroupNorm+GLU+LayerScale())

        +(Conv2d+Norm2+GLU),Norm2:Identity() ,备注:Identity可以理解成直通

#上图均是自己读完代码绘制的。相信自己也可以。
#具体的,编码层1-4的Conv2d分别是:
Conv2d(4, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(48, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(96, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(192, 384, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
#残差连接1
DConv((layers): ModuleList((0): Sequential((0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 6, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 96, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 6, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 96, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))#残差连接2
DConv((layers): ModuleList((0): Sequential((0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 12, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 192, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 12, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 192, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))#残差连接3
DConv((layers): ModuleList((0): Sequential((0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 24, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 384, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 24, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 384, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))#残差连接4
DConv((layers): ModuleList((0): Sequential((0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 48, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 768, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 48, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 768, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))

        关于,各个卷积模块输出数据的shape计算,可以读这篇文章。

        没有所谓天生的大佬,如果有那么我愿称他/她为圣人。我相信,能读到这儿的都会成为大佬~。Believe yourself,one day,you will be somebody.


         感谢阅读,最近开始写公众号(分享好用的AI工具),欢迎大家一起见证我的成长(桂圆学AI)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 卡尔曼滤波的完整流程
  • 线程池介绍与应用
  • 【代码随想录】【算法训练营】【第30天 1】 [322]重新安排行程 [51]N皇后
  • easyexcel的简单使用(execl模板导出)
  • oracle块跟踪
  • OpenGL-ES 学习(6)---- Ubuntu OES 环境搭建
  • 探索AI视频生成技术的原理
  • Chromium源码阅读:Mojo实战:从浏览器JS API 到blink实现
  • vue中,设置全局的 input 为只读状态,并改变输入框背景色
  • AWS无服务器 应用程序开发—第四章 数据库(Amazon DynamoDB)
  • 关于下载 IDEA、WebStorm 的一些心得感想
  • 统信UOS屏蔽mysql显性的用户名称以及密码
  • vue技巧(十)全局配置使用(打包后可修改配置文件)
  • Hash算法、MD5算法、HashMap
  • SpringBoot 升级到2.4.0以上版本跨域设置
  • canvas绘制圆角头像
  • django开发-定时任务的使用
  • gitlab-ci配置详解(一)
  • Hibernate【inverse和cascade属性】知识要点
  • Linux编程学习笔记 | Linux IO学习[1] - 文件IO
  • Odoo domain写法及运用
  • PAT A1050
  • Sass 快速入门教程
  • swift基础之_对象 实例方法 对象方法。
  • vue-loader 源码解析系列之 selector
  • 编写符合Python风格的对象
  • 对话 CTO〡听神策数据 CTO 曹犟描绘数据分析行业的无限可能
  • 关于for循环的简单归纳
  • 推荐一款sublime text 3 支持JSX和es201x 代码格式化的插件
  • 系统认识JavaScript正则表达式
  • 用Python写一份独特的元宵节祝福
  • 你对linux中grep命令知道多少?
  • k8s使用glusterfs实现动态持久化存储
  • ​中南建设2022年半年报“韧”字当头,经营性现金流持续为正​
  • ‌‌雅诗兰黛、‌‌兰蔻等美妆大品牌的营销策略是什么?
  • (26)4.7 字符函数和字符串函数
  • (CPU/GPU)粒子继承贴图颜色发射
  • (php伪随机数生成)[GWCTF 2019]枯燥的抽奖
  • (阿里云在线播放)基于SpringBoot+Vue前后端分离的在线教育平台项目
  • (二十六)Java 数据结构
  • (附源码)SSM环卫人员管理平台 计算机毕设36412
  • (经验分享)作为一名普通本科计算机专业学生,我大学四年到底走了多少弯路
  • (十五)使用Nexus创建Maven私服
  • (总结)Linux下的暴力密码在线破解工具Hydra详解
  • ./configure,make,make install的作用(转)
  • ./indexer: error while loading shared libraries: libmysqlclient.so.18: cannot open shared object fil
  • .Net IE10 _doPostBack 未定义
  • .NET 表达式计算:Expression Evaluator
  • .Net 执行Linux下多行shell命令方法
  • .net(C#)中String.Format如何使用
  • .Net(C#)自定义WinForm控件之小结篇
  • .NET/C# 项目如何优雅地设置条件编译符号?
  • .NET/MSBuild 中的发布路径在哪里呢?如何在扩展编译的时候修改发布路径中的文件呢?
  • .NET未来路在何方?
  • ;号自动换行