部署并运行Vim/Vmamba在ImageNet上的训练与测试
部署并运行Vim/Vmamba在ImageNet上的训练与测试
下面我将指导你如何部署和运行Vision Mamba (Vim/VMamba)模型在ImageNet-1K数据集上的训练与测试。
1. 环境准备
首先需要设置Python环境并安装必要的依赖:
# 创建conda环境(推荐)
conda create -n vmamba python=3.9 -y
conda activate vmamba# 安装PyTorch (根据你的CUDA版本选择)
pip install torch torchvision torchaudio# 安装其他依赖
pip install timm tensorboardX einops tqdm
2. 获取官方代码
从官方仓库克隆代码:
git clone https://github.com/hustvl/Vim.git
cd Vim
或者对于VMamba:
git clone https://github.com/MzeroMiko/VMamba.git
cd VMamba
3. 准备数据集
ImageNet-1K数据集应该按照以下结构组织:
imagenet/train/class1/img1.JPEGimg2.JPEG...class2/img1.JPEG...val/class1/img1.JPEG...class2/img1.JPEG...
4. 训练脚本
创建一个训练脚本 train.py
:
import argparse
import os
import time
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from timm.utils import accuracy
from models.vim import Vim # 或 from models.vmamba import VMambadef parse_args():parser = argparse.ArgumentParser(description='Vim/VMamba Training')parser.add_argument('--data-path', type=str, required=True, help='Path to ImageNet dataset')parser.add_argument('--batch-size', type=int, default=256, help='Batch size')parser.add_argument('--epochs', type=int, default=300, help='Number of epochs')parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')parser.add_argument('--model', type=str, default='vim_tiny', choices=['vim_tiny', 'vim_small', 'vim_base'])parser.add_argument('--output-dir', type=str, default='./output', help='Output directory')parser.add_argument('--num-workers', type=int, default=8, help='Number of data loading workers')return parser.parse_args()def main():args = parse_args()# 创建输出目录os.makedirs(args.output_dir, exist_ok=True)# 数据增强和加载train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])train_dataset = torchvision.datasets.ImageFolder(os.path.join(args.data_path, 'train'),transform=train_transform)val_dataset = torchvision.datasets.ImageFolder(os.path.join(args.data_path, 'val'),transform=val_transform)train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,num_workers=args.num_workers, pin_memory=True)val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,num_workers=args.num_workers, pin_memory=True)# 创建模型model = Vim(img_size=224,patch_size=16,in_chans=3,num_classes=1000,embed_dim=192,depths=[2, 2, 9, 2],drop_path_rate=0.1,model_type=args.model).cuda()# 损失函数和优化器criterion = nn.CrossEntropyLoss().cuda()optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)# 学习率调度器scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)# 训练循环for epoch in range(args.epochs):model.train()start_time = time.time()for i, (images, target) in enumerate(train_loader):images = images.cuda()target = target.cuda()# 前向传播output = model(images)loss = criterion(output, target)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()if i % 100 == 0:acc1, acc5 = accuracy(output, target, topk=(1, 5))print(f'Epoch: [{epoch}/{args.epochs}], Step: [{i}/{len(train_loader)}], 'f'Loss: {loss.item():.4f}, Acc@1: {acc1.item():.3f}, Acc@5: {acc5.item():.3f}')# 验证model.eval()val_loss = 0val_acc1 = 0val_acc5 = 0total = 0with torch.no_grad():for images, target in val_loader:images = images.cuda()target = target.cuda()output = model(images)loss = criterion(output, target)acc1, acc5 = accuracy(output, target, topk=(1, 5))val_loss += loss.item() * images.size(0)val_acc1 += acc1.item() * images.size(0)val_acc5 += acc5.item() * images.size(0)total += images.size(0)val_loss = val_loss / totalval_acc1 = val_acc1 / totalval_acc5 = val_acc5 / totalprint(f'Validation - Epoch: {epoch}, Loss: {val_loss:.4f}, 'f'Acc@1: {val_acc1:.3f}, Acc@5: {val_acc5:.3f}')# 更新学习率scheduler.step()# 保存检查点checkpoint = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch,'args': args}torch.save(checkpoint, os.path.join(args.output_dir, f'checkpoint_{epoch}.pth'))print('Training completed!')if __name__ == '__main__':main()
5. 测试脚本
创建一个测试脚本 test.py
:
import argparse
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from timm.utils import accuracy
from models.vim import Vim # 或 from models.vmamba import VMambadef parse_args():parser = argparse.ArgumentParser(description='Vim/VMamba Testing')parser.add_argument('--data-path', type=str, required=True, help='Path to ImageNet validation set')parser.add_argument('--batch-size', type=int, default=256, help='Batch size')parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')parser.add_argument('--num-workers', type=int, default=8, help='Number of data loading workers')return parser.parse_args()def main():args = parse_args()# 数据加载transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])val_dataset = torchvision.datasets.ImageFolder(args.data_path,transform=transform)val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,num_workers=args.num_workers, pin_memory=True)# 加载模型checkpoint = torch.load(args.checkpoint)model = Vim(img_size=224,patch_size=16,in_chans=3,num_classes=1000,embed_dim=192,depths=[2, 2, 9, 2],drop_path_rate=0.1,model_type=checkpoint['args'].model).cuda()model.load_state_dict(checkpoint['model'])model.eval()# 测试val_acc1 = 0val_acc5 = 0total = 0with torch.no_grad():for images, target in val_loader:images = images.cuda()target = target.cuda()output = model(images)acc1, acc5 = accuracy(output, target, topk=(1, 5))val_acc1 += acc1.item() * images.size(0)val_acc5 += acc5.item() * images.size(0)total += images.size(0)val_acc1 = val_acc1 / totalval_acc5 = val_acc5 / totalprint(f'Test Results - Acc@1: {val_acc1:.3f}, Acc@5: {val_acc5:.3f}')if __name__ == '__main__':main()
6. 运行训练和测试
训练模型
python train.py --data-path /path/to/imagenet --model vim_tiny --batch-size 256 --epochs 300 --output-dir ./output
测试模型
python test.py --data-path /path/to/imagenet/val --checkpoint ./output/checkpoint_299.pth
7. 自定义选项
你可以通过以下方式自定义训练:
- 模型大小:使用
--model
参数选择vim_tiny
,vim_small
, 或vim_base
- 学习率:使用
--lr
调整学习率 - 批量大小:使用
--batch-size
调整 - 训练周期:使用
--epochs
设置 - 数据路径:使用
--data-path
指定ImageNet数据集位置
8. 预期结果
根据官方论文,Vim/VMamba在ImageNet-1K上的预期性能大约为:
模型 | 参数量 | Top-1 Acc |
---|---|---|
Vim-Tiny | 26M | ~80.5% |
Vim-Small | 44M | ~82.5% |
Vim-Base | 66M | ~83.5% |
注意:实际结果可能会因超参数设置、训练时长和数据增强策略而有所不同。
9. 使用预训练模型
如果你想快速验证模型性能,可以下载官方预训练权重:
from models.vim import vim_tiny_pretrainedmodel = vim_tiny_pretrained()
model.eval()
然后使用测试脚本进行评估。
希望这个指南能帮助你成功部署和运行Vim/VMamba模型!