当前位置: 首页 > news >正文

机器学习拟合过程

import numpy as np
import matplotlib.pyplot as plt# 步骤1: 生成模拟数据
np.random.seed(0)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + 2 * X**2 + np.random.randn(100, 1)# 步骤2: 定义线性模型 (我们从随机权重开始)
w = np.random.randn(2, 1)
b = np.random.randn(1)# 步骤3: 定义损失函数 (均方误差)
def mse(y_pred, y_true):return ((y_pred - y_true) ** 2).mean()# 步骤4: 使用梯度下降来优化模型参数
learning_rate = 0.01
epochs = 100# 记录每次迭代的损失,用于画图
losses = []# 准备一个颜色列表,用于绘制不同颜色的直线
colors = plt.cm.rainbow(np.linspace(0, 1, epochs))fig, ax = plt.subplots()# 绘制数据点
ax.scatter(X, y, label="Data")# 扩展X值以绘制连续的线条
X_plot = np.linspace(X.min(), X.max(), 100).reshape(-1, 1)
X_plot_ext = np.hstack((X_plot, X_plot**2))for epoch in range(epochs):# 扩展特征X_ext = np.hstack((X, X**2))# 预测y_pred = X_ext @ w + b# 计算损失loss = mse(y_pred, y)losses.append(loss)# 梯度下降dy_pred = 2 * (y_pred - y)dw = X_ext.T @ dy_pred / len(X_ext)db = dy_pred.sum(axis=0) / len(X_ext)# 更新权重和偏置w -= learning_rate * dwb -= learning_rate * db# 绘制当前拟合的直线,使用不同的颜色y_plot_pred = X_plot_ext @ w + bax.plot(X_plot, y_plot_pred, color=colors[epoch], label=f"Epoch {epoch + 1}")# 设置图例和标签
ax.set_xlabel("X")
ax.set_ylabel("y")
ax.legend()plt.show()# 绘制损失下降的曲线
plt.plot(losses)
plt.title("Loss over epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

http://www.lryc.cn/news/462618.html

相关文章:

  • 如何快速部署一套智能化openGauss测试环境
  • 【设计模式】深入理解Python中的原型设计模式
  • Django CORS配置方案
  • 2024年开放式耳机哪个牌子好?推荐最好的顶级开放式耳机品牌
  • 零基础读懂Stable Diffusion!
  • Hash Join 和 Index Join工作原理和性能差异
  • Apifox简介及使用
  • 十、IPD 实施细节(产品设计与开发管理)
  • MySQL-13.DQL-聚合函数
  • 为什么跟别人学习如何证明定理要远比使用定理更有意义
  • Qt在Win,Mac和Linux的开机自启设置
  • spring boot热部署
  • 网关与蓝牙网关有什么不同之处?
  • JAVA计算双十一多产品实付款优惠券的省钱方案
  • 零售行业的数字化营销转型之路
  • js的for in 和 for of的详解
  • 前端工具函数库
  • Java程序设计:Spring boot(4)——Freemarker Thymeleaf视图技术集成
  • JavaScript 第19章:Web Storage
  • [山河2024] week2
  • 无限可能LangChain——开启大模型世界
  • URL路径以及Tomcat本身引入的jar包会导致的 SpringMVC项目 404问题、Tomcat调试日志的开启及总结
  • 如何引起Java中的System.in.read()函数的异常
  • 深入理解Flutter鸿蒙next版本 中的Widget继承:使用extends获取数据与父类约束
  • Loss:Focal Loss for Dense Object Detection
  • Unity3D中Excel表格的数据处理模块详解
  • 【python】OpenCV—Fun Mirrors
  • QT IEEE754 16进制浮点数据转成10进制
  • 无人机+视频推流直播EasyCVR视频汇聚/EasyDSS平台在森林防护巡检中的解决方案
  • Rancher—多集群Kubernetes管理平台