当前位置: 首页 > news >正文

VIT(视觉Transformer)

由于今天研究了一下SegFormer(分割任务Transformer模型),顺带就总结一下Vision Transformer,话不多说。直接上链接。

vit项目链接:视觉Transformer (ViT) - Hugging Face 机器学习平台

vit论文地址:[2010.11929] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

Vision Transformer 网络结构图

核心思想

  1. 图像分块(Patching)

    • 将输入图像分割为固定大小的块(如 16x16 像素),每个块被视为一个“视觉 token”。
    • 例如:224x224 的图像被分割为 14x14=196 个块。
  2. 线性嵌入(Linear Embedding)

    • 将每个图像块展平为一维向量,并通过线性变换(类似全连接层)映射到低维特征空间,得到 Patch Embeddings。
  3. 位置编码(Positional Encoding)

    • 为每个块添加可学习的位置编码,以保留图像的空间信息(Transformer 本身不感知位置)。
  4. Transformer 编码器

    • 将 Patch Embeddings 和位置编码组成的序列输入标准 Transformer 编码器。
    • 通过 多头自注意力机制(Multi-Head Self-Attention) 捕捉图像块之间的全局依赖关系。
  5. 分类输出

    • 使用特殊的 [CLS] 标记(类似 NLP 中的 [CLS])聚合全局信息,最终通过全连接层进行分类。

ViT 小tips:

1,核心思路是将图像特征提取成为强语义特征所以这个切分的patch数很重要)(切不好很可能网络效果很差)。

2,天生带大量的噪声,所以要大量数据。并且训练较缓慢。(4 -10W张数据,数据少也会造成网络效果很差。)

3, 改进可以不喂原图,喂用CNN提取过的特征图。

4,优点,提取全局特征,具有长距离依赖,大数据情况下更精准。

demo代码

ps:这个代码是博主跑一个分类网络的简易demo。对应目录修改即可运行。

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import os"""
Einops允许你通过简单的字符串表达式来重新排列、重塑和减少数组的维度,而无需编写冗长且容易出错的代码。
"""
from einops import rearrange
from einops.layers.torch import Rearrangedef pair(t):return t if isinstance(t, tuple) else (t, t)def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):"""位置编码:param h::param w::param dim::param temperature::param dtype::return:"""y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")  # 每个区域的位置编码assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"omega = torch.arange(dim // 4) / (dim // 4 - 1)  # 频率设置omega = 1.0 / (temperature ** omega)y = y.flatten()[:, None] * omega[None, :]x = x.flatten()[:, None] * omega[None, :]pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)return pe.type(dtype)class FeedForward(nn.Module):"""FFN层"""def __init__(self, dim, hidden_dim):super().__init__()self.net = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, hidden_dim),nn.GELU(),nn.Linear(hidden_dim, dim),)def forward(self, x):return self.net(x)class Attention(nn.Module):"""多头注意力"""def __init__(self, dim, heads=8, dim_head=64):super().__init__()inner_dim = dim_head * headsself.heads = headsself.scale = dim_head ** -0.5self.norm = nn.LayerNorm(dim)self.attend = nn.Softmax(dim=-1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)self.to_out = nn.Linear(inner_dim, dim, bias=False)def forward(self, x):x = self.norm(x)qkv = self.to_qkv(x).chunk(3, dim=-1)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)class Transformer(nn.Module):"""只有encoder"""def __init__(self, dim, depth, heads, dim_head, mlp_dim):super().__init__()self.norm = nn.LayerNorm(dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([Attention(dim, heads=heads, dim_head=dim_head),FeedForward(dim, mlp_dim)]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn self.norm(x)class SimpleViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dim_head=64):super().__init__()image_height, image_width = pair(image_size)patch_height, patch_width = pair(patch_size)assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'patch_dim = channels * patch_height * patch_width# b 3 (256//32,32) (256//32,32) -> b (256//32,256//32) (32,32,3)self.patch = Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_height, p2=patch_width)self.to_patch_embedding = nn.Sequential(nn.LayerNorm(patch_dim),nn.Linear(patch_dim, dim),nn.LayerNorm(dim),)self.pos_embedding = posemb_sincos_2d(h=image_height // patch_height,w=image_width // patch_width,dim=dim,)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)self.pool = "mean"self.to_latent = nn.Identity()self.linear_head = nn.Linear(dim, num_classes)def forward(self, img):device = img.device# 分割成8x8的区域"""输出为:1x(8*8)x(32,32,3),即 1x64x(32,32,3)。输出的形状 1x64x3072 是因为每个补丁的大小是 32*32*3=3072,而总共有 64 个补丁64个补丁,每个补丁的特征映射为3072维向量相似之处:表示形式:1。补丁和词向量:在视觉模型中,图像被分成补丁,每个补丁都有一个特征表示;2。在 BERT 中,句子中的每个词也有一个向量表示。这两者都可以被看作是对输入的分块和特征化。处理方式:1。在视觉模型中,多个补丁的特征可以被用来捕捉整个图像的上下文信息;2。在 BERT 中,多个词的向量通过自注意力机制来理解句子的上下文。"""x = self.patch(img)  # bx64x3072# 线性投射层x = self.to_patch_embedding(x)  # b x 64 x 1024# 位置编码x += self.pos_embedding.to(device, dtype=x.dtype)# encoderx = self.transformer(x)  # 1x64x1024x = x.mean(dim=1)  # 取多通道中的第1个通道来做为分类预测# ffnx = self.to_latent(x)return self.linear_head(x)if __name__ == "__main__":device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 保存特征图的列表feature_maps = []weight_path = r'./torch_weight_file/simple_vit_model.pth'model = SimpleViT(image_size=224,patch_size=32,num_classes=7,dim=1024,  # Transformer 在编码过程中,输入的每个补丁会被转换为一个 1024 维的向量,以捕捉丰富的特征信息。depth=6,  # encoder重复的次数heads=16,  # 多头的数量mlp_dim=2048  # 全连接的神经元).to(device)model.load_state_dict(torch.load(weight_path, map_location=device))# 数据集路径data_dir = r'your—-data--dir'# 数据预处理transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),])# 加载数据集train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=transform)train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'), transform=transform)val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.AdamW(model.parameters(), lr=0.0003)# 训练和验证过程num_epochs = 40train_losses, train_accs, val_losses, val_accs = [], [], [], []for epoch in range(num_epochs):model.train()train_loss, correct = 0, 0total = 0for images, labels in tqdm(train_loader):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()train_losses.append(train_loss / len(train_loader))train_accs.append(correct / total)# 验证过程model.eval()val_loss, val_correct = 0, 0val_total = 0with torch.no_grad():for images, labels in val_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)loss = criterion(outputs, labels)val_loss += loss.item()_, val_predicted = torch.max(outputs.data, 1)val_total += labels.size(0)val_correct += (val_predicted == labels).sum().item()val_losses.append(val_loss / len(val_loader))val_accs.append(val_correct / val_total)print(f'Epoch [{epoch + 1}/{num_epochs}], 'f'Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accs[-1]:.4f}, 'f'Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_accs[-1]:.4f}')# 保存模型权重torch.save(model.state_dict(), './torch_weight_file/simple_vit_model.pth')# 绘图plt.figure(figsize=(12, 5))# 训练和验证损失plt.subplot(1, 2, 1)plt.plot(train_losses, label='Train Loss')plt.plot(val_losses, label='Validation Loss')plt.title('Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()# 训练和验证准确率plt.subplot(1, 2, 2)plt.plot(train_accs, label='Train Accuracy')plt.plot(val_accs, label='Validation Accuracy')plt.title('Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.show()

http://www.lryc.cn/news/585202.html

相关文章:

  • 【爬虫】- 爬虫原理及其入门
  • 提示工程:突破Transformer极限的计算科学
  • 进程状态 + 进程优先级切换调度-进程概念(5)
  • 需求升级,创新破局!苏州金龙赋能旅游客运新生态
  • 20250711荣品RD-RK3588开发板在Android13下的开机自启动的配置步骤
  • 宝塔命令Composer 更改数据源不生效
  • 动态组件和插槽
  • 基于定制开发开源AI智能名片与S2B2C商城小程序的旅游日志创新应用研究
  • nessus最新安装
  • [Meetily后端框架] Whisper转录服务器 | 后端服务管理脚本
  • 20.缓存问题与解决方案详解教程
  • NodeJs后端常用三方库汇总
  • 录音实时上传
  • 2025河南高考生物真题及解析
  • 国际学术期刊IJCAST发布最新一期论文
  • 【达梦数据库|JPA】后端数据库国产化迁移记录
  • uniapp类似抖音视频滑动
  • [python]在drf中使用drf_spectacular
  • 持续集成 简介环境搭建
  • STM32G473串口通信-USART/UART配置和清除串口寄存器状态的注意事项
  • Rail开发日志_5
  • 基于Selenium和FFmpeg的全平台短视频自动化发布系统
  • Maven下载与配置对Java项目的理解
  • RISC-V:开源芯浪潮下的技术突围与职业新赛道 (三)RISC-V架构深度解剖(下)
  • SpringBoot 使用注解获取配置文件中的值
  • c/c++拷贝函数
  • Claude Code是什么?国内如何使用到Claude Code?附国内最新使用教程
  • FlashBots 之 MEV-boost
  • 决策树算法在医学影像诊断中的广泛应用
  • 用Python和OpenCV从零搭建一个完整的双目视觉系统(六 最终篇)