当前位置: 首页 > news >正文

梯度下降法、模拟训练、拟合二次曲线、最小二乘法、MSELoss、拟合:f(x)=ax^2+bx+c

本文目标:

f(x)=a*x^2+b*x+c

以这个公式为例,设计一个算法,用梯度下降法来模拟训练过程,最终得出参数a,b,c

原理介绍

目标函数:h_{\theta}(x) = a^{2}+bx+c

损失函数:Loss=\frac{1}{2m}\sum_{1}^{m}(h_{\theta}(x^{i})-y^{i})^2,就是mse

损失函数展开:Loss=\frac{1}{2m}\sum_{1}^{m}((ax^{i})^{2}+bx^i+c-y^{i})^2

损失函数对a,b,c求导数:

{L_{a}}^{'}=\frac{1}{m}\sum_{1}^{m}(ax^2+bx+c-y)*x^2

{L_{b}}^{'}=\frac{1}{m}\sum_{1}^{m}(ax^2+bx+c-y)*x

{L_{c}}^{'}=\frac{1}{m}\sum_{1}^{m}(ax^2+bx+c-y)

导数就是梯度,也就是目标参数与当前参数的差异,这个差异需要用梯度下降法更新

\Delta a{L_{a}}^{'}   \Delta b={L_{b}}^{'}     \Delta c={L_{c}}^{'}

a = a - lr*\Delta a

b = b - lr*\Delta b

c = c - lr*\Delta c

重复上面的过程,参数就可以更新了,然后可以看看新参数的效果,也就是损失有没有降低

具体流程

  1. 预设模型的表达式为:f(x)=a*x^2+b*x+c,也就是二次函数。同时随机初始化模型参数a,b,c。如果是其他函数如f(x)=ax^3+bx^2+cx,就无法在本版本适用(修改求导方式后才可用)。即本模型需要提前知道模型的表达式。
  2. 通过不断喂入(x_input,y_true),得出y_{out} = ax_{input}^{2}+bx_{input}+c.而y_out与y_true之间具有差异。
  3. 将差异封装成一个loss函数,并分别对a,b,c进行求导。得到a,b,c的梯度\Delta a\Delta b\Delta c
  4. \Delta a\Delta b\Delta c和原始的参数a,b,c和学习率作为输入,用梯度下降法来对a,b,c参数进行更新.
  5. 重复2,3,4过程。直到训练结束或者loss降低到较小值

python实现

  # 初始化a,b,c为:-11/6 , -395/3,-2400  目标a,b,c为:(2,-4,3)

class QuadraticFunc():def drew(self,w,name="show"):a,b,c = wx1 = np.array(range(-80,80))y1 = a*x1*x1 + b*x1 + cy2 = 2*x1*x1 - 4*x1 + 3plt.clf()plt.plot(x1, y1)plt.plot(x1, y2)plt.scatter(x1, y1, c='r')# set colorplt.xlim((-50,50))plt.ylim((-500,500))plt.xlabel('X Axis')plt.ylabel('Y Axis')if name == "first":plt.pause(3)else:plt.pause(0.01)plt.ioff()#计算lossdef cal_loss(self,y_out,y_true):# return np.dot((y_out - y_true),(y_out - y_true)) * 0.5return np.mean((y_out - y_true)**2)#计算梯度  def cal_grad(self,x,y_out,y_true):# x(batch),y_out(batch),y_true(batch)a_grad = (y_out-y_true)*x**2 #b_grad = (y_out-y_true)*xc_grad = (y_out-y_true)return np.array([np.mean(a_grad),np.mean(b_grad),np.mean(c_grad)])        #梯度下降法更新参数def update_theta(self,step,w,grad):new_w = w - step*gradreturn new_wdef run(self):feed_x = np.array(range(-400,400))/400feed_y = 2*feed_x*feed_x - 4*feed_x + 3step = 0.5base_lr = 0.5lr = base_lr# 初始化参数a,b,c = -11/6 , -395/3,-2400#-1,10,26w = np.array([a,b,c])self.drew(w,"first")epochs = 100for epoch in range(epochs):# 每隔10轮 降低一半的学习率lr = base_lr/(2**(int((epoch+1)/10)))for i in range(len(feed_x)):x_input = feed_x[i]y_true = feed_y[i]y_out = w[0]*x_input*x_input +w[1]*x_input + w[2]#计算lossloss = self.cal_loss(y_out,y_true)#计算梯度grad = self.cal_grad(x_input,y_out,y_true)#更新参数,梯度下降w = self.update_theta(lr,w,grad)# self.drew(w)grad = np.round(grad,2)loss = np.round(loss,2)w = np.round(w,2)print("train times is:",epoch,"  grad is:",grad,"   loss is:","%.4f"%loss, "  w is:",w,"\n")self.drew(w)if loss<1e-5:print("train finish:",w)breakdef run_batch(self):feed_x = np.array(range(-400,400))/400feed_y = 2*feed_x*feed_x - 4*feed_x + 3x_y = [[feed_x[i],feed_y[i]] for i in range(len(feed_x))]base_lr = 0.5lr = base_lr# 初始化参数a,b,c = -11/6 , -395/3,-2400#-1,10,26w = np.array([a,b,c])self.drew(w,"first")batch_size = 16data_len = len(x_y)//batch_sizeepochs = 100for epoch in range(epochs):random.shuffle(x_y)# 每隔10轮 降低一半的学习率lr = base_lr/(2**(int((epoch+1)/10)))print("epoch,lr:",epoch,lr)for i in range(data_len):x_y_list = x_y[i*batch_size:(i+1)*batch_size]x_y_np = np.array(x_y_list)x_input = x_y_np[:,0]y_true = x_y_np[:,1]y_out = w[0]*x_input*x_input +w[1]*x_input + w[2]#计算lossloss = self.cal_loss(y_out,y_true)#计算梯度grad = self.cal_grad(x_input,y_out,y_true)#更新参数,梯度下降w = self.update_theta(lr,w,grad)grad = np.round(grad,2)loss = np.round(loss,2)w = np.round(w,2)print("train times is:",epoch,"  grad is:",grad,"   loss is:","%.4f"%loss, "  w is:",w,"\n")self.drew(w)if loss<1e-5:print("train finish:",w)# breaktime.sleep(0.1)if __name__ == "__main__":qf = QuadraticFunc()qf.run()

http://www.lryc.cn/news/288054.html

相关文章:

  • Web3.0投票如何做到公平公正且不泄露个人隐私
  • 灰度图像的自动阈值分割
  • 利用Maven获取jar包
  • 将vue组件发布成npm包
  • 江科大STM32 中
  • vue+draggable+el-upload上传图片拖拽重排方法
  • 微信的新版canvas绘制的图案发生变形和偏移的问题
  • [ACM学习] 进制转换
  • redis + 拦截器 :防止数据重复提交
  • 如何进行H.265视频播放器EasyPlayer.js的中性化设置?
  • Ubuntu22.04安装4090显卡驱动
  • YOLOv8优化策略:注意力涨点系列篇 | 一种轻量级的加强通道信息和空间信息提取能力的MLCA注意力
  • 【新书推荐】2.5节 有符号整数和无符号整数
  • RT-Thread: 串口操作、增加串口、串口函数
  • 自然语言处理的新突破:如何推动语音助手和机器翻译的进步
  • vue3 + jeecgBoot 获取项目IP地址
  • Java Server-Sent Events通信
  • [蓝桥杯]真题讲解:冶炼金属(暴力+二分)
  • Fastbee开源物联网项目RoadMap
  • Linux文件管理技术实践
  • Python如何按指定列的空值删除行?
  • 【云原生】Docker的镜像创建
  • 大语言模型推理提速:TensorRT-LLM 高性能推理实践
  • 全面理解“张量”概念
  • MacOS X 安装免费的 LaTex 环境
  • 深入Amazon S3:实战指南
  • Ansible自动化运维(三)Playbook 模式详解
  • LCS板子加逆向搜索
  • 不同知识表示方法与知识图谱
  • Kotlin程序设计 扩展篇(一)