当前位置: 首页 > news >正文

diffusion model(十四): prompt-to-prompt 深度剖析

info
paperPrompt-to-Prompt Image Editing with Cross Attention Control
githubhttps://github.com/google/prompt-to-prompt
Org:Google Research
个人复现https://github.com/myhz0606/diffusion_learning
个人博客主页http://myhz0606.com/article/p2p

1 前言

基于扩散模型(diffusion model)的图片编辑技术当下取得了飞跃的进展,涌现出了大量优秀的工作,例如:InstructPix2Pix[1]和EmuEdit[2]。这些工作致力于实现直接通过文字指令来编辑图片,极大地提升了传统图像编辑流程的效率。这种新兴的技术领域被称作基于指令的图像编辑(instruction-based image editing)。饮水思源,这类技术成功的背后,离不开Google在2022年提出的Prompt-to-Prompt(下文简称为p2p)这项工作。

为了深入理解技术细节,笔者借鉴google的开源代码对其进行复现。

2 P2P提出的Motivation

目前大火的文生图技术(text to image),给定一段文本(prompt)和随机种子,文生图模型会基于这两者生成一张图片。生成图片的不同由两个变量决定

  • 随机种子。随机种子决定初始的噪声 x T x_T xT
  • prompt。prompt是通过文本编码器(如CLIP的text encoder)转为语义向量再送入到diffusion modelcross-attention层中与图片信息交互。

假定up sampler不引入随机性,如DDIM; classifier-guidance-score; generation step是系统变量维持不变

如果我们固定了随机种子,仅微小的改变prompt,输出的图片是否相似?如果可行,那么根据这个特性,很方便的可以通过修改prompt来编辑图片了。很遗憾,事情没有那么简单。仅微小改动prompt,输出的图片也有很大差异。下图展示了固定随机种子,仅更改蛋糕种类的生成结果。

在这里插入图片描述

过去为了解决上述问题,Repaint[3]、Diffedit[4]在做图片编辑时,会引入一个mask,在编辑阶段,只更新mask区域的像素值,这类方法也取得了一些令人惊叹的结果,但上述方法同样存在三个问题:

  1. 需要手动构建mask,比较麻烦。(现在一般会接入SAM[5]来加速这个过程)
  2. 由于在编辑过程只修改mask区域的像素值,未考虑mask区域与非mask区域的结构信息,导致生成的图片语义连贯性较差。
  3. 这类方法只能实现object-level的编辑,无法实现图片风格、纹理的编辑。

在这篇文章中,作者提出了一种p2p的文字编辑方法(textual editing),无需训练任何参数、添加任何模块,仅用预训练的文生图模型(如stable diffusion)即能实现卓越的textual editing能力。下图展示了引入p2p技术后,同样的随机种子和prompt的生成结果。

在这里插入图片描述

下面来看p2p具体是怎么做的吧。

3 方法

3.1 什么是prompt-to-prompt 🤔

通过上面的背景和动机介绍,我们知道p2p做的是这样一件事:

给定原始图片的prompt( P \mathcal{P} P)与编辑图片的prompt ( P ∗ \mathcal{P}^* P),通过文生图模型,分别获得原始图片 I \mathcal{I} I与编辑后的图片 I ∗ \mathcal{I}^* I I \mathcal{I} I I ∗ \mathcal{I}^* I除了编辑区域外尽可能的近。

举个🌰,当我输入prompt a photo of a house on a mountain.用文生图生成了一张在山上的房子的图片,现在我们想维持生成图片的整体布局,仅将其改为冬景。用p2p技术可以很方便实现,如下图所示

在这里插入图片描述

3.2 prompt-to-prompt的具体实现 🤔

在详细介绍p2p之前,我们先来回答motivation中的一个问题。

为什么给定了随机种子,仅微小的改变prompt,输出的图片却差异很大?

我们知道在文生图中,prompt与diffusion model是在cross-attention层进行交互(text embedding作为cross-attention的key和value)。如下图所示(灰色的块代表mask)。

在这里插入图片描述

📌忘记文生图条件融合的话,可以回顾 classifier-free guided的内容。

假定当prompt的第二个token发生改变时,根据下图的计算流,可以看到整个attention score的数值都会发生改变。从而导致最终输出结果发生改变。

在这里插入图片描述

3.2.1 cross-attention对生成图片的影响

通过对diffusion model网络内部的观察,作者发现生成图片的空间布局和几何形状都是由内部的cross-attention层的attention map决定(上图的 a t t e n \mathrm{atten} atten)。下图是由prompt: “a furry bear watching a bird”生成的图片,我们分别看每一个token对应的attention map对应生成图片的相应位置。并在time step的早期这个对应关系就已形成。

在这里插入图片描述

在这里插入图片描述

这里提供一张attention map随时间步变化的gif图。

在这里插入图片描述

3.2.1 controlling the cross-attention

control的思路很简单。既然cross-attention的attention map决定生成图片的结构信息,那我们维持原始的attention map即可。

p2p的整体算法流程如下图所示

每一个时间步 t t t分别计算原始prompt P \mathcal{P} P的attention map M t M_t Mt和新的prompt P ∗ \mathcal{P}^* P的attention map M t ∗ M^*_t Mt并用特定的替换规则 E d i t ( M t , M t ∗ , t ) Edit(M_t, M_t^*, t) Edit(Mt,Mt,t)替换后再进行生成。

在这里插入图片描述

作者根据不同的编辑类型,设计了不同的替换方式

在这里插入图片描述

(一)Word Swap

这个编辑类型是指将原始prompt中的某个token用新的token进行替换。 P = \mathcal{P} = P= “photo of a cat riding on a bicycle”, P ∗ = \mathcal{P}^* = P= “photo of a cat riding on a motorcycle”。此时的替换规则是

E d i t ( M t , M t ∗ , t ) : = { M t ∗ i f t < τ M t o t h e r w i s e . (1) E d i t ( M _ { t } , M _ { t } ^ { * } , t ) : = \left\{ \begin{array} { c l } { M _ { t } ^ { * } } & { \quad \mathrm { i f \ } t \lt \tau } \\ { M _ { t } } & { \quad \mathrm { o t h e r w i s e . } } \\ \end{array} \right . \tag{1} Edit(Mt,Mt,t):={MtMtif t<τotherwise.(1)

τ \tau τ表示某一时间步。当时间步小于 τ \tau τ时不做替换,否则用原始prompt的attention map做替换。(当两个词的长度不同时,可以对少的进行复制)引入 τ \tau τ的目的是:有一些编辑对图像的几何改变会很大,可以通过引入控制时机 τ \tau τ来缓和。Word Swap的编辑形式可以很方便的对图片中某个物体进行局部编辑。

在这里插入图片描述

(二)Adding a New Phrase

指的是在原始prompt P \mathcal{P} P新增一些token。如 P = \mathcal{P}= P= “a photo of a house on a mountain”, P ∗ = \mathcal{P}^* = P= “a photo of a house on a mountain at winter”。

( E d i t ( M t , M t ∗ , t ) ) i , j : = { ( M t ∗ ) i , j i f A ( j ) = N o n e ( M t ) i , A ( j ) o t h e r w i s e . (2) ( E d i t ( M _ { t } , M _ { t } ^ { * } , t ) ) _ { i , j } : = \left \{ \begin{array} { l l } { { ( M _ { t } ^ { * } ) _ { i , j } } } & { { \mathrm { i f } \ A ( j ) = N o n e } } \\ { { ( M _ { t } ) _ { i , A ( j ) } } } & { { \mathrm { o t h e r w i s e } . } } \end{array} \right . \tag{2} (Edit(Mt,Mt,t))i,j:={(Mt)i,j(Mt)i,A(j)if A(j)=Noneotherwise.(2)

i i i 表示visual token的索引位置, j j j 表示 P ∗ \mathcal{P}^* P中text token 的索引位置; A ( j ) A(j) A(j)表示, P ∗ \mathcal{P}^* P的text token j j j P \mathcal{P} P中的索引位置。这种类型的control同样可以引入word swap中的 τ \tau τ来控制control的时机。用这个方法可以对图像进行全局的编辑,如下面例子的改变风格整体图片的风格为“winter”。

在这里插入图片描述

(三)Attention Re–weighting

基于p2p还可以精细的控制prompt每一个token的控制强度。这个场景 P \mathcal{P} P P ∗ \mathcal{P}^* P是相同的,可以更改特定token的权重来控制图像。

( E d i t ( M t , M t ∗ , t ) ) i , j : = { c ⋅ ( M t ) i , j i f j = j ∗ ( M t ) i , j o f h e r w i s e . (3) ( E d i t ( M _ { t } , M _ { t } ^ { * } , t ) ) _ { i , j } : = \left \{ \begin{array} { c c } { c \cdot ( M _ { t } ) _ { i , j } } & { \mathrm { i f } \ j = j ^ { * } } \\ { ( M _ { t } ) _ { i , j } } & { \mathrm { o f h e r w i s e } . } \\ \end{array} \right . \tag{3} (Edit(Mt,Mt,t))i,j:={c(Mt)i,j(Mt)i,jif j=jofherwise.(3)
在这里插入图片描述

4 核心部分代码说明

diffusers.version == 0.10.0

4.1 修改cross-attention层的forward

p2p的核心是修改cross-attention层的计算方式,为此我们需要重写diffusers内部cross-attention的forward函数,引入controller.control() 来控制attention map的编辑。

def control_cross_attn_forward(self, controller: BaseController, place_in_unet):def forward(x, context: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None):batch_size, sequence_length, dim = x.shapeh = self.headsq = self.to_q(x)is_cross = context is not Nonecontext = context if is_cross else xk = self.to_k(context)v = self.to_v(context)q = self.reshape_heads_to_batch_dim(q)k = self.reshape_heads_to_batch_dim(k)v = self.reshape_heads_to_batch_dim(v)sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scaleif mask is not None:mask = mask.reshape(batch_size, -1)max_neg_value = -torch.finfo(sim.dtype).maxmask = mask[:, None, :].repeat(h, 1, 1)sim.masked_fill_(~mask, max_neg_value)# attention, what we cannot get enough ofattn = sim.softmax(dim=-1)# print(f"attn shape: {attn.shape}")attn = controller.control(attn, is_cross, place_in_unet)  # AttentionStore时相当于将attention值缓存到controller中out = torch.einsum("b i j, b j d -> b i d", attn, v)out = self.reshape_batch_dim_to_heads(out)to_out = self.to_outif type(to_out) is torch.nn.modules.container.ModuleList:to_out = self.to_out[0]  # 忽略dropoutelse:to_out = self.to_outreturn to_out(out)return forwarddef register_attention_control_mine(unet, controller):cross_attn_name_ls = []for i in unet.named_children():name, cur_module = i[:2]if cur_module.__class__.__name__ == "CrossAttention":cur_module.forward = control_cross_attn_forward(cur_module, controller, name)cross_attn_name_ls.append(name)elif hasattr(cur_module, "children"):module_ls = [(name, cur_module)]while module_ls:name, cur_module = module_ls.pop()for sub_name, sub_module in cur_module.named_children():if sub_module.__class__.__name__ == "CrossAttention":sub_module.forward = control_cross_attn_forward(sub_module,controller,f"{name}.{sub_name}")cross_attn_name_ls.append(f"{name}.{sub_name}")elif hasattr(sub_module, "children"):module_ls.append((f"{name}.{sub_name}", sub_module))controller.num_att_layers = len(cross_attn_name_ls)controller.cross_attn_name_ls = cross_attn_name_ls

4.2 control attention map

controller.control() 内部的实现方式

class EditControllerMemEfficient(BaseController):def __init__(self, edit_params: EditParams,max_vis_pixel_num=MAX_VIS_PIXEL_NUM,cached_attn_info_flag=False,logger=base_logger):super(EditControllerMemEfficient, self).__init__(max_vis_pixel_num=max_vis_pixel_num, cached_attn_info_flag=cached_attn_info_flag, edit_params=edit_params, logger=logger)self.control_info_checking()def cross_attn_control(self, attn: torch.Tensor, place_in_unet: str) -> torch.Tensor:assert attn.shape[0] > 1, f"attn shape: {attn.shape}"source_replace_mask = self.replace_index_map["source_mask"]target_replace_mask = self.replace_index_map["target_mask"]source_token_weight = self.replace_index_map["source_token_weight"]target_token_weight = self.replace_index_map["target_token_weight"]if self.do_cross_attn_control_flag:attn = rearrange(attn, "(b h) p c -> b h p c", b=self.batch_size)source_attn = attn[:1, ...]target_attn = attn[1:, ...]source_attn_for_merge = source_attn * source_token_weighttarget_attn = target_attn * target_token_weighttarget_attn[..., target_replace_mask] = source_attn_for_merge[..., source_replace_mask]attn = torch.cat([source_attn, target_attn], dim=0)attn = rearrange(attn, "b h p c -> (b h) p c")if self.do_local_blend and self.text_branch_flag:  # local blend whatever cross controlblend_attn = attnself.set_blend_attn_map(place_in_unet, True, blend_attn)return attndef self_attn_control(self, attn: torch.Tensor, place_in_unet: str) -> torch.Tensor:if attn.shape[2] <= 16 ** 2:attn = rearrange(attn, "(b h) p c -> b h p c", b=self.batch_size)source_attn = attn[:1, ...]if self.do_self_attn_control_flag:attn = source_attn.expand(self.batch_size, *source_attn.shape[1:])attn = rearrange(attn, "b h p c -> (b h) p c")return attndef control(self, attn: torch.Tensor, is_cross: bool, place_in_unet: str) -> torch.Tensor:# print(f">>>cached_attn_flag: {self.cached_attn_info_flag}")assert self.current_step is not None, f"please set current time step by 'self.set_step'!"pixel_num = attn.shape[1]if pixel_num > self.max_vis_pixel_num:self.not_control_attn_name_set.add(place_in_unet)return attnif place_in_unet not in self.cached_attn.keys():self.cached_attn[place_in_unet] = dict() if is_cross:attn = self.cross_attn_control(attn, place_in_unet)else:attn = self.self_attn_control(attn, place_in_unet)if self.cached_attn_info_flag:self.cached_attn_name_set.add(place_in_unet)if is_cross and self.do_cross_attn_control_flag:self.set_cached_attn(place_in_unet, is_cross, attn)elif is_cross and not self.do_cross_attn_control_flag:self.set_cached_attn(place_in_unet, is_cross, None)elif not is_cross and self.do_self_attn_control_flag:self.set_cached_attn(place_in_unet, is_cross, attn)else:self.set_cached_attn(place_in_unet, is_cross, None)return attn

4.3 支持的编辑方式

代码中通过EditParams类来指定编辑的参数

class EditParams:source_prompt: strtarget_prompt: strcross_merge_end_step: Union[float, int]  # cross attention merge step, 0-(cross_merge_step * diffusion step)  using cross-attn injection self_merge_end_step: Union[float, int]  # self attention merge step, 0-(self_merge_step * diffusion step) using self-attn injectioncross_merge_start_step: Union[float, int] = 0  # cross attention merge step, 0-(cross_merge_step * diffusion step)  using cross-attn injectionself_merge_start_step: Union[float, int] = 0  # self attention merge step, 0-(self_merge_step * diffusion step) using self-attn injectionaddition_token_control_info: Optional[Dict] = Nonedo_noise_branch_control: bool = Falsedo_local_blend: bool = False  # using local blendblend_focus_text: Optional[List] = None

5 One More Thing

5.1 p2p with additional constraints

的edit能力通过引入以下3个约束还能进一步提升

  • self attention的约束

将原始图片在self attention处的attention map迁移给编辑图片,非编辑区域维持性会更强。详细可见pnp[7]论文。

下图展现了当使用self- attention control时的编辑效果。应用的步长越多,非edit区域的维持性越好。

source prompt: "a photo of a house on a mountain.”

target_prompt: "a photo of a house on a mountain at winter"

在这里插入图片描述

同样,有一些编辑对图像的几何改变会很大,不宜控制过多

在这里插入图片描述

  • 引入local blend

仅更改需要编辑区域的pixel,保留其它区域的pixel。编辑区域的mask为token对应的attention map。底层原理可见repaint[8] paper。

如:当引入“mountain”的local blend限制时,只有山的区域变为雪景。

在这里插入图片描述

当local-blend还可以结合re-weight等编辑策略,可以实现更细粒度的控制

在这里插入图片描述

  • noise分支引入self attention的约束

我们知道对于classifier-free的文生图,需要同时计算条件分支的噪声估计 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y, t) ϵθ(xt,y,t)和非条件分支的噪声估计 ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\empty, t) ϵθ(xt,y=,t) ,再通过classifier-free的方式融合。尝试发现,非条件分支引入self-attention control有助于进一步提升编辑效果(相比前面,提升不太大)。

ϵ ^ θ ( x t , y , t ) = ϵ θ ( x t , y = ∅ , t ) + s [ ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ] (4) \begin{align} \hat{\epsilon}_{\theta}(x_t, y, t)=\epsilon_{\theta}(x_t, y=\empty,t) + s[\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ] \end{align} \tag{4} ϵ^θ(xt,y,t)=ϵθ(xt,y=,t)+s[ϵθ(xt,y,t)ϵθ(xt,y=,t)](4)

5.2 p2p for real image editing

若要采用p2p论文中的方法进行编辑需要知道两个信息:1)图片的初始噪声分布;2)图片的prompt。如果直接拿一张图过来是没有办法进行p2p进行编辑的。需要先得到以下两个信息:

1)给定或生成这张图的prompt;

2)估计出给定prompt下这张图的噪声。

在作者后续的Null-text Inversion [9]工作中对这类情形进一步研究,后续文章中将详细介绍。

参考文献

[1] InstructPix2Pix: Learning to Follow Image Editing Instructions

[2] Emu Edit: Precise Image Editing via Recognition and Generation Tasks

[3] RePaint: Inpainting using Denoising Diffusion Probabilistic Models

[4] DiffEdit: Diffusion-based semantic image editing with mask guidance

[5] Segment Anything

[6] classifier-free diffusion model

[7] Plug-and-Play Diffusion Features for Text-Driven Image-to-Image Translation

[8] RePaint: Inpainting using Denoising Diffusion Probabilistic Models

[9] Null-text Inversion for Editing Real Images using Guided Diffusion Models

相关文章:

  • QT 驾校系统界面布局编写
  • Nginx安装和平滑升级
  • 阿里云发布 AI 编程助手 “通义灵码”——VSCode更强了 !!
  • PX4|基于FAST-LIO mid360的无人机室内自主定位及定点悬停
  • layui table列表重载后保持进度条位置不变
  • 论文浅尝 | GPT-RE:基于大语言模型针对关系抽取的上下文学习
  • Json Web Token(JWT) 快速入门
  • [项目设计]基于websocket实现网络对战五子棋
  • Python使用whisper实现语音识别(ASR)
  • 【鸿蒙系统】 ---Harmony 鸿蒙编译构建指导(一)
  • 【Python】使用selenium对Poe批量模拟注册脚本
  • Docker使用之java项目工程的部署
  • Linux操作系统-汇编LED驱动程序基础
  • FX-数组的使用
  • 【OCR】OCR开源文字识别工具
  • Electron入门介绍
  • extjs4学习之配置
  • flutter的key在widget list的作用以及必要性
  • Java 内存分配及垃圾回收机制初探
  • js 实现textarea输入字数提示
  • laravel5.5 视图共享数据
  • PHP 7 修改了什么呢 -- 2
  • react 代码优化(一) ——事件处理
  • 当SetTimeout遇到了字符串
  • 得到一个数组中任意X个元素的所有组合 即C(n,m)
  • 构建二叉树进行数值数组的去重及优化
  • 基于Vue2全家桶的移动端AppDEMO实现
  • 力扣(LeetCode)56
  • 聊一聊前端的监控
  • 猫头鹰的深夜翻译:JDK9 NotNullOrElse方法
  • 如何学习JavaEE,项目又该如何做?
  • 使用 Xcode 的 Target 区分开发和生产环境
  • 使用iElevator.js模拟segmentfault的文章标题导航
  • Prometheus VS InfluxDB
  • TPG领衔财团投资轻奢珠宝品牌APM Monaco
  • #经典论文 异质山坡的物理模型 2 有效导水率
  • (01)ORB-SLAM2源码无死角解析-(66) BA优化(g2o)→闭环线程:Optimizer::GlobalBundleAdjustemnt→全局优化
  • (52)只出现一次的数字III
  • (delphi11最新学习资料) Object Pascal 学习笔记---第7章第3节(封装和窗体)
  • (Matalb时序预测)WOA-BP鲸鱼算法优化BP神经网络的多维时序回归预测
  • (九十四)函数和二维数组
  • (数位dp) 算法竞赛入门到进阶 书本题集
  • (五)网络优化与超参数选择--九五小庞
  • ***测试-HTTP方法
  • .NET C# 使用 SetWindowsHookEx 监听鼠标或键盘消息以及此方法的坑
  • .net 调用php,php 调用.net com组件 --
  • .NET的数据绑定
  • .NET命令行(CLI)常用命令
  • .NET应用架构设计:原则、模式与实践 目录预览
  • ;号自动换行
  • @Repository 注解
  • [ 攻防演练演示篇 ] 利用通达OA 文件上传漏洞上传webshell获取主机权限
  • [android] 手机卫士黑名单功能(ListView优化)
  • [AndroidStudio]_[初级]_[修改虚拟设备镜像文件的存放位置]
  • [Angular] 笔记 16:模板驱动表单 - 选择框与选项