当前位置: 首页 > news >正文

逻辑斯特回归

*分类是离散的,回归是连续的

下载数据集

train=True:下载训练集

逻辑斯蒂函数保证输出值在0-1之间

能够把实数值映射到0-1之间

 导函数类似正态分布

 其他饱和函数sigmoid functions

循环神经网络经常使用tanh函数

与线性回归区别

塞戈马无参数,构造函数无区别

 更改损失函数MSE->BCE损失(越小越好)

分布的差异:KL散度,cross-entropy交叉熵 

二分类的交叉熵

 

# -*- coding: utf-8 -*-
# @Time    : 2023-07-18 20:26
# @Author  : yuer
# @FileName: exercise06.py
# @Software: PyCharm
import matplotlib.pyplot as plt
import numpy as np
import torch# 数据集
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])# 先根据x算出y值再根据y的范围找到分类class logisticRegressionModel(torch.nn.Module):def __init__(self):super(logisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# x_data,y_data都是一维,与线性回归相比构造没有函数区别def forward(self, x):y_pred = torch.sigmoid(self.linear(x))return y_predmodel = logisticRegressionModel()# 默认情况size_average=True 即loss是1/n倍的,False设置loss不除n
criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# SGD梯度下降优化方法 初始化w,b都为0for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad()  # 清空梯度loss.backward()  # 反馈算梯度并更新optimizer.step()  # 更新w,b的值print('w=', model.linear.weight.item())
print('b=', model.linear.bias.item())x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred=', y_test.data)x = np.linspace(0, 10, 200)  # 在线性空间中以均匀步长生成数字序列;在0-10之间的200个点
x_t = torch.Tensor(x).view((200, 1))  # 转换为200*1的矩阵
y_t = model(x_t)  # 利用模型训练
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

http://www.lryc.cn/news/98684.html

相关文章:

  • OpenCV 算法解析
  • springboot创建并配置环境(一) - 创建环境
  • 2023JAVA 架构师面试 130 题含答案:JVM+spring+ 分布式 + 并发编程》...
  • layui手机端上传文件时返回404 Not Found的解决方案(client_body_temp权限设置)
  • 网络编程知识
  • 线性神经网路——线性回归随笔【深度学习】【PyTorch】【d2l】
  • js实现多种按钮
  • getopt函数(未更新完)
  • SpringCloud学习路线(9)——服务异步通讯RabbitMQ
  • postcss-pxtorem适配插件动态配置rootValue(根据文件路径名称,动态改变vue.config里配置的值)
  • 代码随想录算法训练营第二十三天 | 额外题目系列
  • UiAutomator
  • stm32标准库开发常用函数的使用和代码说明
  • 有关合泰BA45F5260中断的思考
  • Numpy-算数函数与数学函数
  • Nginx在springboot中起到的作用
  • 12.(开发工具篇vscode+git)vscode 不能识别npm命令
  • 如何在MacBook上彻底删除mysql
  • web攻击面试|网络渗透面试(一)
  • VBA操作WORD(六)另存为不含宏的文档
  • 分享69个Java源码,总有一款适合您
  • 《cool! autodistill帮你标注数据训练yolov8模型》学习笔记
  • Rust vs Go:常用语法对比(十)
  • SliverPersistentHeader组件 实现Flutter吸顶效果
  • Nginx性能优化配置
  • 杭州多校2023“钉耙编程”中国大学生算法设计超级联赛(4)
  • 音视频入门之音频采集、编码、播放
  • 在 Linux 系统中,如何发起POST/GET请求
  • 文心一言大数据模型-文心千帆大模型平台
  • django学习笔记(1)