当前位置: 首页 > news >正文

【AI 绘画】模型转换与快速生图(基于diffusers)

AI 绘画- 模型转换与快速生图(基于diffusers)

1. 本章介绍

本次主要展示一下不同框架内文生图模型转换,以及快速生成图片的方法。

SDXL文生图

2. sdxl_lightning基本原理

模型基本原理介绍如下

利用蒸馏方法获取小参数模型。首先,论文从128步直接蒸馏到32步,并使用MSE损失。在早期阶段,论文发现MSE已足够。此外,在此阶段,论文仅应用了无分类器指导(CFG),并使用了6的指导尺度,而没有使用任何负面提示。

接着,论文通过对抗性损失按照以下顺序进一步减少步数:32 → 8 → 4 → 2 → 1。在每个阶段,论文首先使用条件目标进行训练,以保持概率流动,然后使用无条件目标进行训练,以放松模式覆盖。

在每个阶段,论文首先使用LoRA并结合这两个目标进行训练,然后合并LoRA,并进一步用无条件目标训练整个UNet。论文发现微调整个UNet可以获得更好的性能,而LoRA模块可以用于其他基础模型。论文的LoRA设置与LCM-LoRA 相同,即在所有卷积和线性权重上使用64的秩,但不包括输入和输出卷积以及共享的时间嵌入线性层。论文没有在判别器上使用LoRA,并且在每个阶段都会重新初始化判别器。

3. 环境安装

diffusers是Hugging Face推出的一个diffusion库,它提供了简单方便的diffusion推理训练pipe,同时拥有一个模型和数据社区,代码可以像torchhub一样直接从指定的仓库去调用别人上传的数据集和pretrain checkpoint。除此之外,安装方便,代码结构清晰,注释齐全,二次开发会十分有效率。

# pip
pip install --upgrade diffusers[torch]
# conda
conda install -c conda-forge diffusers

4. 代码实现

主要测试代码:

4.1 sdxl_lightning文生图


from diffusers import DPMSolverMultistepScheduler,UNet2DConditionModel,StableDiffusionXLPipeline,DiffusionPipeline
import torch
from safetensors.torch import load_filedevice = "cuda"# load both base & refiner
# stabilityai/stable-diffusion-xl-base-1.0
# base = DiffusionPipeline.from_pretrained(
#     "./data/data282269/",device_map=None,torch_dtype=torch.float16, variant="fp16", use_safetensors=True
# )
# !unzip  ./data/data283423/SDXL.zip -D ./data/data283423/
# load base model
unet = UNet2DConditionModel.from_config("./data/data283423/SDXL/unet/config.json").to( device, torch.float16)
unet.load_state_dict(load_file("./data/data283423/sdxl_lightning_4step_unet.safetensors", device= device))base = StableDiffusionXLPipeline.from_pretrained("./data/data283423/SDXL/", unet=unet, torch_dtype=torch.float16, variant="fp16"
).to( device)# # scheduler
# base.scheduler = DPMSolverMultistepScheduler.from_config(
#     base.scheduler.config, timestep_spacing="trailing"
# )base.to("cuda")# Define how many steps and what % of steps to be run on each experts (80/20) here
n_steps = 4
high_noise_frac = 0.8prompt = "masterpiece, best quality,Realistic, cinematic quality,A majestic lion jumping from a big stone at night "#"A majestic lion jumping from a big stone at night"
negative_prompt = ('flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,')
# run both experts
image = base(prompt=prompt,negative_prompt = negative_prompt,num_inference_steps=n_steps,#  denoising_end=high_noise_frac,#output_type="latent",
).images[0]image.save("./data/section-1/h5.png")

4.2 safetensors模型加载

如果想将safetensors模型加载到diffusers中,需要使用如下代码


pipeline = AutoPipelineForImage2Image.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)

转换为


from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckptbase = download_from_original_stable_diffusion_ckpt(from_safetensors = True,checkpoint_path_or_dict = "./data/data282269/SDXL_doll.safetensors"
)

4.2 safetensors模型转换

如果想将safetensors模型转化为diffusers常用格式,需要使用如下代码


"""Conversion script for the LDM checkpoints."""import argparse
import importlibimport torchfrom diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckptif __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument("--checkpoint_path", default="./data/data282269/SDXL_doll.safetensors", type=str, help="Path to the checkpoint to convert." #required=True, )# !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yamlparser.add_argument("--original_config_file",default=None,type=str,help="The YAML config file corresponding to the original architecture.",)parser.add_argument("--config_files",default=None,type=str,help="The YAML config file corresponding to the architecture.",)parser.add_argument("--num_in_channels",default=None,type=int,help="The number of input channels. If `None` number of input channels will be automatically inferred.",)parser.add_argument("--scheduler_type",default="pndm",type=str,help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",)parser.add_argument("--pipeline_type",default=None,type=str,help=("The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'"". If `None` pipeline will be automatically inferred."),)parser.add_argument("--image_size",default=None,type=int,help=("The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"" Base. Use 768 for Stable Diffusion v2."),)parser.add_argument("--prediction_type",default=None,type=str,help=("The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"" Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2."),)parser.add_argument("--extract_ema",action="store_true",help=("Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."),)parser.add_argument("--upcast_attention",action="store_true",help=("Whether the attention computation should always be upcasted. This is necessary when running stable"" diffusion 2.1."),)parser.add_argument("--from_safetensors",default= "true",#  action="store_true",help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",)parser.add_argument("--to_safetensors",action="store_true",help="Whether to store pipeline in safetensors format or not.",)parser.add_argument("--dump_path", default="./data/data282269/", type=str,  help="Path to the output model.")parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")parser.add_argument("--stable_unclip",type=str,default=None,required=False,help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.",)parser.add_argument("--stable_unclip_prior",type=str,default=None,required=False,help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.",)parser.add_argument("--clip_stats_path",type=str,help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.",required=False,)parser.add_argument("--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint.")parser.add_argument("--half", action="store_true", help="Save weights in half precision.")parser.add_argument("--vae_path",type=str,default=None,required=False,help="Set to a path, hub id to an already converted vae to not convert it again.",)parser.add_argument("--pipeline_class_name",type=str,default=None,required=False,help="Specify the pipeline class name",)args = parser.parse_args()if args.pipeline_class_name is not None:library = importlib.import_module("diffusers")class_obj = getattr(library, args.pipeline_class_name)pipeline_class = class_objelse:pipeline_class = Nonepipe = download_from_original_stable_diffusion_ckpt(checkpoint_path_or_dict=args.checkpoint_path,original_config_file=args.original_config_file,config_files=args.config_files,image_size=args.image_size,prediction_type=args.prediction_type,model_type=args.pipeline_type,extract_ema=args.extract_ema,scheduler_type=args.scheduler_type,num_in_channels=args.num_in_channels,upcast_attention=args.upcast_attention,from_safetensors=args.from_safetensors,device=args.device,stable_unclip=args.stable_unclip,stable_unclip_prior=args.stable_unclip_prior,clip_stats_path=args.clip_stats_path,controlnet=args.controlnet,vae_path=args.vae_path,pipeline_class=pipeline_class,)if args.half:pipe.to(dtype=torch.float16)if args.controlnet:# only save the controlnet modelpipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)else:pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

5. 资源链接

https://www.liblib.art/modelinfo/8345679083144158adb64b80c58e3afd

http://www.lryc.cn/news/424619.html

相关文章:

  • 甄选范文“论软件设计方法及其应”软考高级论文系统架构设计师论文
  • leetcode线段树(2940. 找到 Alice 和 Bob 可以相遇的建筑)
  • 用于不平衡医疗数据分类的主动SMOTE
  • linux文件更新日期与系统日期比较
  • leetCode - - - 哈希表
  • NGINX自动清理180天之前的日志
  • jackson 轻松搞定接口数据脱敏
  • Nginx 正则表达式与rewrite
  • tekton什么情况下在Dockerfile中需要用copy
  • 第九届世界渲染大赛在哪里提交作品呢?
  • fastjson(autoType)反序列化漏洞
  • Java入门基础16:集合框架1(Collection集合体系、List、Set)
  • Qt如何调用接口
  • Android14之解决编译libaaudio.so报错问题(二百二十七)
  • 【专题】2024年7月人工智能AI行业报告合集汇总PDF分享(附原数据表)
  • 干货分享|如何使用Stable Diffusion打造会说话的数字人?
  • OrangePi AIpro学习4 —— 昇腾AI模型推理 C++版
  • vue js 多组件异步请求解决方案
  • 【Android】不同系统版本获取设备MAC地址
  • 残差网络--NLP上的应用
  • 1章4节:数据可视化, R 语言的静态绘图和 Shiny 的交互可视化演示(更新2024/08/14)
  • 浅谈个人用户如何玩转HTTP代理
  • 动手研发实时口译系统
  • C#(asp.net)电商后台管理系统-计算机毕业设计源码70015
  • Unity 中创建动画的教程
  • 2024年最全渗透测试学习指南,小白也能轻松hold住!零基础到精通,看完这篇就够了!
  • 有道云docx转换markdown,导入hugo发布到github page,多平台发布适配
  • 如何理解:进程控制
  • 工业互联网边缘计算实训室解决方案
  • Android全面解析之Context机制(一) :初识Android context