当前位置: 首页 > news >正文

python梯度下降法求解三元线性回归系数,并绘制结果

import numpy as np
import matplotlib.pyplot as plt

# 生成随机数据
np.random.seed(0)
X1 = 2 * np.random.rand(100, 1)
X2 = 3 * np.random.rand(100, 1)
X3 = 4 * np.random.rand(100, 1)
y = 4 + 3 * X1 + 5 * X2 + 2 * X3 + np.random.randn(100, 1)

# 合并特征
X_b = np.hstack([np.ones((100, 1)), X1, X2, X3])

# 梯度下降求解多元线性回归系数
eta = 0.1  # 学习率
n_iterations = 1000  # 迭代次数
m = 100  # 样本数

theta = np.random.randn(4, 1)  # 初始化参数

for iteration in range(n_iterations):
    gradients = 2/m * X_b.T.dot(X_b.dot(theta) - y)
    theta -= eta * gradients

# 打印得到的参数
print("得到的参数为:", theta)

# 绘制结果
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# 绘制原始数据点
ax.scatter(X1, X2, y, c='b', marker='o')

# 生成新数据点
X1_new = np.linspace(0, 2, 100)
X2_new = np.linspace(0, 3, 100)
X1_new, X2_new = np.meshgrid(X1_new, X2_new)
X3_new = (-theta[0] - theta[1] * X1_new - theta[2] * X2_new) / theta[3]

# 绘制平面
ax.plot_surface(X1_new, X2_new, X3_new, alpha=0.5)

ax.set_xlabel('X1')
ax.set_ylabel('X2')
ax.set_zlabel('y')

plt.show()
 

http://www.lryc.cn/news/352818.html

相关文章:

  • Linux基础(五):常用基本命令
  • 原始字面常量(C++11)
  • C++|设计模式(〇)|设计模式的六大原则
  • 【排序算法】——归并排序(递归与非递归)含动图
  • Mysql自增id、uuid、雪花算法id的比较
  • 【会议征稿,IEEE出版】第九届信息科学、计算机技术与交通运输国际学术会议(ISCTT 2024,6月28-30)
  • 二十八篇:嵌入式系统实战指南:案例研究与未来挑战
  • 探索编程乐趣:绘制螺旋图的奇幻之旅
  • C# 语法糖
  • ubuntu 安装VMtool 实现复制粘贴
  • 智慧仓储新动力:EasyCVR+AI视频智能监管系统方案助力仓储安全高效管理
  • gcc源码分析(AST抽象语法树)
  • ES基础概念
  • 断更是我的错
  • 红队攻防渗透技术实战流程:云安全之云原生安全:云堡垒机
  • Down with typename
  • CSS3背景与渐变
  • 线性表——链式存储
  • VUE3和VUE2
  • mysql5.5版本安装过程
  • 工厂生产管理系统
  • Atlas 200I DK A2安装MindSpore Ascend版本
  • Go 生成UUID唯一标识
  • 【知识蒸馏】deeplabv3 logit-based 知识蒸馏实战,对剪枝的模型进行蒸馏训练
  • 02.爬虫---HTTP基本原理
  • HTTP响应的基本概念
  • 链栈的存储
  • 常见网络协议及端口号
  • 几张自己绘制的UML图
  • [读论文]精读Self-Attentive Sequential Recommendation