当前位置: 首页 > news >正文

吴恩达2022机器学习专项课程(一) 第一周课程实验:模型表示(Lab_03)

目标

学习如何使用一个变量实现线性回归模型。

导入需要的库

在这里插入图片描述

存储特征x和目标变量y

这是真实的训练集,[1.0,2.0]是房子的大小,[300,500]是房子的价格。
在这里插入图片描述
使用数组存储训练集的数据:
在这里插入图片描述

  • x_train:存储的是所有特征,[1.0,2.0],也就是房子的大小。
  • y_train:存储的是所有目标变量,[300,500],也就是是房子的价格。

获取训练样本的数量

由于我们要计算每一行训练样本的预测值,所以要知道一共有多少行训练样本,也就是求出m的值。
在这里插入图片描述

  • shape[0]:查看特征数组里有几个特征,示例中有2个特征,代表2行训练样本,因此m=2。
  • len:查看特征数组长度,数组里有2个特征,因此长度为2,也就是m=2。

获取每一组训练样本

在这里插入图片描述

  • x_train[i]/y_train[i]:查看第i行训练样本的特征或目标变量。

绘制训练集的数据点

把训练集的每行训练样本,以数据点的形式绘制在图表里。
在这里插入图片描述

初始化w和b

根据线性回归的函数,我们需要先知道w和b的值,这里先不讨论w和b的计算过程,直接给出一个初始值。
在这里插入图片描述

计算线性回归的预测值

每一行的训练样本都需要计算出一个预测值,因此m的作用体现出来了,用于循环。
在这里插入图片描述

  • f_wb = np.zeros(m):为了方便存储每一行训练样本的预测值,因此需要创建一个初始值为0,元素数量为m个的数组。示例中的m=2,它就是这个样子:【0. 0.】
  • w*x[i]+b:线性回归的计算公式,x[i]表示第一行训练样本的特征。
  • f_wb[i]:存储第i-1行训练样本的预测值。

绘图线性回归预测值的图表

代码
在这里插入图片描述
蓝色线条是这段代码绘制的 plt.plot(x_train, tmp_f_wb, c=‘b’, label=‘Our Prediction’)。
在这里插入图片描述
输出结果含义:蓝色线条没有拟合数据点,也就是说,模型预测的y值和真实的数据点y值差距很大。因为我们设置的w和b不合适。

尝试不同的w和b

由于给出了正确的w和b,线性回归模型才能够完美预测,但如果使用不同的w和b,例如更换成w=200和b=100,结果如何?
在这里插入图片描述
先修改w和b
在这里插入图片描述
-输出结果在这里插入图片描述
-输出结果含义:按照我们初始化好的w和b的值,通过给定的x[1,2]和线性回归函数,计算出的预测值是准确的(预测值y帽[300,500]和训练集的y[300,500]),也就是说,我们的线性回归模型能够完美预测。

使用模型示例

经过上述的调整,我们找到了合适的w和b,因此我们的模型就可以用来预测房价了。在这里插入图片描述

总结

首先我们将训练集的数据标记在图表上,然后要找到线性回归函数中合适的w和b。

设置到合适的w和b,我们需要计算每行训练样本的预测值。

通过计算好的预测值,我们要将线性回归函数的趋势绘制在图表中,通过是否能拟合训练集的数据点,来检测w和b是否合适。

由此可以看出,线性回归中比较重要的是要找到合适的w和b,才能比较完美的预测出靠近真实数据的预测值y帽。

这也引出了后面的课程,如何判断w和b是否合适?

http://www.lryc.cn/news/325096.html

相关文章:

  • 流畅的 Python 第二版(GPT 重译)(十)
  • 【自然语言处理七-经典论文-attention is all you need】
  • 【嵌入式】STM32和I2C通信
  • 如何使用Harmony OS控制外设——输入输出?
  • 1.1-数组-704. 二分查找★
  • 人物百度百科怎么做?需要什么资料?
  • 在基于Android相机预览的CV应用程序中使用 OpenCL
  • 网络分类简述与数据链路层协议(PPP)
  • Linux文件系列:磁盘,文件系统,软硬链接
  • GPT4.0
  • 软件工程(双语)
  • 网络——套接字编程UDP
  • FPGA_AD9361
  • 探讨Java代码混淆加固工具
  • Linux——du, df命令查看磁盘空间使用情况
  • 数据库实验(一)SQL Server触发器
  • 添加网址到主页
  • 消息中间件如何实现高可用
  • Hbase 王者荣耀数据表 HBase常用Shell命令
  • 家用智能洗地机哪个牌子好?4款型号让你解锁高效省力生活体验
  • Linux--进程(1)
  • Qt登录页面
  • 软件工程-第8章 软件测试
  • 专业135+总分400+重庆邮电大学801信号与系统考研经验重邮电子信息与通信工程,真题,大纲,参考书。
  • 主干网络篇 | YOLOv8改进之在主干网络中引入密集连接卷积网络DenseNet
  • lavarel的php程序是顺序执行,用pdo mysql连接池好像没有什么用啊。没有办法挂起等待啊,为什么要用连接池,应用场景是什么
  • spring maven项目 实时接口请求次数及时间发送到grafana监控_亲测成功
  • 银行数字人民币系统应用架构设计
  • 流畅的 Python 第二版(GPT 重译)(三)
  • 06-验证浮点数输入