当前位置: 首页 > news >正文

python打卡day45

@疏锦行

知识点回顾:

  1. tensorboard的发展历史和原理
  2. tensorboard的常见操作
  3. tensorboard在cifar上的实战:MLP和CNN模型

    作业:对resnet18在cifar10上采用微调策略下,用tensorboard监控训练过程。

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import datasets, transforms, models
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    import os# 设置随机种子保证可重复性
    torch.manual_seed(42)# 定义数据预处理
    transform = {    'train': transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),'test': transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    }# 加载 CIFAR-10 数据集
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform['train'])
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform['test'])train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 初始化 TensorBoard 的 SummaryWriter
    log_dir = 'runs/resnet18_cifar10_finetune'
    if os.path.exists(log_dir):i = 1while os.path.exists(f"{log_dir}_{i}"):i += 1log_dir = f"{log_dir}_{i}"
    writer = SummaryWriter(log_dir)# 加载预训练的 ResNet18 模型
    model = models.resnet18(pretrained=True)# 修改最后一层全连接层以适应 CIFAR-10 的 10 个类别
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 10)# 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 设置训练设备
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)# 训练和验证函数
    num_epochs = 10
    for epoch in range(num_epochs):# 训练阶段model.train()running_loss = 0.0running_corrects = 0for inputs, labels in train_loader:inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(train_dataset)epoch_acc = running_corrects.double() / len(train_dataset)# 使用 TensorBoard 记录训练损失和准确率writer.add_scalar('Train/Loss', epoch_loss, epoch)writer.add_scalar('Train/Accuracy', epoch_acc, epoch)# 验证阶段model.eval()val_running_loss = 0.0val_running_corrects = 0with torch.no_grad():for inputs, labels in test_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)val_running_loss += loss.item() * inputs.size(0)val_running_corrects += torch.sum(preds == labels.data)val_epoch_loss = val_running_loss / len(test_dataset)val_epoch_acc = val_running_corrects.double() / len(test_dataset)# 使用 TensorBoard 记录验证损失和准确率writer.add_scalar('Validation/Loss', val_epoch_loss, epoch)writer.add_scalar('Validation/Accuracy', val_epoch_acc, epoch)print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_acc:.4f}')# 可视化模型结构
    dataiter = iter(train_loader)
    images, labels = next(dataiter)
    images = images.to(device)
    writer.add_graph(model, images)
    writer.close()

http://www.lryc.cn/news/576614.html

相关文章:

  • 顺序表的常见算法
  • FPGA设计的时序分析概要
  • 鸿蒙 Grid 与 GridItem 深度解析:二维网格布局解决方案
  • 【 Linux 输入子系统】
  • python的医疗废弃物收运管理系统
  • 【力扣 中等 C】79. 单词搜索
  • Webpack 核心与基础使用
  • 数据结构之——顺序栈与链式栈
  • 个人日记本小程序开发方案(使用IntelliJ IDEA)
  • ORB-SLAM + D435i提取相机位姿 + ROS发布
  • 现代串口通讯UI框架性能对比
  • 容器安全——AI教你学Docker
  • 机器学习——线性回归
  • 【数据标注师】3D标注
  • 使用Calibre对GDS进行数据遍历
  • Note2.4 机器学习:Batch Normalization Introduction
  • 【go】初学者入门环境配置,GOPATH,GOROOT,GOCACHE,以及GoLand使用配置注意
  • LNA设计
  • 【安卓Sensor框架-1】SensorService 的启动流程
  • iOS 使用 SceneKit 实现全景图
  • MCPA2APPT:基于 A2A+MCP+ADK 的多智能体流式并发高质量 PPT 智能生成系统
  • 微处理原理与应用篇---STM32寄存器控制GPIO
  • Unity2D 街机风太空射击游戏 学习记录 #16 道具父类提取 旋涡道具
  • FPGA内部资源介绍
  • Python爬虫实战:研究sanitize库相关技术
  • 笔记07:网表的输出与导入
  • SQL关键字三分钟入门:RANK() —— 窗口函数
  • Java AI 新纪元:Spring AI 与 Spring AI Alibaba 的崛起
  • JavaScript正则表达式之正向先行断言(Positive Lookahead)深度解析
  • 第8章-财务数据