当前位置: 首页 > news >正文

部署并运行Vim/Vmamba在ImageNet上的训练与测试

部署并运行Vim/Vmamba在ImageNet上的训练与测试

下面我将指导你如何部署和运行Vision Mamba (Vim/VMamba)模型在ImageNet-1K数据集上的训练与测试。

1. 环境准备

首先需要设置Python环境并安装必要的依赖:

# 创建conda环境(推荐)
conda create -n vmamba python=3.9 -y
conda activate vmamba# 安装PyTorch (根据你的CUDA版本选择)
pip install torch torchvision torchaudio# 安装其他依赖
pip install timm tensorboardX einops tqdm

2. 获取官方代码

从官方仓库克隆代码:

git clone https://github.com/hustvl/Vim.git
cd Vim

或者对于VMamba:

git clone https://github.com/MzeroMiko/VMamba.git
cd VMamba

3. 准备数据集

ImageNet-1K数据集应该按照以下结构组织:

imagenet/train/class1/img1.JPEGimg2.JPEG...class2/img1.JPEG...val/class1/img1.JPEG...class2/img1.JPEG...

4. 训练脚本

创建一个训练脚本 train.py

import argparse
import os
import time
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from timm.utils import accuracy
from models.vim import Vim  # 或 from models.vmamba import VMambadef parse_args():parser = argparse.ArgumentParser(description='Vim/VMamba Training')parser.add_argument('--data-path', type=str, required=True, help='Path to ImageNet dataset')parser.add_argument('--batch-size', type=int, default=256, help='Batch size')parser.add_argument('--epochs', type=int, default=300, help='Number of epochs')parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')parser.add_argument('--model', type=str, default='vim_tiny', choices=['vim_tiny', 'vim_small', 'vim_base'])parser.add_argument('--output-dir', type=str, default='./output', help='Output directory')parser.add_argument('--num-workers', type=int, default=8, help='Number of data loading workers')return parser.parse_args()def main():args = parse_args()# 创建输出目录os.makedirs(args.output_dir, exist_ok=True)# 数据增强和加载train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])train_dataset = torchvision.datasets.ImageFolder(os.path.join(args.data_path, 'train'),transform=train_transform)val_dataset = torchvision.datasets.ImageFolder(os.path.join(args.data_path, 'val'),transform=val_transform)train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,num_workers=args.num_workers, pin_memory=True)val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,num_workers=args.num_workers, pin_memory=True)# 创建模型model = Vim(img_size=224,patch_size=16,in_chans=3,num_classes=1000,embed_dim=192,depths=[2, 2, 9, 2],drop_path_rate=0.1,model_type=args.model).cuda()# 损失函数和优化器criterion = nn.CrossEntropyLoss().cuda()optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)# 学习率调度器scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)# 训练循环for epoch in range(args.epochs):model.train()start_time = time.time()for i, (images, target) in enumerate(train_loader):images = images.cuda()target = target.cuda()# 前向传播output = model(images)loss = criterion(output, target)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()if i % 100 == 0:acc1, acc5 = accuracy(output, target, topk=(1, 5))print(f'Epoch: [{epoch}/{args.epochs}], Step: [{i}/{len(train_loader)}], 'f'Loss: {loss.item():.4f}, Acc@1: {acc1.item():.3f}, Acc@5: {acc5.item():.3f}')# 验证model.eval()val_loss = 0val_acc1 = 0val_acc5 = 0total = 0with torch.no_grad():for images, target in val_loader:images = images.cuda()target = target.cuda()output = model(images)loss = criterion(output, target)acc1, acc5 = accuracy(output, target, topk=(1, 5))val_loss += loss.item() * images.size(0)val_acc1 += acc1.item() * images.size(0)val_acc5 += acc5.item() * images.size(0)total += images.size(0)val_loss = val_loss / totalval_acc1 = val_acc1 / totalval_acc5 = val_acc5 / totalprint(f'Validation - Epoch: {epoch}, Loss: {val_loss:.4f}, 'f'Acc@1: {val_acc1:.3f}, Acc@5: {val_acc5:.3f}')# 更新学习率scheduler.step()# 保存检查点checkpoint = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch,'args': args}torch.save(checkpoint, os.path.join(args.output_dir, f'checkpoint_{epoch}.pth'))print('Training completed!')if __name__ == '__main__':main()

5. 测试脚本

创建一个测试脚本 test.py

import argparse
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from timm.utils import accuracy
from models.vim import Vim  # 或 from models.vmamba import VMambadef parse_args():parser = argparse.ArgumentParser(description='Vim/VMamba Testing')parser.add_argument('--data-path', type=str, required=True, help='Path to ImageNet validation set')parser.add_argument('--batch-size', type=int, default=256, help='Batch size')parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')parser.add_argument('--num-workers', type=int, default=8, help='Number of data loading workers')return parser.parse_args()def main():args = parse_args()# 数据加载transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])val_dataset = torchvision.datasets.ImageFolder(args.data_path,transform=transform)val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,num_workers=args.num_workers, pin_memory=True)# 加载模型checkpoint = torch.load(args.checkpoint)model = Vim(img_size=224,patch_size=16,in_chans=3,num_classes=1000,embed_dim=192,depths=[2, 2, 9, 2],drop_path_rate=0.1,model_type=checkpoint['args'].model).cuda()model.load_state_dict(checkpoint['model'])model.eval()# 测试val_acc1 = 0val_acc5 = 0total = 0with torch.no_grad():for images, target in val_loader:images = images.cuda()target = target.cuda()output = model(images)acc1, acc5 = accuracy(output, target, topk=(1, 5))val_acc1 += acc1.item() * images.size(0)val_acc5 += acc5.item() * images.size(0)total += images.size(0)val_acc1 = val_acc1 / totalval_acc5 = val_acc5 / totalprint(f'Test Results - Acc@1: {val_acc1:.3f}, Acc@5: {val_acc5:.3f}')if __name__ == '__main__':main()

6. 运行训练和测试

训练模型

python train.py --data-path /path/to/imagenet --model vim_tiny --batch-size 256 --epochs 300 --output-dir ./output

测试模型

python test.py --data-path /path/to/imagenet/val --checkpoint ./output/checkpoint_299.pth

7. 自定义选项

你可以通过以下方式自定义训练:

  1. 模型大小:使用 --model 参数选择 vim_tiny, vim_small, 或 vim_base
  2. 学习率:使用 --lr 调整学习率
  3. 批量大小:使用 --batch-size 调整
  4. 训练周期:使用 --epochs 设置
  5. 数据路径:使用 --data-path 指定ImageNet数据集位置

8. 预期结果

根据官方论文,Vim/VMamba在ImageNet-1K上的预期性能大约为:

模型参数量Top-1 Acc
Vim-Tiny26M~80.5%
Vim-Small44M~82.5%
Vim-Base66M~83.5%

注意:实际结果可能会因超参数设置、训练时长和数据增强策略而有所不同。

9. 使用预训练模型

如果你想快速验证模型性能,可以下载官方预训练权重:

from models.vim import vim_tiny_pretrainedmodel = vim_tiny_pretrained()
model.eval()

然后使用测试脚本进行评估。

希望这个指南能帮助你成功部署和运行Vim/VMamba模型!

http://www.lryc.cn/news/582935.html

相关文章:

  • JavaScript之数组方法详解
  • (C++)list列表相关基础用法(C++教程)(STL库基础教程)
  • HTTP/3.x协议详解:基于QUIC的下一代Web传输协议
  • 音频被动降噪技术
  • nng库使用
  • Android Handler机制与底层原理详解
  • Java 阻塞队列:7种类型全解析
  • 华为eNSP防火墙实验(包含详细步骤)
  • AR 双缝干涉实验亮相:创新科技实验范式,开拓 AR 技术新局​
  • Kafka多组消费:同一Topic,不同Group ID
  • 如何用Python编程计算权重?
  • 常见的网络攻击方式及防御措施
  • 分布式接口幂等性的演进和最佳实践,含springBoot 实现(Java版本)
  • 【c++学习记录】状态模式,实现一个登陆功能
  • 【ES实战】ES客户端线程量分析
  • 从 .proto 到 Python:使用 Protocol Buffers 的完整实践指南
  • 实战Linux进程状态观察:R、S、D、T、Z状态详解与实验模拟
  • 蓝桥杯 第十六届(2025)真题思路复盘解析
  • 50天50个小项目 (Vue3 + Tailwindcss V4) ✨ | StickyNavbar(粘性导航栏)
  • SPI / I2C / UART 哪个更适合初学者?
  • 【C++】AVL树底层思想 and 大厂面试
  • 27.移除元素(快慢指针)
  • AI大模型应用-Ollama本地千问大模型stream流乱码
  • HDLBits刷题笔记和一些拓展知识(十一)
  • 学习设计模式《十七》——状态模式
  • 美团Java面试分享
  • 基于模板设计模式开发优惠券推送功能以及对过期优惠卷进行定时清理
  • 在Docker中安装nexus3(作为maven私服)
  • [创业之路-489]:企业经营层 - 营销 - 如何将缺点转化为特点、再将特点转化为卖点
  • Java基础回顾(1)