.net网站开发后编译,城市建设协会网站,wordpress 搜索结果分页,百度怎么推广AIGC专栏9——Scalable Diffusion Models with Transformers #xff08;DiT#xff09;结构解析 学习前言源码下载地址网络构建一、什么是Diffusion Transformer (DiT)二、DiT的组成三、生成流程1、采样流程a、生成初始噪声b、对噪声进行N次采样c、单次采样解析I、预测噪声I… AIGC专栏9——Scalable Diffusion Models with Transformers DiT结构解析 学习前言源码下载地址网络构建一、什么是Diffusion Transformer (DiT)二、DiT的组成三、生成流程1、采样流程a、生成初始噪声b、对噪声进行N次采样c、单次采样解析I、预测噪声II、施加噪声 d、预测噪声过程中的网络结构解析i、adaLN-Zero结构解析ii、patch分块处理iii、Transformer特征提取iv、上采样 3、隐空间解码生成图片 类别到图像预测过程代码 学习前言
近期Sora大火它底层是Diffusion Transformer本质上是使用Transformer结构代替原本的Unet进行噪声预测好处是统一了文本生成与视频生成的结构。这训练优化和预测优化而言是个好事因为只需要优化一种结构就够了。虽然觉得OpenAI是大力出奇迹但还是得学
源码下载地址
https://github.com/bubbliiiing/DiT-pytorch
喜欢的可以点个star噢。
网络构建
一、什么是Diffusion Transformer (DiT)
DiT基于扩散模型所以不免包含不断去噪的过程如果是图生图的话还有不断加噪的过程此时离不开DDPM那张老图如下 DiT相比于DDPM使用了更快的采样器也使用了更大的分辨率与Stable Diffusion一样使用了隐空间的扩散但可能更偏研究性质一些没有使用非常大的数据集进行预训练只使用了imagenet进行预训练。
与Stable Diffusion不同的是DiT的网络结构完全由Transformer组成没有Unet中大量的上下采样结构更为简单清晰。
本文主要是解析一下整个DiT模型的结构组成并简单一次扩散多次扩散的流程。本文代码来自于DiffusersDiffusers代码较为简单清晰是一个非常好的仓库学习起来也比较快。
二、DiT的组成
DiT由三大部分组成。 1、Sampler采样器。 2、Variational Autoencoder (VAE) 变分自编码器。 3、UNet 主网络噪声预测器。
每一部分都很重要由于DiT的官方版本并没有在 大规模文本图片 的 数据集上训练只使用了imagenet进行预训练。所以它并没有文本输入而是以标签作为输入。因此DiT只能按照类别进行图片生成可以生成imagenet中的1000类
三、生成流程 生成流程分为两个部分 1、生成正态分布向量后进行若干次采样。 2、进行解码。
由于DiT只能按照类别进行图片生成所以无需对文本进行编码直接传入类别的对应的id0-1000即可指定类别。
# --------------------------------- #
# 前处理
# --------------------------------- #
# 生成latent
latents randn_tensor(shape(batch_size, latent_channels, latent_size, latent_size),generatorgenerator,deviceself._execution_device,dtypeself.transformer.dtype,
)
latent_model_input torch.cat([latents] * 2) if guidance_scale 1 else latents# 将输入的label 与 null label进行concatnull label是负向提示类。
class_labels torch.tensor(class_labels, deviceself._execution_device).reshape(-1)
class_null torch.tensor([1000] * batch_size, deviceself._execution_device)
class_labels_input torch.cat([class_labels, class_null], 0) if guidance_scale 1 else class_labels# 设置生成的步数
self.scheduler.set_timesteps(num_inference_steps)# --------------------------------- #
# 扩散生成
# --------------------------------- #
# 开始N步扩散的循环
for t in self.progress_bar(self.scheduler.timesteps):if guidance_scale 1:half latent_model_input[: len(latent_model_input) // 2]latent_model_input torch.cat([half, half], dim0)latent_model_input self.scheduler.scale_model_input(latent_model_input, t)# 处理timestepstimesteps tif not torch.is_tensor(timesteps):is_mps latent_model_input.device.type mpsif isinstance(timesteps, float):dtype torch.float32 if is_mps else torch.float64else:dtype torch.int32 if is_mps else torch.int64timesteps torch.tensor([timesteps], dtypedtype, devicelatent_model_input.device)elif len(timesteps.shape) 0:timesteps timesteps[None].to(latent_model_input.device)# broadcast to batch dimension in a way thats compatible with ONNX/Core MLtimesteps timesteps.expand(latent_model_input.shape[0])# 将隐含层特征、时间步和种类输入传入到transformers中noise_pred self.transformer(latent_model_input, timesteptimesteps, class_labelsclass_labels_input).sample# perform guidanceif guidance_scale 1:# 在通道上做分割取出生图部分的通道eps, rest noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]cond_eps, uncond_eps torch.split(eps, len(eps) // 2, dim0)half_eps uncond_eps guidance_scale * (cond_eps - uncond_eps)eps torch.cat([half_eps, half_eps], dim0)noise_pred torch.cat([eps, rest], dim1)# 对结果进行分割取出生图部分的通道if self.transformer.config.out_channels // 2 latent_channels:model_output, _ torch.split(noise_pred, latent_channels, dim1)else:model_output noise_pred# 通过采样器将这一步噪声施加到隐含层latent_model_input self.scheduler.step(model_output, t, latent_model_input).prev_sampleif guidance_scale 1:latents, _ latent_model_input.chunk(2, dim0)
else:latents latent_model_input# --------------------------------- #
# 后处理
# --------------------------------- #
# 通过vae进行解码
latents 1 / self.vae.config.scaling_factor * latents
samples self.vae.decode(latents).samplesamples (samples / 2 0.5).clamp(0, 1)# 转化为float32类别
samples samples.cpu().permute(0, 2, 3, 1).float().numpy()1、采样流程
a、生成初始噪声 在生成初始噪声前介绍一下VAEVAE是变分自编码器可以将输入图片进行编码一个高宽原本为256x256x3的图片在使用VAE编码后会变成32x32x4这个4是人为设定的不必纠结为什么不是3。这个时候我们就使用一个相对简单的矩阵代替原有的256x256x3的图片了传输与存储成本就很低。在实际要去看的时候可以对32x32x4的矩阵进行解码获得256x256x3的图片。
因此如果 我们要生成一个256x256x3的图片那么我们只需要初始化一个32x32x4的隐向量在隐空间进行扩散即可。在隐空间扩散好后再使用解码器就可以生成256x256x3的图像。
在代码中我们确实是这么做的初始噪声的生成函数为randn_tensor是diffusers自带的一个函数尽管它写的很长但实际生成初始噪声的代码只有一行
latents torch.randn(shape, generatorgenerator, devicerand_device, dtypedtype, layoutlayout).to(device)代码本来位于diffusers的工具文件中为了方便查看我将其复制到nets/pipeline.py中。
b、对噪声进行N次采样 既然Stable Diffusion是一个不断扩散的过程那么少不了不断的去噪声那么怎么去噪声便是一个问题。
在上一步中我们已经获得了一个latents它是一个符合正态分布的向量我们便从它开始去噪声。
在代码中我们有一个对时间步的循环会不断的将隐含层向量输入到transformers中进行噪声预测并且一步一步的去噪。
# --------------------------------- #
# 扩散生成
# --------------------------------- #
# 开始N步扩散的循环
for t in self.progress_bar(self.scheduler.timesteps):if guidance_scale 1:half latent_model_input[: len(latent_model_input) // 2]latent_model_input torch.cat([half, half], dim0)latent_model_input self.scheduler.scale_model_input(latent_model_input, t)# 处理timestepstimesteps tif not torch.is_tensor(timesteps):is_mps latent_model_input.device.type mpsif isinstance(timesteps, float):dtype torch.float32 if is_mps else torch.float64else:dtype torch.int32 if is_mps else torch.int64timesteps torch.tensor([timesteps], dtypedtype, devicelatent_model_input.device)elif len(timesteps.shape) 0:timesteps timesteps[None].to(latent_model_input.device)# broadcast to batch dimension in a way thats compatible with ONNX/Core MLtimesteps timesteps.expand(latent_model_input.shape[0])# 将隐含层特征、时间步和种类输入传入到transformers中noise_pred self.transformer(latent_model_input, timesteptimesteps, class_labelsclass_labels_input).sample# perform guidanceif guidance_scale 1:# 在通道上做分割取出生图部分的通道eps, rest noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]cond_eps, uncond_eps torch.split(eps, len(eps) // 2, dim0)half_eps uncond_eps guidance_scale * (cond_eps - uncond_eps)eps torch.cat([half_eps, half_eps], dim0)noise_pred torch.cat([eps, rest], dim1)# 对结果进行分割取出生图部分的通道if self.transformer.config.out_channels // 2 latent_channels:model_output, _ torch.split(noise_pred, latent_channels, dim1)else:model_output noise_pred# 通过采样器将这一步噪声施加到隐含层latent_model_input self.scheduler.step(model_output, t, latent_model_input).prev_samplec、单次采样解析
I、预测噪声
在进行单次采样前需要首先判断是否有负向提示类如果有我们需要同时处理负向提示类否则仅仅需要处理正向提示类。实际使用的时候一般都有负向提示类效果会好一些所以默认进入对应的处理过程。
在处理负向提示类时我们对输入进来的隐向量进行复制一个属于正向提示类0-999一个属于负向提示类1000。它们是在batch_size维度进行堆叠二者不会互相影响。然后我们将正向提示类和负向提示类1000也在batch_size维度堆叠。代码中如果guidance_scale1则代表需要负向提示类。
# --------------------------------- #
# 前处理
# --------------------------------- #
# 生成latent
latents randn_tensor(shape(batch_size, latent_channels, latent_size, latent_size),generatorgenerator,deviceself._execution_device,dtypeself.transformer.dtype,
)
latent_model_input torch.cat([latents] * 2) if guidance_scale 1 else latents# 将输入的label 与 null label进行concatnull label是负向提示类。
class_labels torch.tensor(class_labels, deviceself._execution_device).reshape(-1)
class_null torch.tensor([1000] * batch_size, deviceself._execution_device)
class_labels_input torch.cat([class_labels, class_null], 0) if guidance_scale 1 else class_labels堆叠完后我们将隐向量、步数和类别条件一起传入网络中将结果在bs维度进行使用chunk进行分割。
因为我们在堆叠时正向提示类放在了前面。因此分割好后前半部分cond_eps属于利用正向提示类得到的后半部分uncond_eps属于利用负向提示类得到的我们本质上应该扩大正向提示类的影响远离负向提示类的影响。因此我们使用cond_eps-uncond_eps计算二者的距离使用scale扩大二者的距离。在uncond_eps基础上得到最后的隐向量。
# 堆叠完后隐向量、步数和prompt条件一起传入网络中将结果在bs维度进行使用chunk进行分割
e_t_uncond, e_t self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t e_t_uncond unconditional_guidance_scale * (e_t - e_t_uncond)此时获得的eps就是通过隐向量和提示类共同获得的预测噪声啦。
II、施加噪声
在获得噪声后我们还要将获得的新噪声按照一定的比例添加到原来的原始噪声上。
diffusers的代码并没有将施加噪声的代码写在明面上而是使用采样器的step方法替代采样流程与DDIM一致因此直接参考DDIM公式即可此前在Stable Diffusion相关博文中写到过DDIM公式可以参考对应博文了解一下。
latent_model_input self.scheduler.step(model_output, t, latent_model_input).prev_sampled、预测噪声过程中的网络结构解析
这个部分是DiT与Stable Diffusion最大的不同DiT将网络结构从Unet转换成了Transformers
i、adaLN-Zero结构解析
Transformers主要做的工作是结合 时间步t 和 类别 计算这一时刻的噪声。此处的Transformers结构与VIT中的Transformers基本一致但为了融合时间步t和类别新增了一个Embed层和adaLN-Zero结构。
Embed层主要是将输入进来的timestep和label进行向量化。adaLN-Zero则是通过全连接对向量化后的timestep和label进行映射然后分为6个部分分别作用于DiT的不同阶段用于缩放scale、偏置shift、bias与门函数gate。
如下是Embed层和adaLN-Zero结构的代码与示意图
class CombinedTimestepLabelEmbeddings(nn.Module):def __init__(self, num_classes, embedding_dim, class_dropout_prob0.1):super().__init__()self.time_proj Timesteps(num_channels256, flip_sin_to_cosTrue, downscale_freq_shift1)self.timestep_embedder TimestepEmbedding(in_channels256, time_embed_dimembedding_dim)self.class_embedder LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)def forward(self, timestep, class_labels, hidden_dtypeNone):timesteps_proj self.time_proj(timestep)timesteps_emb self.timestep_embedder(timesteps_proj.to(dtypehidden_dtype)) # (N, D)class_labels self.class_embedder(class_labels) # (N, D)conditioning timesteps_emb class_labels # (N, D)return conditioningclass AdaLayerNormZero(nn.Module):Norm layer adaptive layer norm zero (adaLN-Zero).def __init__(self, embedding_dim, num_embeddings):super().__init__()self.emb CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)self.silu nn.SiLU()self.linear nn.Linear(embedding_dim, 6 * embedding_dim, biasTrue)self.norm nn.LayerNorm(embedding_dim, elementwise_affineFalse, eps1e-6)def forward(self, x, timestep, class_labels, hidden_dtypeNone):emb self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtypehidden_dtype)))shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp emb.chunk(6, dim1)x self.norm(x) * (1 scale_msa[:, None]) shift_msa[:, None]return x, gate_msa, shift_mlp, scale_mlp, gate_mlpii、patch分块处理
在代码中我们使用一个PatchEmbed类对输入的隐含层向量进行分块该操作便是VIT中的patchc操作通过卷积进行类似于下采样的操作可以减少计算量。 如下为patch分块处理的代码核心是使用步长和卷积核大小一样的Conv2d模块进行处理由于步长和卷积核大小一致每个图片区域的特征提取过程就不会有重叠。
我们初始化生成的隐含层向量为32x32x4。在DiT-XL-2中patch处理的步长和卷积核大小为2通道为1152在处理完成后特征的通道上升高宽被压缩此时我们获得一个16x16x1152的新特征然后我们将其在长宽上进行平铺获得一个256x1152的向量并且加上位置信息。
class PatchEmbed(nn.Module):2D Image to Patch Embeddingdef __init__(self,height224,width224,patch_size16,in_channels3,embed_dim768,layer_normFalse,flattenTrue,biasTrue,):super().__init__()num_patches (height // patch_size) * (width // patch_size)self.flatten flattenself.layer_norm layer_normself.proj nn.Conv2d(in_channels, embed_dim, kernel_size(patch_size, patch_size), stridepatch_size, biasbias)if layer_norm:self.norm nn.LayerNorm(embed_dim, elementwise_affineFalse, eps1e-6)else:self.norm Nonepos_embed get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))self.register_buffer(pos_embed, torch.from_numpy(pos_embed).float().unsqueeze(0), persistentFalse)def forward(self, latent):latent self.proj(latent)if self.flatten:latent latent.flatten(2).transpose(1, 2) # BCHW - BNCif self.layer_norm:latent self.norm(latent)return latent self.pos_embediii、Transformer特征提取
此后我们将向量传入Transformer中进行特征提取对应图中的DiT Block。
256x1152的特征会通过图中红框的部分而时间步t 和 类别会通过途中绿框的部分。
红框部分的结构除了缩放scale、偏置shift、bias与门函数gate对应图中的α代码中是gate但图中写scale外其它部分与VIT一模一样可参考博文VIT结构解析进行了解主要工作的模块是Self-Attention和Pointwise FeedforwardMLP。这两个模块的输入和输出均为256x1152的特征。
而缩放scale、偏置shift、bias与门函数gate分别对应了图中的γ、β和α。通过adaLN-Zero结构获得γ、β分别在 Self-Attention和Pointwise Feedforward 的处理前 进行特征的 缩放与偏置 而Pointwise Feedforward则在 Self-Attention和Pointwise Feedforward 的处理后 进行特征的 缩放。在代码中我添加了中文注释方便读者区分添加缩放、偏置和门函数的位置。
DiT Block的输入和输出特征均为256x1152。
class BasicTransformerBlock(nn.Module):def __init__(self,dim: int,num_attention_heads: int,attention_head_dim: int,dropout0.0,cross_attention_dim: Optional[int] None,activation_fn: str geglu,num_embeds_ada_norm: Optional[int] None,attention_bias: bool False,only_cross_attention: bool False,double_self_attention: bool False,upcast_attention: bool False,norm_elementwise_affine: bool True,norm_type: str layer_norm,final_dropout: bool False,):super().__init__().......def forward(self,hidden_states: torch.FloatTensor,attention_mask: Optional[torch.FloatTensor] None,encoder_hidden_states: Optional[torch.FloatTensor] None,encoder_attention_mask: Optional[torch.FloatTensor] None,timestep: Optional[torch.LongTensor] None,cross_attention_kwargs: Dict[str, Any] None,class_labels: Optional[torch.LongTensor] None,):# Notice that normalization is always applied before the real computation in the following blocks.# 1. Self-Attentionif self.use_ada_layer_norm:norm_hidden_states self.norm1(hidden_states, timestep)elif self.use_ada_layer_norm_zero:# 在norm1中已经进行了输入特征的缩放与偏置norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp self.norm1(hidden_states, timestep, class_labels, hidden_dtypehidden_states.dtype)else:norm_hidden_states self.norm1(hidden_states)cross_attention_kwargs cross_attention_kwargs if cross_attention_kwargs is not None else {}attn_output self.attn1(norm_hidden_states,encoder_hidden_statesencoder_hidden_states if self.only_cross_attention else None,attention_maskattention_mask,**cross_attention_kwargs,)# 在self attention后再次进行了特征的缩放gateif self.use_ada_layer_norm_zero:attn_output gate_msa.unsqueeze(1) * attn_outputhidden_states attn_output hidden_states# 2. Cross-Attentionif self.attn2 is not None:norm_hidden_states (self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states))attn_output self.attn2(norm_hidden_states,encoder_hidden_statesencoder_hidden_states,attention_maskencoder_attention_mask,**cross_attention_kwargs,)hidden_states attn_output hidden_states# 3. Feed-forwardnorm_hidden_states self.norm3(hidden_states)# 在mlp前进行了输入特征的缩放与偏置if self.use_ada_layer_norm_zero:norm_hidden_states norm_hidden_states * (1 scale_mlp[:, None]) shift_mlp[:, None]if self._chunk_size is not None:# feed_forward_chunk_size can be used to save memoryif norm_hidden_states.shape[self._chunk_dim] % self._chunk_size ! 0:raise ValueError(fhidden_states dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate chunk_size when calling unet.enable_forward_chunking.)num_chunks norm_hidden_states.shape[self._chunk_dim] // self._chunk_sizeff_output torch.cat([self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dimself._chunk_dim)],dimself._chunk_dim,)else:ff_output self.ff(norm_hidden_states)# 在mlp后再次进行了特征的缩放gateif self.use_ada_layer_norm_zero:ff_output gate_mlp.unsqueeze(1) * ff_outputhidden_states ff_output hidden_statesreturn hidden_statesiv、上采样
虽然这个部分学名可能不叫上采样但是我觉得用上采样来描述它还是比较合适的因为我们前面做过patch分块处理所以隐含层的高宽被压缩而这一步则是将隐含层的高宽再还原回去。
在这里我们会对256x1152进行两次全连接一次LayerNorm两次全连接的神经元个数分别为2304和patch_size * patch_size * out_channels。第一次全连接目的是扩宽通道数第二次全链接则是还原高宽。两次全连接后在DiT-XL-2中out_channels为88可拆分为4 4前面的4用于直接预测噪声后面的4用于根据 x t − 1 x_{t-1} xt−1均值和方差计算KL散度特征层的shape从256x1152变为256x32。
然后我们会进行一系列shape变换首先将256x1152变为16x16x2x2x8然后进行转置变为8x16x2x16x2然后还原高宽变为8x32x32。此时上采样结束。该部分对应了图中的Linear And Reshape。
上采样代码如下所示
# 3. Output
conditioning self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtypehidden_states.dtype
)
shift, scale self.proj_out_1(F.silu(conditioning)).chunk(2, dim1)
hidden_states self.norm_out(hidden_states) * (1 scale[:, None]) shift[:, None]
hidden_states self.proj_out_2(hidden_states)# unpatchify
height width int(hidden_states.shape[1] ** 0.5)
hidden_states hidden_states.reshape(shape(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states torch.einsum(nhwpqc-nchpwq, hidden_states)
output hidden_states.reshape(shape(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)3、隐空间解码生成图片
通过上述步骤已经可以多次采样获得结果然后我们便可以通过隐空间解码生成图片。
隐空间解码生成图片的过程非常简单将上文多次采样后的结果使用vae的decode方法即可生成图片。
# --------------------------------- #
# 后处理
# --------------------------------- #
# 通过vae进行解码
latents 1 / self.vae.config.scaling_factor * latents
samples self.vae.decode(latents).samplesamples (samples / 2 0.5).clamp(0, 1)# 转化为float32类别
samples samples.cpu().permute(0, 2, 3, 1).float().numpy()类别到图像预测过程代码
整体预测代码如下 import torch
import json
import os
from diffusers import DPMSolverMultistepScheduler, AutoencoderKLfrom nets.transformer_2d import Transformer2DModel
from nets.pipeline import DiTPipeline# 模型路径
model_path model_data/DiT-XL-2-256# 初始化DiT的各个组件
scheduler DPMSolverMultistepScheduler.from_pretrained(model_path, subfolderscheduler)
transformer Transformer2DModel.from_pretrained(model_path, subfoldertransformer)
vae AutoencoderKL.from_pretrained(model_path, subfoldervae)
id2label json.load(open(os.path.join(model_path, model_index.json), r))[id2label]# 初始化DiT的Pipeline
pipe DiTPipeline(schedulerscheduler, transformertransformer, vaevae, id2labelid2label)
pipe pipe.to(cuda)# imagenet种类 对应的 名称
words [white shark, umbrella]
# 获得imagenet对应的ids
class_ids pipe.get_label_ids(words)
# 设置seed
generator torch.manual_seed(42)# pipeline前传
output pipe(class_labelsclass_ids, num_inference_steps25, generatorgenerator)# 保存图片
for index, image in enumerate(output.images):image.save(foutput-{index}.png)