当前位置: 首页 > news >正文

day44打卡

import torch

import torch.nn as nn

import torch.optim as optim

from torchvision import datasets, transforms, models

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

import time

import warnings

 

# 安装并导入 tqdm

%pip install tqdm

from tqdm import tqdm

 

warnings.filterwarnings("ignore")

 

# --- 步骤 1: 数据准备 (保持不变) ---

def get_cifar10_loaders(batch_size=128, resize_to=32):

    """

    获取CIFAR-10的数据加载器。

    新增 resize_to 参数以适应不同模型的输入要求。

    """

    print(f"--- 正在准备数据 (图像将缩放至 {resize_to}x{resize_to}) ---")

    

    # 训练集使用数据增强

    train_transform = transforms.Compose([

        transforms.Resize(resize_to), # 新增:缩放图像以匹配模型输入

        transforms.RandomCrop(resize_to, padding=4),

        transforms.RandomHorizontalFlip(),

        transforms.ToTensor(),

        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

    ])

 

    # 测试集只进行必要的缩放和标准化

    test_transform = transforms.Compose([

        transforms.Resize(resize_to),

        transforms.ToTensor(),

        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

    ])

 

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)

    test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)

    

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    

    print("✅ 数据加载器准备完成。")

    return train_loader, test_loader

 

# --- 步骤 2: 模型创建函数 ---

 

# ResNet18 创建函数 (保持不变)

def create_resnet18(pretrained=True, num_classes=10):

    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)

    in_features = model.fc.in_features

    model.fc = nn.Linear(in_features, num_classes)

    return model

 

# 【新增】MobileNetV2 创建函数

def create_mobilenet_v2(pretrained=True, num_classes=10):

    model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1 if pretrained else None)

    # MobileNetV2的分类器是一个包含Linear层的Sequential

    in_features = model.classifier[1].in_features

    model.classifier[1] = nn.Linear(in_features, num_classes)

    return model

 

# --- 步骤 3: 训练与评估框架 (保持不变) ---

def run_experiment(model_name: str, model_creator, device, epochs, freeze_epochs):

    """运行一次完整的实验"""

    print(f"\n{'='*25} 开始实验: {model_name} {'='*25}")

    

    # 1. 准备数据和模型

    train_loader, test_loader = get_cifar10_loaders()

    model = model_creator(pretrained=True, num_classes=10).to(device)

    

    # 打印模型参数量

    total_params = sum(p.numel() for p in model.parameters())

    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"模型总参数量: {total_params / 1e6:.2f}M")

    

    # 2. 冻结/解冻与优化器设置

    def set_freeze_state(freeze=True):

        print(f"--- {'冻结' if freeze else '解冻'} 特征提取层 ---")

        for name, param in model.named_parameters():

            # 最后一个全连接层始终可训练

            if 'fc' not in name and 'classifier' not in name:

                param.requires_grad = not freeze

    

    set_freeze_state(freeze=True) # 初始冻结

    

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

    criterion = nn.CrossEntropyLoss()

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=2, verbose=True)

 

    # 3. 训练循环

    start_time = time.time()

    for epoch in range(1, epochs + 1):

        # 解冻控制

        if epoch == freeze_epochs + 1:

            set_freeze_state(freeze=False)

            # 解冻后需要重新定义优化器,以包含所有可训练参数

            optimizer = optim.Adam(model.parameters(), lr=1e-4) # 使用更小的学习率进行全局微调

            print("优化器已更新以包含所有参数。")

        

        model.train()

        loop = tqdm(train_loader, desc=f"Epoch [{epoch}/{epochs}]", leave=False)

        for data, target in loop:

            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()

            output = model(data)

            loss = criterion(output, target)

            loss.backward()

            optimizer.step()

            loop.set_postfix(loss=loss.item())

        loop.close()

        

        # 评估

        model.eval()

        test_loss = 0

        correct = 0

        with torch.no_grad():

            for data, target in test_loader:

                data, target = data.to(device), target.to(device)

                output = model(data)

                test_loss += criterion(output, target).item() * data.size(0)

                pred = output.argmax(dim=1)

                correct += pred.eq(target).sum().item()

        

        avg_test_loss = test_loss / len(test_loader.dataset)

        accuracy = 100. * correct / len(test_loader.dataset)

        print(f"Epoch {epoch} 完成 | 测试集损失: {avg_test_loss:.4f} | 测试集准确率: {accuracy:.2f}%")

        

        scheduler.step(avg_test_loss)

    

    end_time = time.time()

    print(f"✅ 实验 '{model_name}' 完成,总耗时: {end_time - start_time:.2f

@浙大疏锦行

http://www.lryc.cn/news/581219.html

相关文章:

  • cmd 的sftp传输;Conda出现环境问题: error: invalid value for --gpu-architecture (-arch)
  • 浅度解读-(未完成版)浅层神经网络-多个隐层神经元
  • 前端-CSS-day1
  • 【openp2p】学习3:【专利分析】一种基于混合网络的自适应切换方法、装 置、设备及介质
  • WSL命令
  • 【爬虫】逆向爬虫初体验之爬取音乐
  • 大模型算法面试笔记——Bert
  • 计算机网络(网页显示过程,TCP三次握手,HTTP1.0,1.1,2.0,3.0,JWT cookie)
  • 一键将 SQL 转为 Java 实体类,全面支持 MySQL / PostgreSQL / Oracle!
  • 永磁同步电机无速度算法--基于锁频环前馈锁相环的滑模观测器
  • 使用SSH隧道连接远程主机
  • 五、Python新特性指定类型用法
  • 【赵渝强老师】Oracle RMAN的目录数据库
  • 数据库-元数据表
  • 事务的原子性
  • 自建双因素认证器 2FAuth 完美替代 Google Auth / Microsoft Auth
  • CSS 文字浮雕效果:巧用 text-shadow 实现 3D 立体文字
  • 虚拟机与容器技术详解:VM、LXC、LXD与Docker
  • HarmonyOS学习3---ArkUI
  • 《Redis》哨兵模式
  • ✨ OpenAudio S1:影视级文本转语音与语音克隆Mac整合包
  • 构建未来交互体验:AG-UI 如何赋能智能体与前端通信?
  • openai和chatgpt什么关系
  • hono框架绑定cloudflare的d1数据库操作步骤
  • 2025最新Telegram快读助手:一款智能Telegram链接摘要机器人
  • 【leetcode100】最长回文子串
  • 探索 .NET 桌面开发:WinForms、WPF、.NET MAUI 和 Avalonia 的全面对比(截至2025年7月)
  • MAX3485在MCU芯片AS32S601-485通信外设中的应用
  • Java 创建对象过程 JVM 内存分配并发安全笔记
  • 介绍Flutter