当前位置: 首页 > news >正文

Open-Sora代码详细解读(2):时空3D VAE

Diffusion Models视频生成

前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。

目录

3D VAE原理

代码剖析

2D VAE

时间VAE

因果3D卷积


3D VAE原理

之前绝大多数都是2D VAE,特别是SDXL的VAE相当好用,很多人都拿来直接用了。但是在DiT-based的模型中,时间序列上如果再不做压缩的话,就已经很难训得动了。因此非常有必要在时间序列上进行压缩,3D VAE应运而生。

Open-Sora的方案是在2D VAE的基础上,再添加一个时间VAE,相比于EasyAnimate 和 CogVideoX的方案的Full Attention 存在劣势,但是可以充分利用到2D VAE的权重,成本更低。

代码剖析

2D VAE

来自华为pixart sdxl vae:

    vae_2d = dict(type="VideoAutoencoderKL",from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",subfolder="vae",micro_batch_size=micro_batch_size,local_files_only=local_files_only,)

时间VAE

    vae_temporal = dict(type="VAE_Temporal_SD",from_pretrained=None,)
@MODELS.register_module()
class VAE_Temporal(nn.Module):def __init__(self,in_out_channels=4,latent_embed_dim=4,embed_dim=4,filters=128,num_res_blocks=4,channel_multipliers=(1, 2, 2, 4),temporal_downsample=(True, True, False),num_groups=32,  # for nn.GroupNormactivation_fn="swish",):super().__init__()self.time_downsample_factor = 2 ** sum(temporal_downsample)# self.time_padding = self.time_downsample_factor - 1self.patch_size = (self.time_downsample_factor, 1, 1)self.out_channels = in_out_channels# NOTE: following MAGVIT, conv in bias=False in encoder first convself.encoder = Encoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim * 2,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups,  # for nn.GroupNormactivation_fn=activation_fn,)self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)self.decoder = Decoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups,  # for nn.GroupNormactivation_fn=activation_fn,)def get_latent_size(self, input_size):latent_size = []for i in range(3):if input_size[i] is None:lsize = Noneelif i == 0:time_padding = (0if (input_size[i] % self.time_downsample_factor == 0)else self.time_downsample_factor - input_size[i] % self.time_downsample_factor)lsize = (input_size[i] + time_padding) // self.patch_size[i]else:lsize = input_size[i] // self.patch_size[i]latent_size.append(lsize)return latent_sizedef encode(self, x):time_padding = (0if (x.shape[2] % self.time_downsample_factor == 0)else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor)x = pad_at_dim(x, (time_padding, 0), dim=2)encoded_feature = self.encoder(x)moments = self.quant_conv(encoded_feature).to(x.dtype)posterior = DiagonalGaussianDistribution(moments)return posteriordef decode(self, z, num_frames=None):time_padding = (0if (num_frames % self.time_downsample_factor == 0)else self.time_downsample_factor - num_frames % self.time_downsample_factor)z = self.post_quant_conv(z)x = self.decoder(z)x = x[:, :, time_padding:]return xdef forward(self, x, sample_posterior=True):posterior = self.encode(x)if sample_posterior:z = posterior.sample()else:z = posterior.mode()recon_video = self.decode(z, num_frames=x.shape[2])return recon_video, posterior, z

因果3D卷积

class CausalConv3d(nn.Module):def __init__(self,chan_in,chan_out,kernel_size: Union[int, Tuple[int, int, int]],pad_mode="constant",strides=None,  # allow custom stride**kwargs,):super().__init__()kernel_size = cast_tuple(kernel_size, 3)time_kernel_size, height_kernel_size, width_kernel_size = kernel_sizeassert is_odd(height_kernel_size) and is_odd(width_kernel_size)dilation = kwargs.pop("dilation", 1)stride = strides[0] if strides is not None else kwargs.pop("stride", 1)self.pad_mode = pad_modetime_pad = dilation * (time_kernel_size - 1) + (1 - stride)height_pad = height_kernel_size // 2width_pad = width_kernel_size // 2self.time_pad = time_padself.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)stride = strides if strides is not None else (stride, 1, 1)dilation = (dilation, 1, 1)self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)def forward(self, x):x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)x = self.conv(x)return x

http://www.lryc.cn/news/438001.html

相关文章:

  • 基于微信平台的旅游出行必备商城小程序+ssm(lw+演示+源码+运行)
  • AI绘画:科技赋能艺术的崭新时代
  • 性能诊断的方法(四):自下而上的资源诊断方法和发散的异常信息诊断方法
  • GDPU Vue前端框架开发 计数器
  • 最大流笔记
  • el-tree父子不互相关联时,手动实现全选、反选、子级全选、清空功能
  • 模板与泛型编程笔记(一)入门篇
  • 浅谈WebApi
  • 9月14日,每日信息差
  • 无人机控制与三维AI感知处理平台正式上线!
  • 9.11-kubeadm方式安装k8s
  • 限流,流量整形算法
  • 【C++知识扫盲】------C++ 中的引用入门
  • 【机器学习】6 ——最大熵模型
  • 小程序——生命周期
  • 基于微信小程序的宠物之家的设计与实现
  • 自定义EPICS在LabVIEW中的测试
  • 基于深度学习的农作物病害检测
  • 【C#】命名规范
  • 超级帐本(Hyperledger)
  • 如何精细优化网站关键词排名:实战经验分享
  • Ruoyi Cloud 本地启动
  • Nginx解析:入门笔记
  • 在 Mac 上安装双系统会影响性能吗,安装双系统会清除数据吗?
  • vue3提交按钮限制重复点击
  • Java | Leetcode Java题解之第395题至少有K个重复字符的最长子串
  • 20240915 每日AI必读资讯
  • 量化交易需要注意的关于股票交易挂单排队规则的问题
  • 应急响应实战---是谁修改了我的密码?
  • 知识的通用性