day44打卡
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
import warnings
# 安装并导入 tqdm
%pip install tqdm
from tqdm import tqdm
warnings.filterwarnings("ignore")
# --- 步骤 1: 数据准备 (保持不变) ---
def get_cifar10_loaders(batch_size=128, resize_to=32):
"""
获取CIFAR-10的数据加载器。
新增 resize_to 参数以适应不同模型的输入要求。
"""
print(f"--- 正在准备数据 (图像将缩放至 {resize_to}x{resize_to}) ---")
# 训练集使用数据增强
train_transform = transforms.Compose([
transforms.Resize(resize_to), # 新增:缩放图像以匹配模型输入
transforms.RandomCrop(resize_to, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
# 测试集只进行必要的缩放和标准化
test_transform = transforms.Compose([
transforms.Resize(resize_to),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
print("✅ 数据加载器准备完成。")
return train_loader, test_loader
# --- 步骤 2: 模型创建函数 ---
# ResNet18 创建函数 (保持不变)
def create_resnet18(pretrained=True, num_classes=10):
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
return model
# 【新增】MobileNetV2 创建函数
def create_mobilenet_v2(pretrained=True, num_classes=10):
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1 if pretrained else None)
# MobileNetV2的分类器是一个包含Linear层的Sequential
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, num_classes)
return model
# --- 步骤 3: 训练与评估框架 (保持不变) ---
def run_experiment(model_name: str, model_creator, device, epochs, freeze_epochs):
"""运行一次完整的实验"""
print(f"\n{'='*25} 开始实验: {model_name} {'='*25}")
# 1. 准备数据和模型
train_loader, test_loader = get_cifar10_loaders()
model = model_creator(pretrained=True, num_classes=10).to(device)
# 打印模型参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"模型总参数量: {total_params / 1e6:.2f}M")
# 2. 冻结/解冻与优化器设置
def set_freeze_state(freeze=True):
print(f"--- {'冻结' if freeze else '解冻'} 特征提取层 ---")
for name, param in model.named_parameters():
# 最后一个全连接层始终可训练
if 'fc' not in name and 'classifier' not in name:
param.requires_grad = not freeze
set_freeze_state(freeze=True) # 初始冻结
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=2, verbose=True)
# 3. 训练循环
start_time = time.time()
for epoch in range(1, epochs + 1):
# 解冻控制
if epoch == freeze_epochs + 1:
set_freeze_state(freeze=False)
# 解冻后需要重新定义优化器,以包含所有可训练参数
optimizer = optim.Adam(model.parameters(), lr=1e-4) # 使用更小的学习率进行全局微调
print("优化器已更新以包含所有参数。")
model.train()
loop = tqdm(train_loader, desc=f"Epoch [{epoch}/{epochs}]", leave=False)
for data, target in loop:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
loop.set_postfix(loss=loss.item())
loop.close()
# 评估
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item() * data.size(0)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
avg_test_loss = test_loss / len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print(f"Epoch {epoch} 完成 | 测试集损失: {avg_test_loss:.4f} | 测试集准确率: {accuracy:.2f}%")
scheduler.step(avg_test_loss)
end_time = time.time()
print(f"✅ 实验 '{model_name}' 完成,总耗时: {end_time - start_time:.2f
@浙大疏锦行