当前位置: 首页 > news >正文

PyTorch-Loss Function and BP

目录

1. Loss Function

1.1 L1Loss

1.2 MSELoss

1.3 CrossEntropyLoss

2. 交叉熵与神经网络模型的结合

2.1 反向传播

1. Loss Function

目的: 

a. 计算预测值与真实值之间的差距;

b. 可通过此条件,进行反向传播。

1.1 L1Loss

import torch
from torch.nn import L1Lossinputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)
inputs = torch.reshape(inputs, (1, 1, 1, 3))  # 1-batch_size,1-channel,1×3
targets = torch.reshape(targets, (1, 1, 1, 3))
loss = L1Loss()
result = loss(inputs, targets)
print(result)  # tensor(0.6667)
loss1 = L1Loss(reduction='sum')
result1 = loss1(inputs, targets)
print(result1)  # tensor(2.)

1.2 MSELoss

import torch
from torch.nn import L1Loss, MSELossinputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)
inputs = torch.reshape(inputs, (1, 1, 1, 3))  # 1-batch_size,1-channel,1×3
targets = torch.reshape(targets, (1, 1, 1, 3))
loss_mse = MSELoss()
res = loss_mse(inputs, targets)
print(res)  # tensor(1.3333)

1.3 CrossEntropyLoss

图片来源于:b站up主 我是土堆

It is useful when training a classification problem with C classes. 

import torch
from torch import nnx = torch.tensor([0.1, 0.2, 0.3])
y = torch.tensor([1])
x = torch.reshape(x, (1, 3))  # 1-batch_size,3 classes
loss_cross = nn.CrossEntropyLoss()
res = loss_cross(x, y)
print(res)  # tensor(1.1019)

2. 交叉熵与神经网络模型的结合

nn_loss_network.py

import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset, batch_size=1)class MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xmyModule1 = MyModule()
for data in dataloader:imgs, targets = dataoutputs = myModule1(imgs)print(outputs)print(targets)

tensor([[-0.1187,  0.1490, -0.1015,  0.0767, -0.0677, -0.0625,  0.0553, -0.0932,
         -0.0866,  0.0746]], grad_fn=<AddmmBackward0>)
tensor([1])

计算交叉熵损失

loss = nn.CrossEntropyLoss()
myModule1 = MyModule()
for data in dataloader:imgs, targets = dataoutputs = myModule1(imgs)res_loss = loss(outputs, targets)print(res_loss)

tensor(2.4315, grad_fn=<NllLossBackward0>)
tensor(2.3594, grad_fn=<NllLossBackward0>)
tensor(2.3659, grad_fn=<NllLossBackward0>)

...

2.1 反向传播

for data in dataloader:imgs, targets = dataoutputs = myModule1(imgs)res_loss = loss(outputs, targets)res_loss.backward()

http://www.lryc.cn/news/91009.html

相关文章:

  • centos docker安装mysql8
  • Java中synchronized锁的深入理解
  • Find My资讯|iOS17将重点改进钱包、Find My、SharePlay和AirPlay等功能
  • 什么是webSocket?
  • 黑马Redis视频教程高级篇(一:分布式缓存)
  • SLMi331数明深力科带DESAT保护功能隔离驱动应用笔记
  • 【嵌入式Linux基础】启动初始化程序--init程序
  • 基于Java实现农产品交易平台的设计与实现_kaic
  • 视频转换、视频压缩、录屏等工具合集:迅捷视频工具箱
  • 理解时序数据库的时间线
  • 音视频技术开发周刊 | 295
  • 15稳压二级管
  • 一些零零碎碎的记录
  • MyBatis - Spring Boot 集成 MyBatis
  • 常见开源协议介绍
  • 第十九章行为型模式—中介者模式
  • AKStream部署1:ZLMediaKit流媒体服务器(win)
  • 【Redis】Redis 中地理位置功能 Geospatial 了解一下?
  • Qt Qml 实现键鼠长时间未操作锁屏
  • 常用的数字高程模型(DEM)数据介绍,附免费下载
  • 字节跳动面试挂在2面,复盘后,决定二战.....
  • 简述熔断、限流、降级
  • Maven 工具
  • iptables扩展匹配条件
  • 直播录音时准备一副监听耳机,实现所听即所得,丁一号G800S上手
  • 回归测试最小化(贪心算法,帕累托支配)
  • Python系列模块之标准库shutil详解
  • pb如何播放Flash
  • 独立成分分析ICA
  • 从零开始之如何在React Native中使用导航