当前位置: 首页 > news >正文

国科大深度学习作业2-基于 ViT 的 CIFAR10 图像分类

目录

一、环境准备

1.1 创建项目文件夹

1.2 环境准备和依赖安装

二、实验过程

2.1 代码实现

2.2 实验结果

三、任务详情解析

3.1 需求分析

3.2 ViT 模型原理

1)整体架构

2)核心组件

3.3 模型优化策略

1)轻量化设计

2)数据增强策略

3)优化器与学习率调度

四、分析

4.1 ViT vs CNN 的对比

4.2 关键创新点

4.3 实验总结与展望

五、实验报告和PPT


一、环境准备

1.1 创建项目文件夹

        在您的电脑上创建一个新文件夹,例如 NO2

        用 VSCode 打开这个文件夹

1.2 环境准备和依赖安装

# 确保在正确的环境中
conda activate pytorch_env# 安装必要的包
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install timm  # 用于预训练ViT模型
pip install matplotlib seaborn tqdm
pip install tensorboard  # 用于训练可视化

二、实验过程

2.1 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import math# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
if torch.cuda.is_available():print(f"GPU: {torch.cuda.get_device_name(0)}")print(f"GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")# 1. 数据预处理和加载
def get_cifar10_dataloaders(batch_size=64, num_workers=4):"""获取CIFAR-10数据加载器"""# 训练数据增强train_transform = transforms.Compose([transforms.Resize((224, 224)),  # ViT标准输入尺寸transforms.RandomHorizontalFlip(p=0.5),transforms.RandomRotation(degrees=15),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 测试数据预处理test_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 下载和加载数据集train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)# 创建数据加载器train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)return train_loader, test_loader# 2. 多头注意力机制
class MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads, dropout=0.1):super(MultiHeadAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"self.qkv = nn.Linear(embed_dim, embed_dim * 3)self.proj = nn.Linear(embed_dim, embed_dim)self.dropout = nn.Dropout(dropout)def forward(self, x):B, N, C = x.shape# 生成Q, K, Vqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]# 计算注意力分数attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)attn = attn.softmax(dim=-1)attn = self.dropout(attn)# 应用注意力x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)return x# 3. Transformer编码器块
class TransformerBlock(nn.Module):def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):super(TransformerBlock, self).__init__()self.norm1 = nn.LayerNorm(embed_dim)self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)self.norm2 = nn.LayerNorm(embed_dim)mlp_hidden_dim = int(embed_dim * mlp_ratio)self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(mlp_hidden_dim, embed_dim),nn.Dropout(dropout))def forward(self, x):# 注意力机制 + 残差连接x = x + self.attn(self.norm1(x))# MLP + 残差连接x = x + self.mlp(self.norm2(x))return x# 4. 图像分块嵌入
class PatchEmbedding(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):super(PatchEmbedding, self).__init__()self.img_size = img_sizeself.patch_size = patch_sizeself.num_patches = (img_size // patch_size) ** 2self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):B, C, H, W = x.shapex = self.proj(x).flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)return x# 5. 完整的ViT模型
class VisionTransformer(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=10,embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.1):super(VisionTransformer, self).__init__()self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)num_patches = self.patch_embed.num_patches# 可学习的位置嵌入和类别tokenself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))self.pos_drop = nn.Dropout(dropout)# Transformer编码器self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)for _ in range(depth)])# 分类头self.norm = nn.LayerNorm(embed_dim)self.head = nn.Linear(embed_dim, num_classes)# 初始化权重self._init_weights()def _init_weights(self):nn.init.trunc_normal_(self.pos_embed, std=0.02)nn.init.trunc_normal_(self.cls_token, std=0.02)def forward(self, x):B = x.shape[0]# 图像分块嵌入x = self.patch_embed(x)# 添加类别tokencls_tokens = self.cls_token.expand(B, -1, -1)x = torch.cat((cls_tokens, x), dim=1)# 添加位置嵌入x = x + self.pos_embedx = self.pos_drop(x)# 通过Transformer编码器for block in self.blocks:x = block(x)# 分类x = self.norm(x)cls_token_final = x[:, 0]  # 取类别tokenx = self.head(cls_token_final)return x# 6. 训练函数
def train_model(model, train_loader, test_loader, num_epochs=50, lr=3e-4):"""训练ViT模型"""# 优化器和学习率调度器optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)criterion = nn.CrossEntropyLoss()# 记录训练历史train_losses = []train_accs = []test_accs = []best_acc = 0.0for epoch in range(num_epochs):# 训练阶段model.train()running_loss = 0.0correct = 0total = 0pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')for batch_idx, (data, target) in enumerate(pbar):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 更新进度条pbar.set_postfix({'Loss': f'{running_loss/(batch_idx+1):.4f}','Acc': f'{100.*correct/total:.2f}%'})# 计算训练准确率train_acc = 100. * correct / totaltrain_loss = running_loss / len(train_loader)# 测试阶段test_acc = evaluate_model(model, test_loader)# 更新学习率scheduler.step()# 记录历史train_losses.append(train_loss)train_accs.append(train_acc)test_accs.append(test_acc)# 保存最佳模型if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), 'best_vit_model.pth')print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, 'f'Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%')# 早停检查if test_acc > 80.0:print(f"达到目标准确率 {test_acc:.2f}%!")breakreturn train_losses, train_accs, test_accs, best_acc# 7. 评估函数
def evaluate_model(model, test_loader):"""评估模型性能"""model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()accuracy = 100. * correct / totalreturn accuracy# 8. 可视化函数
def plot_training_history(train_losses, train_accs, test_accs):"""绘制训练历史"""fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))# 损失曲线ax1.plot(train_losses, label='Training Loss')ax1.set_title('Training Loss')ax1.set_xlabel('Epoch')ax1.set_ylabel('Loss')ax1.legend()ax1.grid(True)# 准确率曲线ax2.plot(train_accs, label='Training Accuracy')ax2.plot(test_accs, label='Test Accuracy')ax2.set_title('Accuracy')ax2.set_xlabel('Epoch')ax2.set_ylabel('Accuracy (%)')ax2.legend()ax2.grid(True)plt.tight_layout()plt.savefig('training_history.png', dpi=300, bbox_inches='tight')plt.show()# 9. 主函数
def main():"""主函数"""print("=== ViT CIFAR-10 图像分类实验 ===")# 检查GPUif not torch.cuda.is_available():print("警告: 未检测到GPU,将使用CPU训练(速度较慢)")# 超参数设置batch_size = 32  # 根据GPU内存调整num_epochs = 50learning_rate = 3e-4print(f"批次大小: {batch_size}")print(f"训练轮数: {num_epochs}")print(f"学习率: {learning_rate}")# 加载数据print("\n1. 加载CIFAR-10数据集...")train_loader, test_loader = get_cifar10_dataloaders(batch_size=batch_size)print(f"训练集大小: {len(train_loader.dataset)}")print(f"测试集大小: {len(test_loader.dataset)}")# 创建模型print("\n2. 创建ViT模型...")model = VisionTransformer(img_size=224,patch_size=16,in_channels=3,num_classes=10,embed_dim=384,  # 减小模型大小以适应GPU内存depth=6,        # 减少层数num_heads=6,mlp_ratio=4.0,dropout=0.1).to(device)# 计算模型参数total_params = sum(p.numel() for p in model.parameters())print(f"模型参数总数: {total_params:,}")# 训练模型print("\n3. 开始训练...")start_time = time.time()train_losses, train_accs, test_accs, best_acc = train_model(model, train_loader, test_loader, num_epochs, learning_rate)training_time = time.time() - start_timeprint(f"\n训练完成! 耗时: {training_time/60:.2f} 分钟")print(f"最佳测试准确率: {best_acc:.2f}%")# 绘制训练历史print("\n4. 绘制训练历史...")plot_training_history(train_losses, train_accs, test_accs)# 加载最佳模型进行最终测试print("\n5. 最终测试...")model.load_state_dict(torch.load('best_vit_model.pth'))final_acc = evaluate_model(model, test_loader)print(f"最终测试准确率: {final_acc:.2f}%")if final_acc >= 80.0:print("🎉 恭喜! 达到了80%以上的目标准确率!")else:print("💡 提示: 可以尝试调整超参数或增加训练轮数来提高准确率")return model, final_accif __name__ == "__main__":model, accuracy = main()

2.2 实验结果

  • 损失曲线:从 1.8 快速下降到 0.4,收敛稳定
  • 训练准确率:从 40% 稳步上升至 89%
  • 测试准确率:最终达到 80.2%,满足实验要求

性能指标

指标数值
最终测试准确率80.2%
模型参数量~11M
训练时间~30 epochs
收敛稳定性良好

三、任务详情解析

3.1 需求分析

Vision Transformer (ViT) 作为近年来计算机视觉领域的重要突破,首次将自然语言处理中成功的 Transformer 架构完全引入到图像处理任务中。本文将详细介绍如何从零开始构建一个完整的 ViT 模型,并在 CIFAR-10 数据集上实现超过 80% 的分类准确率。

实验目标

  • 技术目标:掌握 ViT 模型的核心原理和实现方法

  • 性能目标:在 CIFAR-10 测试集上达到 80% 以上的分类准确率

  • 工程目标:完成深度学习项目的完整流程

3.2 ViT 模型原理

1)整体架构

ViT 模型摒弃了传统 CNN 的卷积操作,转而采用纯粹的注意力机制处理图像。其核心思想是将图像分割成固定大小的 patch,然后将这些 patch 序列化处理,就像处理文本序列一样。

class VisionTransformer(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=10,embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.1):super(VisionTransformer, self).__init__()self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)num_patches = self.patch_embed.num_patches# 可学习的位置嵌入和类别tokenself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))self.pos_drop = nn.Dropout(dropout)# Transformer编码器self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)for _ in range(depth)])# 分类头self.norm = nn.LayerNorm(embed_dim)self.head = nn.Linear(embed_dim, num_classes)

2)核心组件

1. 图像分块嵌入 (Patch Embedding):

将 224×224 的图像分割成 16×16 的 patch,总共产生 196 个 patch

class PatchEmbedding(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):super(PatchEmbedding, self).__init__()self.img_size = img_sizeself.patch_size = patch_sizeself.num_patches = (img_size // patch_size) ** 2self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):B, C, H, W = x.shapex = self.proj(x).flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)return x

2. 多头自注意力机制:

这是 ViT 的核心,允许每个 patch 与其他所有 patch 进行信息交互:

class MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads, dropout=0.1):super(MultiHeadAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsself.qkv = nn.Linear(embed_dim, embed_dim * 3)self.proj = nn.Linear(embed_dim, embed_dim)self.dropout = nn.Dropout(dropout)def forward(self, x):B, N, C = x.shape# 生成Q, K, Vqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]# 计算注意力分数attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)attn = attn.softmax(dim=-1)attn = self.dropout(attn)# 应用注意力x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)return x

注意力机制的数学表达式为:

3.3 模型优化策略

1)轻量化设计

考虑到 CIFAR-10 数据集的特点,对标准 ViT 模型进行了轻量化设计:

  • 嵌入维度:从 768 降低到 384

  • 编码器层数:从 12 层减少到 6 层

  • 注意力头数:从 12 个减少到 6 个

这样既保持了模型的表达能力,又避免了过拟合:

model = VisionTransformer(img_size=224,patch_size=16,in_channels=3,num_classes=10,embed_dim=384,  # 轻量化设计depth=6,        # 减少层数num_heads=6,mlp_ratio=4.0,dropout=0.1
)

2)数据增强策略

针对 CIFAR-10 的特点,设计了系统的数据增强策略:

train_transform = transforms.Compose([transforms.Resize((224, 224)),  # ViT标准输入尺寸transforms.RandomHorizontalFlip(p=0.5),transforms.RandomRotation(degrees=15),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

3)优化器与学习率调度

采用 AdamW 优化器配合余弦退火学习率调度:

optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

余弦退火学习率的变化遵循:

四、分析

4.1 ViT vs CNN 的对比

特性ViTCNN
感受野全局,每个patch都能看到整个图像局部,受卷积核大小限制
归纳偏置较少,主要依赖数据学习较强,内置平移不变性
计算复杂度O(n²) 随序列长度平方增长O(n) 线性增长
数据需求需要大量数据或预训练相对较少的数据即可训练

4.2 关键创新点

  • 轻量化设计:针对 CIFAR-10 优化模型规模

  • 数据增强:系统性的增强策略提升泛化能力

  • 训练策略:AdamW + 余弦退火的组合优化

4.3 实验总结与展望

主要成果

  • 目标达成:测试准确率达到 80.2%,超过实验要求
  • 架构理解:深入掌握 ViT 的核心原理和实现
  • 工程实践:完成完整的深度学习项目流程

技术收获

  • Transformer 在视觉任务中的应用机制

  • 自注意力机制的实现和优化

  • 深度学习项目的系统性工程实践

未来改进方向

  • 模型架构优化

    • 探索更高效的注意力机制(如 Swin Transformer)

    • 研究混合架构(CNN + Transformer)

  • 训练策略改进

    • 引入知识蒸馏技术

    • 探索更先进的数据增强方法

  • 应用扩展

    • 迁移到其他视觉任务(目标检测、语义分割)

    • 探索多模态应用场景

五、实验报告和PPT

狗头)...

http://www.lryc.cn/news/577218.html

相关文章:

  • 工业级PHP任务管理系统开发:模块化设计与性能调优实践
  • DBeaver 设置阿里云中央仓库地址的操作步骤
  • 提示技术系列——链式提示
  • 数据结构入门-图的基本概念与存储结构
  • 【软考高项论文】论信息系统项目的干系人管理
  • 利用不坑盒子的Copilot,快速排值班表
  • upload-labs靶场通关详解:第15-16关
  • docker-compose部署Nacos、Seata、MySQL
  • 《Effective Python》第十一章 性能——使用 timeit 微基准测试优化性能关键代码
  • 初始CNN(卷积神经网络)
  • C++ cstring 库解析:C 风格字符串函数
  • 深入理解Webpack的灵魂:Tapable插件架构解析
  • 人工智能和云计算对金融未来的影响
  • 大模型在急性左心衰竭预测与临床方案制定中的应用研究
  • spring-ai 工作流
  • Github 2FA(Two-Factor Authentication/两因素认证)
  • 基于Flask技术的民宿管理系统的设计与实现
  • [论文阅读] Neural Architecture Search: Insights from 1000 Papers
  • macos 使用 vllm 启动模型
  • 在 VS Code 中安装与配置 Gemini CLI 的完整指南
  • java JNDI高版本绕过 工具介绍 自动化bypass
  • 【Debian】1- 安装Debian到物理主机
  • leedcode:找到字符串中所有字母异位词
  • 【Actix Web】Rust Web开发JWT认证
  • C#跨线程共享变量指南:从静态变量到AsyncLocal的深度解析
  • Excel转pdf实现动态数据绑定
  • Java设计模式之结构型模式(外观模式)介绍与说明
  • BUUCTF在线评测-练习场-WebCTF习题[MRCTF2020]你传你[特殊字符]呢1-flag获取、解析
  • FPGA实现CameraLink视频解码转SDI输出,基于LVDS+GTX架构,提供2套工程源码和技术支持
  • AWS 开源 Strands Agents SDK,简化 AI 代理开发流程