当前位置: 首页 > news >正文

8-pytorch-损失函数与反向传播

b站小土堆pytorch教程学习笔记

根据loss更新模型参数
1.计算实际输出与目标之间的差距
2.为我们更新输出提供一定的依据(反向传播)

在这里插入图片描述

1 MSEloss

import torch
from torch.nn import L1Loss
from torch import nninputs=torch.tensor([1,2,3],dtype=torch.float32)
targets=torch.tensor([1,2,5],dtype=torch.float32)inputs=torch.reshape(inputs,(-1,1,1,3))
targets=torch.reshape(targets,(-1,1,1,3))loss=L1Loss()
result=loss(inputs,targets)loss_mse=nn.MSELoss()
result_mse=loss_mse(inputs,targets)print(result)
print(result_mse)

tensor(0.6667)
tensor(1.3333)

2 Cross EntropyLoss

在这里插入图片描述

x=torch.tensor([0.1,0.2,0.3])#需要reshape为要求的(batch_size,class)
y=torch.tensor([1])#target已经为要求的batch_size无需reshape
x=torch.reshape(x,(-1,3))
loss_cross=nn.CrossEntropyLoss()
result_cross=loss_cross(x,y)
print(result_cross)

tensor(1.1019)

3 在具体的神经网络中使用loss

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset,batch_size=1)class Han(nn.Module):def __init__(self):super(Han, self).__init__()self.model1=Sequential(Conv2d(3,32,5,padding=2),MaxPool2d(2),Conv2d(32,32,5,padding=2),MaxPool2d(2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self,x):x=self.model1(x)return xloss=nn.CrossEntropyLoss()
han=Han()
for data in dataloader:imgs,target=dataoutput=han(imgs)# print(target)# print(output)result_loss=loss(output,target)print(result_loss)

*tensor([7])
tensor([[ 0.0057, -0.0201, -0.0796, 0.0556, -0.0625, 0.0125, -0.0413, -0.0056,
0.0624, -0.1072]], grad_fn=)…

tensor(2.2664, grad_fn=)…

4 反向传播 优化器

  1. 定义优化器
  2. 将待更新的每个参数梯度清零
  3. 调用损失函数的反向传播函数求出每个节点的梯度
  4. 使用step函数对模型的每个参数调优
import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset,batch_size=64)class Han(nn.Module):def __init__(self):super(Han, self).__init__()self.model1=Sequential(Conv2d(3,32,5,padding=2),MaxPool2d(2),Conv2d(32,32,5,padding=2),MaxPool2d(2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self,x):x=self.model1(x)return xloss=nn.CrossEntropyLoss()
han=Han()
optim=torch.optim.SGD(han.parameters(),lr=0.01)for epoch in range(5):running_loss=0.0#一个epoch结束的loss和for data in dataloader:imgs,target=dataoutput=han(imgs)result_loss=loss(output,target)#每次迭代的lossoptim.zero_grad()#将网络中每个可调节参数对应的梯度调为0result_loss.backward()#优化器需要每个参数的梯度,使用反向传播获得optim.step()#对每个参数调优running_loss=running_loss+result_lossprint(running_loss)

Files already downloaded and verified
tensor(361.0316, grad_fn=)
tensor(357.6938, grad_fn=)
tensor(343.0560, grad_fn=)
tensor(321.8132, grad_fn=)
tensor(313.3173, grad_fn=)

http://www.lryc.cn/news/305795.html

相关文章:

  • MySQL高级特性篇(8)-数据库连接池的配置与优化
  • mac下使用jadx反编译工具
  • 分布式一致性软件-zookeeper
  • 企业计算机服务器中了babyk勒索病毒怎么办?Babyk勒索病毒解密数据恢复
  • 板块一 Servlet编程:第五节 Cookie对象全解 来自【汤米尼克的JAVAEE全套教程专栏】
  • 自动驾驶---Motion Planning之Path Boundary
  • Leetcode 3048. Earliest Second to Mark Indices I
  • 从源码学习单例模式
  • axios介绍和使用
  • redis雪崩问题
  • [SUCTF 2019]EasySQL1 题目分析与详解
  • TestNG与ExtentReport单元测试导出报告文档
  • 【JavaEE】_form表单构造HTTP请求
  • Mysql中INFORMATION_SCHEMA虚拟库使用
  • 【《高性能 MySQL》摘录】第 2 章 MySQL 基准测试
  • 常用的Web应用程序的自动测试工具有哪些
  • 人工智能与开源机器学习框架
  • 高通XBL阶段读取分区
  • [极客大挑战2019]upload
  • [FastDDS] 基于eProsima FastDDS的移动机器人数据中间件
  • 实现外网手机或者电脑随时随地远程访问家里的电脑主机(linux为例)
  • spring boot集成redis
  • Docker的常用命令
  • JSON简介与基本使用
  • 好物周刊#40:多功能文件管理器
  • 【洛谷 P8780】[蓝桥杯 2022 省 B] 刷题统计 题解(贪心算法+模拟+四则运算)
  • 【蓝桥杯入门记录】静态数码管例程
  • 6.openEuler系统服务的配置和管理(二)
  • 一招鲜吃遍天!ChatGPT高级咒语揭秘:记忆、洗稿、速写SEO文章(一)
  • LeetCode 每日一题 2024/2/19-2024/2/25