当前位置: 首页 > news >正文

AIGC笔记--特征线性调制(FiLM)层的实现

目录

1--特征线性调制层的作用

2--特征线性调制层的实现

3--论文实例


1--特征线性调制层的作用

        特征线性调制(Feature-wise Linear Modulation,FiLM)层是一种神经网络模块,它可以用来实现特征的条件调整。FiLM层的主要功能是对输入特征进行缩放(scaling)和偏移(shifting),并且这个缩放和偏移是可以学习的。

        FiLM层的工作原理如下:给定一个输入特征x,FiLM层首先通过一个全连接层或其他形式的网络结构生成两个参数γβ,然后对输入特征进行缩放和偏移,即y = γ * x + β。这里,γ和β是与输入特征x同样大小的向量,它们决定了对输入特征的缩放偏移程度

        FiLM层的主要作用是实现特征的条件调整,使得模型可以根据特定的条件(例如来自其他模态的信息)来调整特征的表示。这种机制在许多任务中都很有用,例如在图像生成任务中,FiLM层可以用来根据文本描述来调整生成的图像特征;在视频理解任务中,FiLM层可以用来根据音频信息来调整视频特征

        总的来说,FiLM层是一种强大的特征调整工具,它可以帮助模型更好地利用条件信息,从而提高模型的性能。

2--特征线性调制层的实现

import torch
import torch.nn as nnclass FiLM(nn.Module):def __init__(self, input_dim, condition_dim):super(FiLM, self).__init__()# 全连接层,用于生成γ和β参数self.fc_gamma = nn.Linear(condition_dim, input_dim)self.fc_beta = nn.Linear(condition_dim, input_dim)def forward(self, x, condition):# 根据条件特征获取缩放scale参数和移位参数shift,即计算γ和β参数gamma = self.fc_gamma(condition)beta = self.fc_beta(condition)# 对输入特征x进行缩放和偏移,实现条件特征调整输入特征y = gamma * x + beta return yif __name__ == "__main__":input_dim = 64 # 输入特征condition_dim = 128 # 条件特征# 创建一个FiLM层实例film_layer = FiLM(input_dim, condition_dim)# 初始化输入特征x和条件特征conditionx = torch.randn(1, input_dim)condition = torch.randn(1, condition_dim)# 使用FiLM层对输入特征x进行条件调整y = film_layer(x, condition)print(y.shape) # [1, 64]

3--论文实例

Audio2Photoreal中,利用音频特征来调整动作特征:

import torch
import torch.nn as nn
from einops import rearrangeclass DenseFiLM(nn.Module):def __init__(self, embed_channels):super().__init__()self.embed_channels = embed_channelsself.block = nn.Sequential(nn.Mish(), nn.Linear(embed_channels, embed_channels * 2)) # nn.Mish()激活函数def forward(self, position): # position [B dim]pos_encoding = self.block(position) # pos_encoding [B 2*dim]pos_encoding = rearrange(pos_encoding, "b c -> b 1 c") # [B 1 2*dim]scale_shift = pos_encoding.chunk(2, dim=-1) # two [B 1 dim]return scale_shiftdef featurewise_affine(x, scale_shift):# 获取缩放因子和移位因子scale, shift = scale_shift # scale [B 1 dim] shift [B 1 dim]return (scale + 1) * x + shift # 调整特征if __name__ == "__main__":B = 2Frame_Residual_depth = 20*4dim = 64input_x = torch.rand(B, Frame_Residual_depth, dim) # 运动特征condition_t = torch.rand(B, dim) # 音频条件特征film = DenseFiLM(dim)# 调用film(condition_t)获取缩放因子和移位因子output_x = input_x + featurewise_affine(input_x, film(condition_t)) # 通过print(output_x.shape) # [B, Frame_Residual_depth, dim]

相关文章:

  • Linux上常用网络操作
  • Android面试官爱问的12个自定义View的问题
  • Mysql深度分页优化的一个实践
  • openssl3.2 - 官方demo学习 - signature - rsa_pss_hash.c
  • 芯片设计重要工具—— IBM LSF 分布式高性能计算调度平台
  • #laravel 通过手动安装依赖PHPExcel#
  • python期末:组合数据
  • 【springboot】配置文件入门
  • 链表的常见操作
  • 【设计模式之美】重构(三)之解耦方法论:如何通过封装、抽象、模块化、中间层等解耦代码?
  • 如何使用阿里云CDN服务?
  • Pandas实战100例 | 案例 100: 将 DataFrame 保存为 CSV 文件
  • 以后要做GIS开发的话是学GIS专业还是学计算机专业好一些?
  • mysql主从报错:Last_IO_Error: Error connecting to source解决方法
  • 京东ES支持ZSTD压缩算法上线了:高性能,低成本 | 京东云技术团队
  • Cookie 在前端中的实践
  • ECS应用管理最佳实践
  • emacs初体验
  • IP路由与转发
  • JS 面试题总结
  • Js基础知识(一) - 变量
  • Laravel 中的一个后期静态绑定
  • Material Design
  • npx命令介绍
  • Ruby 2.x 源代码分析:扩展 概述
  • 读懂package.json -- 依赖管理
  • 给Prometheus造假数据的方法
  • 基于Mobx的多页面小程序的全局共享状态管理实践
  • 使用 @font-face
  • 一、python与pycharm的安装
  • 格斗健身潮牌24KiCK获近千万Pre-A轮融资,用户留存高达9个月 ...
  • 好程序员web前端教程分享CSS不同元素margin的计算 ...
  • ​Linux Ubuntu环境下使用docker构建spark运行环境(超级详细)
  • # centos7下FFmpeg环境部署记录
  • (多级缓存)缓存同步
  • (二十三)Flask之高频面试点
  • (附源码)ssm考生评分系统 毕业设计 071114
  • (三)uboot源码分析
  • (终章)[图像识别]13.OpenCV案例 自定义训练集分类器物体检测
  • (转)Mysql的优化设置
  • . NET自动找可写目录
  • .describe() python_Python-Win32com-Excel
  • .NET Project Open Day(2011.11.13)
  • .Net Redis的秒杀Dome和异步执行
  • .NET 中 GetHashCode 的哈希值有多大概率会相同(哈希碰撞)
  • .NET6 命令行启动及发布单个Exe文件
  • .net最好用的JSON类Newtonsoft.Json获取多级数据SelectToken
  • ?.的用法
  • @Autowired和@Resource的区别
  • @Resource和@Autowired的区别
  • @Transaction注解失效的几种场景(附有示例代码)
  • @value 静态变量_Python彻底搞懂:变量、对象、赋值、引用、拷贝
  • @四年级家长,这条香港优才计划+华侨生联考捷径,一定要看!
  • [ 隧道技术 ] cpolar 工具详解之将内网端口映射到公网
  • []AT 指令 收发短信和GPRS上网 SIM508/548