Vision Transformer(ViT)模型实例化PyTorch逐行实现
为了让大家更好地理解,我们将从零开始,逐步构建 ViT 的各个核心组件,并最终将它们组合成一个完整的模型。我们会以一个在 CIFAR-10
数据集上应用的实例来贯穿整个讲解过程。
ViT 核心思想
在讲解代码之前,我们先快速回顾一下 ViT 的核心思想,这有助于理解代码每一部分的目的。
图片切块 (Image to Patches): 传统 CNN 逐像素处理图像,而 ViT 模仿 NLP 中处理单词 (Token) 的方式。它将一幅图像 (H*W*C) 切割成一个个小块 (Patch),每个小块大小为 P*P*C。
展平与线性投射 (Patch Flattening & Linear Projection): 将每个小块展平成一个一维向量,然后通过一个全连接层(线性投射)将其映射到一个固定的维度 D,这个向量就成为了 Transformer 的 "Token"。
类别令牌 (Class Token): 模仿 BERT 的 [CLS]
令牌,在所有 Patch Token 的最前面加入一个可学习的 [CLS]
Token。这个 Token 最终将用于图像分类。
位置编码 (Positional Embedding): Transformer 本身不包含位置信息。为了让模型知道每个 Patch 的原始位置,我们需要为每个 Token(包括 [CLS]
Token)添加一个可学习的位置编码。
Transformer 编码器 (Transformer Encoder): 将带有位置编码的 Token 序列输入到标准的 Transformer Encoder 中。Encoder 由多层堆叠而成,每一层都包含一个多头自注意力模块 (Multi-Head Self-Attention) 和一个前馈网络 (Feed-Forward Network)。
分类头 (MLP Head): 将 Transformer Encoder 输出的 [CLS]
Token 对应的向量,送入一个简单的多层感知机(MLP),最终输出分类结果。
实例设定
我们将以 CIFAR-10
数据集为例。
图片尺寸 (image_size): 32*32*3
Patch 尺寸 (patch_size): 4*4 (我们可以选择 8x8 或 16x16,这里用 4x4 举例)
类别数 (num_classes): 10
嵌入维度 (dim): 512 (每个 Patch 展平后映射到的维度)
Transformer Encoder 层数 (depth): 6
多头注意力头数 (heads): 8
MLP 内部维度 (mlp_dim): 2048
根据这些设定,我们可以计算出:
每张图片的 Patch 数量 (num_patches): (32/4)x(32/4)=8x8=64
PyTorch 代码逐行实现
我们将按照 ViT 的思想,一步步构建代码。
1. Patch Embedding (图像切块与线性投射)
这是 ViT 的第一步,我们的目标是将一个 (B, C, H, W)
的图像张量,转换成一个 (B, N, D)
的 Token 序列张量,其中 B
是批量大小,N
是 Patch 数量,D
是嵌入维度。
一个巧妙高效的实现方法是使用二维卷积。
思想: 我们可以设置一个卷积层,其卷积核大小 (kernel_size) 和步长 (stride) 都等于 patch_size
。这样,卷积核每次滑动的区域恰好就是一个不重叠的 Patch。卷积的输出通道数设为我们想要的嵌入维度 dim
。
import torch
from torch import nnclass PatchEmbedding(nn.Module):"""将图像分割成块并进行线性嵌入。参数:image_size (int): 输入图像的尺寸 (假设为正方形)。patch_size (int): 每个图像块的尺寸 (假设为正方形)。in_channels (int): 输入图像的通道数。dim (int): 线性投射后的嵌入维度。"""def __init__(self, image_size, patch_size, in_channels, dim):super().__init__()self.patch_size = patch_size# 检查图像尺寸是否能被 patch 尺寸整除if not (image_size % patch_size == 0):raise ValueError("error")# 计算 patch 的数量self.num_patches = (image_size // patch_size) ** 2# 核心:使用 Conv2d 实现 patch 化和线性投射# kernel_size 和 stride 都设为 patch_size,实现不重叠的块分割# out_channels 设为嵌入维度 dimself.projection = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):# 输入 x 的形状: (B, C, H, W)# 例如: (B, 3, 32, 32)# 经过卷积层,将图像转换为 patch 的特征图# 输出形状: (B, dim, H/P, W/P)# 例如: (B, 512, 8, 8)x = self.projection(x)# 将特征图展平# .flatten(2) 将从第2个维度开始展平 (H/P 和 W/P 维度)# 输出形状: (B, dim, N) 其中 N = (H/P) * (W/P)# 例如: (B, 512, 64)x = x.flatten(2)# 交换维度,以匹配 Transformer 输入格式 (B, N, D)# 输出形状: (B, N, dim)# 例如: (B, 64, 512)x = x.transpose(1, 2)return x
2. Transformer Encoder Block
Transformer Encoder 由多个相同的块 (Block) 堆叠而成。每个块包含两个主要部分:
多头自注意力 (Multi-Head Self-Attention)
前馈网络 (Feed-Forward Network / MLP)
每个部分都伴随着残差连接 (Residual Connection) 和层归一化 (Layer Normalization)。
class TransformerEncoderBlock(nn.Module):"""标准的 Transformer Encoder 块。参数:dim (int): 输入的 token 维度。heads (int): 多头注意力的头数。mlp_dim (int): MLP 层的隐藏维度。dropout (float): Dropout 的概率。"""def __init__(self, dim, heads, mlp_dim, dropout=0.1):super().__init__()# 第一个 LayerNormself.norm1 = nn.LayerNorm(dim)# 多头自注意力模块# PyTorch 内置的 MultiheadAttention 期望输入形状为 (N, B, D),# 但我们通常使用 (B, N, D)。设置 batch_first=True 可以解决这个问题。self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout, batch_first=True)# 第二个 LayerNormself.norm2 = nn.LayerNorm(dim)# MLP / 前馈网络self.mlp = nn.Sequential(nn.Linear(dim, mlp_dim),nn.GELU(), # ViT 论文中使用的激活函数nn.Dropout(dropout),nn.Linear(mlp_dim, dim),nn.Dropout(dropout))def forward(self, x):# x 的形状: (B, N, D)# 1. 多头自注意力部分# 残差连接: x + Attention(LayerNorm(x))x_norm = self.norm1(x)# 注意力模块返回 attn_output 和 attn_weights,我们只需要前者attn_output, _ = self.attention(x_norm, x_norm, x_norm)x = x + attn_output# 2. 前馈网络部分# 残差连接: x + MLP(LayerNorm(x))x_norm = self.norm2(x)mlp_output = self.mlp(x_norm)x = x + mlp_outputreturn x
3. 完整的 Vision Transformer 模型
现在,我们将所有组件整合在一起。
class VisionTransformer(nn.Module):"""Vision Transformer 模型。参数:image_size (int): 输入图像尺寸。patch_size (int): Patch 尺寸。in_channels (int): 输入通道数。num_classes (int): 分类类别数。dim (int): 嵌入维度。depth (int): Transformer Encoder 层数。heads (int): 多头注意力头数。mlp_dim (int): MLP 隐藏维度。dropout (float): Dropout 概率。"""def __init__(self, image_size, patch_size, in_channels, num_classes,dim, depth, heads, mlp_dim, dropout=0.1):super().__init__()# 1. Patch Embeddingself.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, dim)# 计算 patch 数量num_patches = self.patch_embedding.num_patches# 2. Class Token# 这是一个可学习的参数,维度为 (1, 1, D)# '1' 个 batch,'1' 个 token,'D' 维self.cls_token = nn.Parameter(torch.randn(1, 1, dim))# 3. Positional Embedding# 这也是一个可学习的参数# 长度为 num_patches + 1 (为了包含 cls_token)# 维度为 (1, N+1, D)self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.dropout = nn.Dropout(dropout)# 4. Transformer Encoder# 使用 nn.Sequential 将多个 Encoder Block 堆叠起来self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(dim, heads, mlp_dim, dropout) for _ in range(depth)])# 5. MLP Head (分类头)self.mlp_head = nn.Sequential(nn.LayerNorm(dim), # 在送入分类头前先进行一次 LayerNormnn.Linear(dim, num_classes))def forward(self, img):# img 形状: (B, C, H, W)# 1. 获取 Patch Embedding# x 形状: (B, N, D)x = self.patch_embedding(img)b, n, d = x.shape # b: batch_size, n: num_patches, d: dim# 2. 添加 Class Token# 将 cls_token 复制 b 份,拼接到 x 的最前面# cls_tokens 形状: (B, 1, D)cls_tokens = self.cls_token.expand(b, -1, -1) # x 形状变为: (B, N+1, D)x = torch.cat((cls_tokens, x), dim=1)# 3. 添加 Positional Embedding# pos_embedding 形状是 (1, N+1, D),利用广播机制直接相加x += self.pos_embeddingx = self.dropout(x)# 4. 通过 Transformer Encoder# x 形状不变: (B, N+1, D)x = self.transformer_encoder(x)# 5. 提取 Class Token 的输出用于分类# 只取序列的第一个 token (cls_token) 的输出# x 形状: (B, D)cls_token_output = x[:, 0]# 6. 通过 MLP Head 得到最终的分类 logits# output 形状: (B, num_classes)output = self.mlp_head(cls_token_output)return output
完整模型与实例
现在我们把所有代码放在一起,并用我们之前设定的 CIFAR-10 参数来实例化模型,看看它的输入和输出。
import torch
from torch import nn# --- 组件 1: PatchEmbedding ---
class PatchEmbedding(nn.Module):def __init__(self, image_size, patch_size, in_channels, dim):super().__init__()if not (image_size % patch_size == 0):raise ValueError("Image dimensions must be divisible by the patch size.")self.num_patches = (image_size // patch_size) ** 2self.projection = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):x = self.projection(x)x = x.flatten(2)x = x.transpose(1, 2)return x# --- 组件 2: TransformerEncoderBlock ---
class TransformerEncoderBlock(nn.Module):def __init__(self, dim, heads, mlp_dim, dropout=0.1):super().__init__()self.norm1 = nn.LayerNorm(dim)self.attention = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)self.norm2 = nn.LayerNorm(dim)self.mlp = nn.Sequential(nn.Linear(dim, mlp_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(mlp_dim, dim),nn.Dropout(dropout))def forward(self, x):attn_output, _ = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))x = x + attn_outputmlp_output = self.mlp(self.norm2(x))x = x + mlp_outputreturn x# --- 主模型: VisionTransformer ---
class VisionTransformer(nn.Module):def __init__(self, image_size, patch_size, in_channels, num_classes,dim, depth, heads, mlp_dim, dropout=0.1):super().__init__()self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, dim)num_patches = self.patch_embedding.num_patchesself.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.dropout = nn.Dropout(dropout)self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(dim, heads, mlp_dim, dropout) for _ in range(depth)])self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_classes))def forward(self, img):x = self.patch_embedding(img)b, n, d = x.shapecls_tokens = self.cls_token.expand(b, -1, -1)x = torch.cat((cls_tokens, x), dim=1)x += self.pos_embeddingx = self.dropout(x)x = self.transformer_encoder(x)cls_token_output = x[:, 0]output = self.mlp_head(cls_token_output)return output# --- 实例化并测试 ---# CIFAR-10 实例参数
BATCH_SIZE = 4
IMAGE_SIZE = 32
IN_CHANNELS = 3
PATCH_SIZE = 4
NUM_CLASSES = 10
DIM = 512
DEPTH = 6
HEADS = 8
MLP_DIM = 2048# 创建模型实例
vit_model = VisionTransformer(image_size=IMAGE_SIZE,patch_size=PATCH_SIZE,in_channels=IN_CHANNELS,num_classes=NUM_CLASSES,dim=DIM,depth=DEPTH,heads=HEADS,mlp_dim=MLP_DIM
)# 创建一个假的输入图像张量 (Batch, Channels, Height, Width)
dummy_img = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)# 将图像输入模型
logits = vit_model(dummy_img)# 打印输出的形状
print(f"输入图像形状: {dummy_img.shape}")
print(f"模型输出 (Logits) 形状: {logits.shape}")# 检查输出形状是否正确
assert logits.shape == (BATCH_SIZE, NUM_CLASSES)
print("\n模型构建成功,输入输出形状正确!")