当前位置: 首页 > article >正文

梯度下降法求解线性回归之python实现

线性回归其实就是寻找一条直线拟合数据点,使得损失函数最小。直线的表达式为:

yi=ω1xi,1+ω2xi,2+ωjxi,j+...+b

损失函数的表达式为:
J=12i=0m(yiypredict_i)2

其中m为数据点总数。
现在我们使用梯度下降法求解函数 J 的最小值,梯度下降法原理示意图如下:

这里写图片描述
如上图所示,只要自变量x沿着负梯度的方向变化,就可以到达函数的最小值了,反之,如果沿着正梯度方向变化,就可以到达函数的最大值。
我们要求解 J 函数的最小值,那么就要求出每个ω的梯度和 b 的梯度,由于梯度太大,可能会导致自变量沿着负梯度方向变化时,J的值出现震荡,而不是一直变小,所以在梯度的前面乘上一个很小的系数 α
由以上可以总结出 ω b 的更新公式:

ωj=ωjαJ(ωj)

b=bαJ(b)

梯度公式(其实就是求导而已):
J(ωj)=Jωj=i=0m(yiypredict_i)(xi,j)=i=0m(ypredict_iyi)xi,j

J(b)=Jb=i=0m(ypredict_iyi)

系数 α 如果随着迭代的进行越来越小的话,有利于防止迭代后期震荡的发生,是算法收敛, α 的更新公式:
α=1i+1+0.001

其中i是迭代次数,起始为0
下面为使用python具体实现梯度下降法求解线性回归
原始数据:

x = np.arange(-2,2,0.1)
y = 2*x+np.random.random(len(x))
x = x.reshape((len(x),1))
y = y.reshape((len(x),1))

这里写图片描述

开始迭代:

for i in range(maxgen):alpha = 1/float(i+1)+alpha0e = np.dot(x,seta.reshape((len(seta),1)))+b-y # 二维列向量mse = np.linalg.norm(e)delta_seta = np.dot(e.T,x)[0]delta_seta_norm = np.linalg.norm(delta_seta)b = b-alpha*np.sum(e)seta = seta-alpha*delta_setaprint u'迭代次数:',iprint u'梯度:',delta_seta_norm,'seta',seta,'b:',b,'mse',mseprint 'alpha:',alpha,'sum(e):',sum(e)

算法运行结果:
这里写图片描述


这里写图片描述
如上图所示,最后梯度的值逐渐降为0,说明达到的 J <script type="math/tex" id="MathJax-Element-20">J</script>的极值点。

http://www.lryc.cn/news/2419310.html

相关文章:

  • ASP.NET运行环境配置
  • 云诺网盘为什么关停了好用的企业网盘有哪些
  • 使用vAPP管理资源
  • Unity - Shader - Projector 高空云层底下透明阴影 - semitransparent shadow
  • Linux 串口RS232/485/GPS 驱动实验(移植minicom)
  • MTK 平台屏蔽 factory mode
  • Redis可视化工具Windows版 Another Redis Desktop Manager 安装与使用_保姆级别
  • 多益网络,面试智商测试题
  • 如果人生太难,就去医院看看
  • Synchronized、lock、volatile、ThreadLocal、原子性总结、Condition
  • 内与外的困惑:找出System进程占用100%CPU的元凶
  • GIS空间分析(四)—— 空间分布类型
  • 如何取消标题栏
  • windows清理系统垃圾bat脚本
  • linux安装zend,linux安装配置Zend Optimizer详解
  • PowerSyncKM 包尔星克 对拷线无法自动链接windows和统信系统
  • GRUB的配置文件的menu.lst的写法(旧版grub)
  • 局域网、以太网、无线局域网学习笔记
  • 关于司南导航全系概况模糊学习记录
  • 数据结构和算法(38)之八皇后问题
  • 优秀网站博客集锦
  • 何为Turing Machine(图灵机)?
  • 手机中的com.android.provision删除可不可以,Android Provision (Setup Wizard)
  • windows7怎么一键重装系统 电脑重装操作系统Windows7
  • SSL连接建立过程分析(1)
  • MATLAB|基于QPSO-LSTM的短期风电负荷预测模型(完全复现)
  • java 环境变量的配置
  • Lucene(8_2_0)核心API学习 之 TokenStream(一)
  • Javafx程序开发-如何制作exe程序及制作安装包
  • arp病毒查杀:手动查杀ARP病毒