当前位置: 首页 > news >正文

pytorch学习笔记-Loss的使用、在神经网络中加入Loss、优化器(optimizer)的使用

博主最近真要累鼠了…
anyway上号更新一点,预计下周就能把该系列学完了= =

Loss的使用

具体关注一下官网上对于形状的说明,如果报错就看看是不是形状不符合要求,其他没啥

import torch
from torch.nn import L1Loss,MSELoss, CrossEntropyLoss
from torch import nn#L1Loss
inputs = torch.tensor([1,2,3],dtype=torch.float32)
targets = torch.tensor([1,2,5],dtype=torch.float32)l1_loss = L1Loss()
res = l1_loss(inputs, targets)print("L1Loss:",res)#MSE
mse_loss = MSELoss()
res = mse_loss(inputs, targets)print("mseLoss:",res)#crossentropy
x = torch.tensor([0.1,0.2,0.3])#每一类的概率
y = torch.tensor([1])#目标类别编号# Input: Shape(N,C), N:bt_size C:class
x = torch.reshape(x,(1,3))
# print(x)cross_loss = CrossEntropyLoss()
res = cross_loss(x,y)
print("crossentropyLoss:",res)# L1Loss: tensor(0.6667)
# mseLoss: tensor(1.3333)
# crossentropyLoss: tensor(1.1019)

在神经网络中加入Loss

了解Loss的本质就是计算真值和目标值之间的差距后,怎么在神经网络中引入loss也蛮自然的:

import torch
import torch.nn as nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential, CrossEntropyLoss
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets
from torch.utils.data import DataLoaderdata_transforms = transforms.Compose([transforms.ToTensor()
])test_data = datasets.CIFAR10(root="./dataset",train=False,transform=data_transforms)dataloader = DataLoader(test_data,batch_size=64)class myModule(nn.Module):def __init__(self):super().__init__()self.model1 = Sequential(Conv2d(3,32,5,padding=2),MaxPool2d(2),Conv2d(32,32,5,padding=2),MaxPool2d(2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self, x):x = self.model1(x)return xmy_module = myModule()loss = CrossEntropyLoss()for data in dataloader:imgs, targets = dataoutputs = my_module(imgs)res_loss = loss(outputs,targets)print(res_loss)

优化器的使用

引入Loss的目的是为了更好的进行参数更新,因此需要引入优化器
事实上引入这一步后也就基本知道了模型如何进行训练了
一般在一次学习中就进行了多次更新,需要进行多次学习,注意的点就是每次梯度计算优化前需要先将上一轮计算得到的梯度清零,因为上一批次的对本次的结果意义不大,剩下的就是用法:

import torch
import torch.nn as nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential, CrossEntropyLoss
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch import optimdata_transforms = transforms.Compose([transforms.ToTensor()
])test_data = datasets.CIFAR10(root="./dataset",train=False,transform=data_transforms)dataloader = DataLoader(test_data,batch_size=64)class myModule(nn.Module):def __init__(self):super().__init__()self.model1 = Sequential(Conv2d(3,32,5,padding=2),MaxPool2d(2),Conv2d(32,32,5,padding=2),MaxPool2d(2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self, x):x = self.model1(x)return xmy_module = myModule()loss = CrossEntropyLoss()#设置优化器
optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01)for epoch in range(5):running_loss = 0.0 #计算每一个epoch的lossfor data in dataloader:imgs, targets = dataoutputs = my_module(imgs)res_loss = loss(outputs,targets)# print(res_loss)running_loss+=res_lossoptimizer.zero_grad()#每一次要重新清零上一轮的梯度res_loss.backward()optimizer.step()#进行优化lossprint(running_loss)# tensor(360.6527, grad_fn=<AddBackward0>)
# tensor(355.2374, grad_fn=<AddBackward0>)
# tensor(336.8181, grad_fn=<AddBackward0>)
# tensor(320.9179, grad_fn=<AddBackward0>)
# tensor(312.2012, grad_fn=<AddBackward0>)
http://www.lryc.cn/news/621303.html

相关文章:

  • Linux 对 YUM 包的管理
  • HTTPS 工作原理
  • Java使用Apache POI读取Excel文件
  • dkms安装nvidia驱动和多内核支持
  • label studio 服务器端打开+xshell端口转发设置
  • UniApp 中使用 tui-xecharts插件(或类似图表库如 uCharts)
  • 2025年Java大厂面试场景题全解析:高频考点与实战攻略
  • 20道DOM相关前端面试题
  • Java面试场景题大全精简版
  • VSCode打开新的文件夹之后当前打开的文件夹被覆盖
  • 树形DP详解
  • 基于springboot的信息化在线教学平台的设计与实现(源码+论文)
  • 2025天府杯数学建模C题
  • Python网络爬虫(二) - 解析静态网页
  • MFC的使用——使用ChartCtrl绘制曲线
  • 数据结构初阶(13)排序算法-选择排序(选择排序、堆排序)(动图演示)
  • 手机实时提取SIM卡打电话的信令声音-整体解决方案规划
  • 百度智能云x中科大脑:「城市智能体」如何让城市更会思考
  • pyecharts可视化图表-pie:从入门到精通
  • QT中ARGB32转ARGB4444优化4K图像性能的实现方案(完整源码)
  • 基于SpringBoot的救援物资管理系统 受灾应急物资管理系统 物资管理小程序
  • 日志系统(log4cpp)
  • Torch -- 卷积学习day2 -- 卷积扩展、数据集、模型
  • AM32电调学习-使用Keil编译uboot
  • JVM的逃逸分析深入学习
  • 一、linux内存管理学习(1):物理内存探测
  • 18 ABP Framework 模块管理
  • Encoder-Decoder Model编码器-解码器模型
  • MCP入门:Python开发者的模型上下文协议实战指南
  • 蓝桥杯STL stack