当前位置: 首页 > news >正文

【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)

系列文章目录

  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
  • 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
  • 【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇
  • 本文则以 IP-Adapter Plus 训练代码为例,进行详细介绍。

文章目录

  • 系列文章目录
  • 整体训练框架
  • 一、训了哪些部分?
      • 第一块 - image_proj_model
      • 第二块 - adapter_modules
  • 二、训练目标


整体训练框架

在这里插入图片描述

一、训了哪些部分?

本文以原仓库 1 的 /path/IP-Adapter/tutorial_train_plus.py 为例,该文件为 SD1.5 IP-Adapter Plus 的训练代码。

从以下代码可以看出,IPAdapter 主要由 unet, image_proj_model, adapter_modules 3 个部分组成,而权重需要被优化的(训练到的)只有 ip_adapter.image_proj_model.parameters(), 和 ip_adapter.adapter_modules.parameters() 。

	ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)# optimizerparams_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(),  ip_adapter.adapter_modules.parameters())optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)...# Prepare everything with our `accelerator`.ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)

第一块 - image_proj_model

在 IP-Adapter Plus 中,采用的是 Resampler 作为img embedding 到 ip_tokens 的映射网络,对图像(image prompt)中信息的抽取更加细粒度。其他模块都不需要梯度下降,如下代码所示。

	# freeze parameters of models to save more memoryunet.requires_grad_(False)vae.requires_grad_(False)text_encoder.requires_grad_(False)image_encoder.requires_grad_(False)#ip-adapter-plusimage_proj_model = Resampler(dim=unet.config.cross_attention_dim,depth=4,dim_head=64,heads=12,num_queries=args.num_tokens,embedding_dim=image_encoder.config.hidden_size,output_dim=unet.config.cross_attention_dim,ff_mult=4)...

第二块 - adapter_modules

Decoupled cross-attention 则在以下代码中进行初始化,关键是在特定的 unet 层中进行替换,详细位置可以参考前文中的图片,本文的重点是后续训练的实现。

	# init adapter modulesattn_procs = {}unet_sd = unet.state_dict()for name in unet.attn_processors.keys():cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dimif name.startswith("mid_block"):hidden_size = unet.config.block_out_channels[-1]elif name.startswith("up_blocks"):block_id = int(name[len("up_blocks.")])hidden_size = list(reversed(unet.config.block_out_channels))[block_id]elif name.startswith("down_blocks"):block_id = int(name[len("down_blocks.")])hidden_size = unet.config.block_out_channels[block_id]if cross_attention_dim is None:attn_procs[name] = AttnProcessor()else:layer_name = name.split(".processor")[0]weights = {"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],}attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens)attn_procs[name].load_state_dict(weights)unet.set_attn_processor(attn_procs)adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

二、训练目标

每个 epoch 是遍历完一整个 dataset,我们直接从每个训练步的循环中来看:

  • latents 是通过 vae 将输入的 image prompt 压到了隐空间(latent space)中。
  • 准备相应的 noise 和 timesteps ,再通过 noise_scheduler 来制作出 noisy_latents。
        for step, batch in enumerate(train_dataloader):load_data_time = time.perf_counter() - beginwith accelerator.accumulate(ip_adapter):# Convert images to latent spacewith torch.no_grad():latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()latents = latents * vae.config.scaling_factor# Sample noise that we'll add to the latentsnoise = torch.randn_like(latents)bsz = latents.shape[0]# Sample a random timestep for each imagetimesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)timesteps = timesteps.long()# Add noise to the latents according to the noise magnitude at each timestep# (this is the forward diffusion process)noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
  • clip_images 和 drop_image_embed 是在准备数据的过程中,做了一个随机 drop 的方式进行数据增强,提升模型鲁棒性。
    • 数据增强:通过随机丢弃一些图像,模型被迫学习从剩余的图像中提取信息,这可以增加模型的泛化能力。
    • 模型鲁棒性:训练模型以处理不完整的数据,使其在实际应用中对缺失数据更加鲁棒。
     clip_images = []for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):if drop_image_embed == 1:clip_images.append(torch.zeros_like(clip_image))else:clip_images.append(clip_image)clip_images = torch.stack(clip_images, dim=0)with torch.no_grad():image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True).hidden_states[-2]with torch.no_grad():encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]

  1. https://github.com/tencent-ailab/IP-Adapter/tree/main ↩︎

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 类图的关联关系
  • VUE-组件间通信(三)全局事件总线
  • CAD二次开发IFoxCAD框架系列(25)- 自动加载和初始化的使用
  • 【flask】python框架flask的hello world
  • YOLOv8改进 | 主干篇 | YOLOv8引入EfficientViT替换Backbone
  • 行为识别实战第二天——Yolov5+SlowFast+deepsort: Action Detection(PytorchVideo)
  • 【算法每日一练及解题思路】计算以空格隔开的字符串的最后一个单词的长度
  • 1.【R语言】R语言的下载和安装
  • css中 display block属性的用法
  • 找单身狗(c语言)
  • 【论文阅读】通过使用实体增强框架融合多种多模态线索来改进假新闻检测
  • Kotlin 泛型小知识: `<T>`, `<out T>`, `<in T>` 的区别
  • Oracle查询优化--分区表建立/普通表转分区表
  • C++:string类(1)
  • 根DNS服务器
  • [NodeJS] 关于Buffer
  • Docker入门(二) - Dockerfile
  • js中forEach回调同异步问题
  • mongodb--安装和初步使用教程
  • Mysql优化
  • Next.js之基础概念(二)
  • React-redux的原理以及使用
  • WordPress 获取当前文章下的所有附件/获取指定ID文章的附件(图片、文件、视频)...
  • 规范化安全开发 KOA 手脚架
  • 前端知识点整理(待续)
  • 通过git安装npm私有模块
  • 在 Chrome DevTools 中调试 JavaScript 入门
  • Java数据解析之JSON
  • 新海诚画集[秒速5センチメートル:樱花抄·春]
  • ​2021半年盘点,不想你错过的重磅新书
  • ​猴子吃桃问题:每天都吃了前一天剩下的一半多一个。
  • #pragma data_seg 共享数据区(转)
  • %@ page import=%的用法
  • (12)Hive调优——count distinct去重优化
  • (C语言)球球大作战
  • (ISPRS,2021)具有遥感知识图谱的鲁棒深度对齐网络用于零样本和广义零样本遥感图像场景分类
  • (ZT)薛涌:谈贫说富
  • (九)信息融合方式简介
  • (论文阅读11/100)Fast R-CNN
  • (算法)前K大的和
  • (转)从零实现3D图像引擎:(8)参数化直线与3D平面函数库
  • (转)可以带来幸福的一本书
  • (转)平衡树
  • (轉)JSON.stringify 语法实例讲解
  • .NET COER+CONSUL微服务项目在CENTOS环境下的部署实践
  • .NET Compact Framework 多线程环境下的UI异步刷新
  • .Net Core/.Net6/.Net8 ,启动配置/Program.cs 配置
  • .NET 动态调用WebService + WSE + UsernameToken
  • .net 流——流的类型体系简单介绍
  • .NET编程——利用C#调用海康机器人工业相机SDK实现回调取图与软触发取图【含免费源码】
  • .Net的DataSet直接与SQL2005交互
  • .NET基础篇——反射的奥妙
  • :class的用法及应用
  • [ linux ] linux 命令英文全称及解释
  • [20170705]lsnrctl status LISTENER_SCAN1