VIT(视觉Transformer)
由于今天研究了一下SegFormer(分割任务Transformer模型),顺带就总结一下Vision Transformer,话不多说。直接上链接。
vit项目链接:视觉Transformer (ViT) - Hugging Face 机器学习平台
vit论文地址:[2010.11929] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
Vision Transformer 网络结构图
核心思想
-
图像分块(Patching)
- 将输入图像分割为固定大小的块(如 16x16 像素),每个块被视为一个“视觉 token”。
- 例如:224x224 的图像被分割为 14x14=196 个块。
-
线性嵌入(Linear Embedding)
- 将每个图像块展平为一维向量,并通过线性变换(类似全连接层)映射到低维特征空间,得到 Patch Embeddings。
-
位置编码(Positional Encoding)
- 为每个块添加可学习的位置编码,以保留图像的空间信息(Transformer 本身不感知位置)。
-
Transformer 编码器
- 将 Patch Embeddings 和位置编码组成的序列输入标准 Transformer 编码器。
- 通过 多头自注意力机制(Multi-Head Self-Attention) 捕捉图像块之间的全局依赖关系。
-
分类输出
- 使用特殊的 [CLS] 标记(类似 NLP 中的 [CLS])聚合全局信息,最终通过全连接层进行分类。
ViT 小tips:
1,核心思路是将图像特征提取成为强语义特征所以这个切分的patch数很重要)(切不好很可能网络效果很差)。
2,天生带大量的噪声,所以要大量数据。并且训练较缓慢。(4 -10W张数据,数据少也会造成网络效果很差。)
3, 改进可以不喂原图,喂用CNN提取过的特征图。
4,优点,提取全局特征,具有长距离依赖,大数据情况下更精准。
demo代码
ps:这个代码是博主跑一个分类网络的简易demo。对应目录修改即可运行。
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import os"""
Einops允许你通过简单的字符串表达式来重新排列、重塑和减少数组的维度,而无需编写冗长且容易出错的代码。
"""
from einops import rearrange
from einops.layers.torch import Rearrangedef pair(t):return t if isinstance(t, tuple) else (t, t)def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):"""位置编码:param h::param w::param dim::param temperature::param dtype::return:"""y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") # 每个区域的位置编码assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"omega = torch.arange(dim // 4) / (dim // 4 - 1) # 频率设置omega = 1.0 / (temperature ** omega)y = y.flatten()[:, None] * omega[None, :]x = x.flatten()[:, None] * omega[None, :]pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)return pe.type(dtype)class FeedForward(nn.Module):"""FFN层"""def __init__(self, dim, hidden_dim):super().__init__()self.net = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, hidden_dim),nn.GELU(),nn.Linear(hidden_dim, dim),)def forward(self, x):return self.net(x)class Attention(nn.Module):"""多头注意力"""def __init__(self, dim, heads=8, dim_head=64):super().__init__()inner_dim = dim_head * headsself.heads = headsself.scale = dim_head ** -0.5self.norm = nn.LayerNorm(dim)self.attend = nn.Softmax(dim=-1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)self.to_out = nn.Linear(inner_dim, dim, bias=False)def forward(self, x):x = self.norm(x)qkv = self.to_qkv(x).chunk(3, dim=-1)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)class Transformer(nn.Module):"""只有encoder"""def __init__(self, dim, depth, heads, dim_head, mlp_dim):super().__init__()self.norm = nn.LayerNorm(dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([Attention(dim, heads=heads, dim_head=dim_head),FeedForward(dim, mlp_dim)]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn self.norm(x)class SimpleViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dim_head=64):super().__init__()image_height, image_width = pair(image_size)patch_height, patch_width = pair(patch_size)assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'patch_dim = channels * patch_height * patch_width# b 3 (256//32,32) (256//32,32) -> b (256//32,256//32) (32,32,3)self.patch = Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_height, p2=patch_width)self.to_patch_embedding = nn.Sequential(nn.LayerNorm(patch_dim),nn.Linear(patch_dim, dim),nn.LayerNorm(dim),)self.pos_embedding = posemb_sincos_2d(h=image_height // patch_height,w=image_width // patch_width,dim=dim,)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)self.pool = "mean"self.to_latent = nn.Identity()self.linear_head = nn.Linear(dim, num_classes)def forward(self, img):device = img.device# 分割成8x8的区域"""输出为:1x(8*8)x(32,32,3),即 1x64x(32,32,3)。输出的形状 1x64x3072 是因为每个补丁的大小是 32*32*3=3072,而总共有 64 个补丁64个补丁,每个补丁的特征映射为3072维向量相似之处:表示形式:1。补丁和词向量:在视觉模型中,图像被分成补丁,每个补丁都有一个特征表示;2。在 BERT 中,句子中的每个词也有一个向量表示。这两者都可以被看作是对输入的分块和特征化。处理方式:1。在视觉模型中,多个补丁的特征可以被用来捕捉整个图像的上下文信息;2。在 BERT 中,多个词的向量通过自注意力机制来理解句子的上下文。"""x = self.patch(img) # bx64x3072# 线性投射层x = self.to_patch_embedding(x) # b x 64 x 1024# 位置编码x += self.pos_embedding.to(device, dtype=x.dtype)# encoderx = self.transformer(x) # 1x64x1024x = x.mean(dim=1) # 取多通道中的第1个通道来做为分类预测# ffnx = self.to_latent(x)return self.linear_head(x)if __name__ == "__main__":device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 保存特征图的列表feature_maps = []weight_path = r'./torch_weight_file/simple_vit_model.pth'model = SimpleViT(image_size=224,patch_size=32,num_classes=7,dim=1024, # Transformer 在编码过程中,输入的每个补丁会被转换为一个 1024 维的向量,以捕捉丰富的特征信息。depth=6, # encoder重复的次数heads=16, # 多头的数量mlp_dim=2048 # 全连接的神经元).to(device)model.load_state_dict(torch.load(weight_path, map_location=device))# 数据集路径data_dir = r'your—-data--dir'# 数据预处理transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),])# 加载数据集train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=transform)train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'), transform=transform)val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.AdamW(model.parameters(), lr=0.0003)# 训练和验证过程num_epochs = 40train_losses, train_accs, val_losses, val_accs = [], [], [], []for epoch in range(num_epochs):model.train()train_loss, correct = 0, 0total = 0for images, labels in tqdm(train_loader):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()train_losses.append(train_loss / len(train_loader))train_accs.append(correct / total)# 验证过程model.eval()val_loss, val_correct = 0, 0val_total = 0with torch.no_grad():for images, labels in val_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)loss = criterion(outputs, labels)val_loss += loss.item()_, val_predicted = torch.max(outputs.data, 1)val_total += labels.size(0)val_correct += (val_predicted == labels).sum().item()val_losses.append(val_loss / len(val_loader))val_accs.append(val_correct / val_total)print(f'Epoch [{epoch + 1}/{num_epochs}], 'f'Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accs[-1]:.4f}, 'f'Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_accs[-1]:.4f}')# 保存模型权重torch.save(model.state_dict(), './torch_weight_file/simple_vit_model.pth')# 绘图plt.figure(figsize=(12, 5))# 训练和验证损失plt.subplot(1, 2, 1)plt.plot(train_losses, label='Train Loss')plt.plot(val_losses, label='Validation Loss')plt.title('Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()# 训练和验证准确率plt.subplot(1, 2, 2)plt.plot(train_accs, label='Train Accuracy')plt.plot(val_accs, label='Validation Accuracy')plt.title('Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.show()