当前位置: 首页 > news >正文

深度学习之用PyTorch实现逻辑回归

0.1 学习视频源于:b站:刘二大人《PyTorch深度学习实践》

0.2 本章内容为自主学习总结内容,若有错误欢迎指正!

代码(类比线性回归):

# 调用库
import torch
import torch.nn.functional as F# 数据准备
x_data = torch.Tensor([[1.0], [2.0], [3.0]])  # 训练集输入值
y_data = torch.Tensor([[0], [0], [1]])  # 训练集输出值# 定义逻辑回归模型
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()  # 调用父类构造函数self.linear = torch.nn.Linear(1, 1)  # 实例化torch库nn模块的Linear类,特征一维,输出一维def forward(self, x):"""前馈运算:param x: 输入值:return: 线性回归预测结果"""y_pred = F.sigmoid(self.linear(x))return y_predmodel = LogisticRegressionModel()  # 实例化criterion = torch.nn.BCELoss(size_average=False)  # 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 优化器——梯度下降SGD# 训练过程
for epoch in range(1000):  # epoch:训练轮次y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad()  # 梯度归零loss.backward()  # 反向传播optimizer.step()  # 权重自动更新print("w = ", model.linear.weight.item())
print("b = ", model.linear.bias.item())# 预测过程
x_test = torch.Tensor([[3.5]])
y_test = model(x_test)
print("y_pred = ", y_test.data)

BCEloss:

 

结果:

注:输出结果为类别是1的概率。

http://www.lryc.cn/news/118575.html

相关文章:

  • 04-4_Qt 5.9 C++开发指南_时间日期与定时器
  • 7个顶级开源数据集来训练自然语言处理(NLP)和文本模型
  • 计算机网络 网络层 边界网关协议BGP
  • GitHub上受欢迎的Android UI Library
  • cpm log2((cpm/10) + 1) nmf 1e6 1e5
  • 竞赛项目 深度学习的视频多目标跟踪实现
  • 如何避免用waveformRecord复制数组
  • RocketMQ 延迟消息
  • Dex文件混淆(一):BlackObfuscator
  • Linux下编译arm 32 出错(/bin/bash: arm-none-linux-gnueabi-gcc: command not found )
  • 最近遇到的两个小问题总结:git问题和node问题
  • Java # Spring(1)
  • SCL更换阿里数据源
  • 【web逆向】全报文加密流量的去加密测试方案
  • Django实现音乐网站 ⑼
  • 【脚踢数据结构】
  • uni-app使用vue语法进行开发注意事项
  • 数据结构---B树
  • c++11以后c++标准库定义的固定位宽的整数类型(Fixed width integer types)
  • Object.values()
  • Oracle 开发篇+Java调用OJDBC访问Oracle数据库
  • linux 查询后台任务及杀掉进程
  • 【Vue3 博物馆管理系统】使用Vue3、Element-plus菜单组件构建前台用户菜单
  • Windows 11清除无效、回收站、过期、缓存、补丁更新文件
  • 栈和队列详解(2)
  • EMC传导干扰滤波电路设计
  • 【win10专业版远程控制】 自带远程桌面公司内网电脑
  • Ubuntu 20.04 中安装docker一键安装脚本
  • Mysql之安装-字符集设置-用户及权限操作-sqlmode设置
  • 腾讯云香港服务器租用价格_CN2线路延迟速度测试