当前位置: 首页 > news >正文

pytorch深度学习逻辑回归 logistic regression

# logistic regression 二分类
# 导入pytorch  和 torchvision
import numpy as np
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as pltx_data = torch.tensor([[1.0], [2.0], [3.0]])  # x_data是一个张量
y_data = torch.Tensor([[0], [0], [1]])  # Tensor是一个类,tesor是一个张量# 定义logistic regression模型
class LogisticRegressionModel(nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()  # 等价于nn.Module.__init__(self)self.linear = nn.Linear(1, 1)  # 输入和输出的维度都是1def forward(self, x):  # forward函数是必须要有的,用来构建计算图# 二分类问题,所以用sigmoid函数作为激活函数y_pred = torch.sigmoid(self.linear(x))  # forwardreturn y_predmodel = LogisticRegressionModel()  # 实例化一个模型
criterion = nn.BCELoss(size_average=False)  # 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 优化器 lr为学习率# 训练模型
for epoch in range(100):  # 训练100次y_pred = model(x_data)  # forwardloss = criterion(y_pred, y_data)  # compute lossprint(epoch, loss.item())  # 打印lossoptimizer.zero_grad()  # 梯度清零loss.backward()  # backwardoptimizer.step()  # update# 测试模型
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print("predict (after training)", y_test.data)  # 预测# 绘制训练次数和预测值的关系
x = np.linspace(0, 10, 200)  # 从0到10均匀取200个点
x_t = torch.Tensor(x).view(200, 1)  # 转换成200行1列的张量 用Tensor是因为要用到torch.sigmoid
y_t = model(x_t)  # 预测
y = y_t.data.numpy()  # 转换成numpy数组
plt.plot(x, y)  # 绘制预测值和x的关系
plt.plot([0, 10], [0.5, 0.5], c='r')  # 绘制y=0.5的直线
plt.xlabel("Hours")  # x轴标签
plt.ylabel("Probability of Pass")  # y轴标签
plt.grid()  # 绘制网格
plt.show()  # 显示图像

结果

 

http://www.lryc.cn/news/95039.html

相关文章:

  • 数据仓库建设-数仓分层
  • 共享与协作:时下最热门的企业共享网盘推荐!
  • mysql取24小时数据
  • TCP/IP网络编程 第十五章:套接字和标准I/O
  • SaleSmartly,客户满意度调查的绝对好助手
  • MySQL高阶语句
  • 手机快充协议
  • centos 7升级gcc到10.5.0
  • 从脚手架搭建到部署访问路程梳理
  • 数据库应用:MySQL数据库SQL高级语句与操作
  • xshell连接WSL2
  • Flask新手教程
  • 拼多多API接口,百亿补贴商品详情页面采集
  • C++入门(未完待续)
  • Python爬虫学习笔记(四)————XPath解析
  • 知识图谱推理的学习逻辑规则(上)
  • 【从零开始学习C++ | 第二十一篇】C++新增特性 (上)
  • 你真的会用async和await么?
  • vscode远程连接提示:过程试图写入的管道不存在(删除C:\Users\<用户名>\.ssh\known_hosts然后重新连接)
  • 【005】基于深度学习的图像语 通信系统
  • 基于ssm的社区生活超市的设计与实现
  • 长短期记忆网络(LSTM)原理解析
  • vscode debug的方式
  • 微信加粉计数器后台开发
  • 黑客是什么?想成为黑客需要学习什么?
  • iOS中__attribute__的使用
  • 腾讯、飞书等在线表格自动化编辑--python
  • 开源库nlohmann json使用备忘
  • 语音识别开源框架 openAI-whisper
  • php做的中秋博饼游戏之绘制骰子图案功能示例