当前位置: 首页 > news >正文

【Deep-ML系列】Linear Regression Using Gradient Descent(手写梯度下降)

题目链接:Deep-ML

这道题主要是要考虑矩阵乘法的维度,保证维度正确,就可以获得最终的theata

import numpy as np
def linear_regression_gradient_descent(X: np.ndarray, y: np.ndarray, alpha: float, iterations: int) -> np.ndarray:"""Linear regression:param X: m * n:param y::param alpha::param iterations::return:"""m, n = X.shapetheta = np.zeros((n, 1))y = y.reshape(m, 1)     # 保证y是列向量for i in range(iterations):prediction = np.dot(X, theta)   # m * 1error = prediction - y          # m * 1gradient = np.dot(X.T, error)   # n * 1theta = theta - alpha * (1 / m) * gradienttheta = np.round(theta, decimals=4)return thetaif __name__ == '__main__':X = np.array([[1, 1], [1, 2], [1, 3]])y = np.array([1, 2, 3])alpha = 0.01iterations = 1000print(linear_regression_gradient_descent(X, y, alpha, iterations))
http://www.lryc.cn/news/418655.html

相关文章:

  • NVIDIA A100 和 H100 硬件架构学习
  • 企业研发设计协同解决方案
  • iOS 18(macOS 15)Vision 中新增的任意图片智能评分功能试玩
  • 如何实现若干子任务一损俱损--浅谈errgroup
  • 并查集的基础题
  • [论文翻译] LTAChecker:利用注意力时态网络基于 Dalvik 操作码序列的轻量级安卓恶意软件检测
  • HTTPS链接建立的过程
  • 文档控件DevExpress Office File API v24.1 - 支持基于Unix系统的打印
  • IP地址封装类(InetAddress类)
  • 数据库设计规范化
  • 预约咨询小程序搭建教程,源码获取,从0到1完成开发并部署上线
  • leetcode217. 存在重复元素,哈希表秒解
  • QT:QString 支持 UTF-8 编码吗?
  • 我主编的电子技术实验手册(13)——电磁元件之继电器
  • odoo from样式更新
  • Oracle(52)分区表有哪些类型?
  • 大黄蜂能飞的起来吗?
  • 虹科新品 | PDF记录仪新增蓝牙®接口型号HK-LIBERO CL-Y
  • Bytebase 2.22.1 - SQL 编辑器展示更丰富的 Schema 信息
  • SQL Server Management Studio的使用
  • Python 爬虫项目实战一:抖音视频下载与网易云音乐下载
  • CAMDS=中国汽车MDS
  • 【Golang 面试 - 进阶题】每日 3 题(十七)
  • ROS 7上实现私网互通方案
  • iOS企业签名过程中APP频繁出现闪退是什么原因?
  • Unity dots IJobParallelFor并行的数据写入问题
  • 媒体资讯视频数据采集-yt-dlp-python实际使用-下载视频
  • MySQL 8
  • Android进阶之路 - app后台切回前台触发超时保护退出登录
  • 论文阅读笔记:Semi-supervised Semantic Segmentation with Error Localization Network