当前位置: 首页 > news >正文

35- tensorboard的使用 (PyTorch系列) (深度学习)

知识要点

  • FashionMNIST数据集: 十种产品的分类.        # T-shirt/top, Trouser, Pullover, Dress, Coat,Sandal, Shirt, Sneaker, Bag, Ankle Boot.
  • writer = SummaryWriter('run/fashion_mnist_experiment_1')    # 网站显示


一 tensorboard的使用

  • 在网站显示pytorch的架构:

1.1 导包

import matplotlib.pyplot as plt 
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

1.2 数据导入

# transforms
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, ), (0.5))])  # 正则化# datasets
trainset = torchvision.datasets.FashionMNIST('./data',download=True,train=True,transform = transform)testset = torchvision.datasets.FashionMNIST('./data',download=True,train=False,transform = transform)
  • dataloader 设置
# dataloaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=2)# constant for classes
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

1.3 定义模型

# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 4 * 4)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()
# 定义损失和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum=0.9)

1.4 tensorboard的使用

  • tensorboard的安装: pip install tensorboard -i https://pypi.douban.com/simple
def matplotlib_imshow(img, one_channel=False):if one_channel:img = img.mean(dim=0)img = img / 2 + 0.5npimg = img.numpy()if one_channel:plt.imshow(npimg, cmap= 'Greys')else:plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(trainloader)
images, labels = next(dataiter)
images.shape   # torch.Size([4, 1, 28, 28])
# torchvision 中make_grid 可以把多张图合并成一张图
img_grid = torchvision.utils.make_grid(images)
img_grid.shape   # torch.Size([3, 32, 122])
matplotlib_imshow(img_grid, one_channel=True)

  • from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('run/fashion_mnist_experiment_1')
writer.add_image('four_fashion_mnist_images', img_grid)images, labels = next(dataiter)
img_grid2 = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid2, one_channel=True)

writer.add_image('img_grid2', img_grid2)

1.5 添加模型的结构图

writer.add_graph(net, images)  # 模型可视化

1.6 添加损失变化

# writer.add_scaler()
running_loss = 0.0
for epoch in range(1):  # loop over the dataset multiple timesfor i, data in enumerate(trainloader, 0):# get the inputs: data is list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 1000 == 999:  # every 1000 mini-batches...# log the running losswriter.add_scalar('training loss',running_loss / 1000,epoch * len(trainloader) + i)running_loss = 0.0
print('Finished Training')
http://www.lryc.cn/news/33886.html

相关文章:

  • ChatGPT在工业领域的用法
  • 使用Chakra-UI封装简书的登录页面组件(React)
  • Three.js初试——基础概念(二)
  • Qt音视频开发21-mpv内核万能属性机制
  • C语言学生随机抽号演讲计分系统
  • Spring Boot 3.0系列【12】核心特性篇之任务调度
  • Java操作XML
  • 女神节灯笼祝福【HTML+CSS】
  • CUDA并行计算基础知识
  • 88. 合并两个有序数组
  • 卢益贵(码客):软件开发团队的管理要素
  • 中小企业的TO B蓝海,如何「掘金」?
  • C++ 算法主题系列之集结0-1背包问题的所有求解方案
  • 【Vue】Vue常见的6种指令
  • 计算机科学与技术(嵌入式)四年学习资料_文件目录树
  • 【java】Java 继承
  • 自媒体账号数据分析从何入手?
  • Clickhouse新版本JSON字段数据写入方式
  • HNU-电路与电子学-实验2
  • 从0开始学python -49
  • Spring MVC 详解(连接、获取参数、返回数据)
  • IT女神节(致敬中国IT界永远的女神严蔚敏-数据结构)
  • Java 集合分页
  • 代码随想录之哈希表(力扣题号)
  • 如何在知行之桥EDI系统中定时自动更换交易伙伴AS2证书?
  • 辽宁千圣文化:抖音店铺怎么做二次优化?
  • 检测js代码中可能导致内存泄漏的工具
  • linux和centos读写日期到文件并对日期进行比较
  • Espressif-IDE v2.8.0 新增功能及开发方向
  • C++学习笔记之基础