【diffusers 极速入门(二)】如何得到扩散去噪的中间结果?Pipeline callbacks 管道回调函数
本文是对 Hugging Face Diffusers 文档中关于回调函数的翻译与总结,:
管道回调函数
在管道的去噪循环中,可以使用callback_on_step_end
参数添加自定义回调函数。该回调函数在每一步结束时执行,并修改管道属性和变量,以供下一步使用。这在动态调整某些管道属性或修改张量变量时非常有用。利用回调函数,你可以实现新的功能而无需修改底层代码。
目前,Diffusers 仅支持callback_on_step_end
,如果你有其他执行点的回调需求,可以在 github 上提出功能请求。
官方回调函数
官方提供了一些可用于修改去噪循环的回调函数列表:
SDCFGCutoffCallback
:在一定步数后禁用 CFG。对于 SD 1.5 pipelines 适用, 包括 text-to-image, image-to-image, inpaint, controlnet。SDXLCFGCutoffCallback
:在一定步数后禁用 CFG。对于 SDXL pipelines 适用, 包括 text-to-image, image-to-image, inpaint, controlnet。IPAdapterScaleCutoffCallback
:在一定步数后禁用 IP Adapter。对所有支持 IP-Adapter 的 pipelines 适用。
要设置回调函数,可以指定cutoff_step_ratio
或cutoff_step_index
参数。
cutoff_step_ratio
:带有步长比的浮点数。cutoff_step_index
:一个整数,包含步数的确切编号。
示例代码
import torch
from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
from diffusers.callbacks import SDXLCFGCutoffCallbackcallback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)
# 也可以用 cutoff_step_index
# callback = SDXLCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10)pipeline = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",torch_dtype=torch.float16,variant="fp16",
).to("cuda")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True)prompt = "a sports car at the road, best quality, high quality, high detail, 8k resolution"
generator = torch.Generator(device="cpu").manual_seed(2628670641)out = pipeline(prompt=prompt,negative_prompt="",guidance_scale=6.5,num_inference_steps=25,generator=generator,callback_on_step_end=callback,
)out.images[0].save("official_callback.png")
动态无分类器引导
动态无分类器引导(classifier-free guidance,CFG)允许在一定步数后禁用 CFG,从而节省计算成本。回调函数应包含以下参数:
pipeline
:访问管道实例属性(如num_timesteps和guidance_scale)。step_index
和timestep
:当前步骤索引和时间步。在达到num_timesteps的40%后,使用step_index
关闭CFG。callback_kwargs
:包含在去噪循环中可以修改的张量变量。是一个dict,包含可以在去噪循环中修改的张量变量。- 它只包括
callback_on_step_end_tensor_inputs
参数中指定的变量,该参数被传递给管道的__call__
方法。 - 不同的管道可能使用不同的变量集,因此请检查管道的
_callback_tensor_inputs
属性以获取可以修改的变量列表。一些常见的变量包括latents和prompt_embeds。 - 对于此函数,请在将
guidance_scale
设置为0.0后更改prompt_embeds
的批处理大小,以使其正常工作。
- 它只包括
示例回调函数:
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
# adjust the batch_size of prompt_embeds according to guidance_scaleif step_index == int(pipeline.num_timesteps * 0.4):prompt_embeds = callback_kwargs["prompt_embeds"]prompt_embeds = prompt_embeds.chunk(2)[-1]# update guidance_scale and prompt_embedspipeline._guidance_scale = 0.0callback_kwargs["prompt_embeds"] = prompt_embedsreturn callback_kwargs
每步生成后显示图像(中间结果)
通过访问并转换潜在空间,可以在每步生成后显示图像。以下函数将 SDXL 的潜在空间(4 通道)转换为 RGB 张量(3 通道)。
- 使用以下函数将SDXL潜伏时间(4个通道)转换为RGB张量(3个通道)
def latents_to_rgb(latents):weights = ((60, -60, 25, -70),(60, -5, 15, -50),(60, 10, -5, -35))weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()image_array = image_array.transpose(1, 2, 0)return Image.fromarray(image_array)
- 使用该函数在每步生成后解码并保存潜在空间为图像。
def decode_tensors(pipe, step, timestep, callback_kwargs):latents = callback_kwargs["latents"]image = latents_to_rgb(latents)image.save(f"{step}.png")return callback_kwargs
- 将
decode_tensors
函数传递给callback_on_step_end
参数,以在每一步之后对张量进行解码。还需要在callback_on_step_end_tensor_inputs
参数中指定要修改的内容,在本例中为 latents。
from diffusers import AutoPipelineForText2Image
import torch
from PIL import Imagepipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",torch_dtype=torch.float16,variant="fp16",use_safetensors=True
).to("cuda")image = pipeline(prompt="A croissant shaped like a cute bear.",negative_prompt="Deformed, ugly, bad anatomy",callback_on_step_end=decode_tensors,callback_on_step_end_tensor_inputs=["latents"],
).images[0]
详细内容请参见Hugging Face Diffusers 官方文档。