当前位置: 首页 > news >正文

神经网络基础-神经网络补充概念-12-向量化逻辑回归的梯度输出

代码实现

import numpy as npdef sigmoid(z):return 1 / (1 + np.exp(-z))def compute_loss(X, y, theta):m = len(y)h = sigmoid(X.dot(theta))loss = (-1/m) * np.sum(y * np.log(h) + (1 - y) * np.log(1 - h))return lossdef compute_gradient(X, y, theta):m = len(y)h = sigmoid(X.dot(theta))gradient = X.T.dot(h - y) / mreturn gradientdef batch_gradient_descent(X, y, theta, learning_rate, num_iterations):m = len(y)losses = []for _ in range(num_iterations):gradient = compute_gradient(X, y, theta)theta -= learning_rate * gradientloss = compute_loss(X, y, theta)losses.append(loss)return theta, losses# 生成一些模拟数据
np.random.seed(42)
m = 100
n = 2
X = np.random.randn(m, n)
X = np.hstack((np.ones((m, 1)), X))
theta_true = np.array([1, 2, 3])
y = (X.dot(theta_true) + np.random.randn(m) * 0.2) > 0# 初始化参数和超参数
theta = np.zeros(X.shape[1])
learning_rate = 0.01
num_iterations = 1000# 执行批量梯度下降(向量化)
theta_optimized, losses = batch_gradient_descent(X, y, theta, learning_rate, num_iterations)# 打印优化后的参数
print("优化后的参数:", theta_optimized)# 绘制损失函数下降曲线
import matplotlib.pyplot as plt
plt.plot(losses)
plt.xlabel('迭代次数')
plt.ylabel('损失')
plt.title('损失函数下降曲线')
plt.show()

我们首先定义了 compute_gradient 函数,它计算梯度向量。然后,在 batch_gradient_descent 函数中使用向量化的梯度计算,从而避免了循环操作。
这种向量化的梯度计算方法可以有效地处理多个样本,从而提高代码的性能。

http://www.lryc.cn/news/127532.html

相关文章:

  • 2023-08-16力扣每日一题
  • 耗资170亿美元?三星电子在得克萨斯州建设新的半导体工厂
  • 黑马项目一阶段面试58题 Web14题(一)
  • 多线程与高并发--------线程池
  • 深度学习实战48-【未来的专家团队】基于AutoCompany模型的自动化企业概念设计与设想
  • 深入剖析:如何通过API优化云计算架构?快来看!
  • 基于STM32设计的中药分装系统
  • 消息队列学习笔记
  • 贝锐蒲公英:助力企业打造稳定高效的智能安防监控网络
  • SASS 学习笔记
  • Web菜鸟教程 - Springboot接入认证授权模块
  • 【深入理解ES6】块级作用域绑定
  • 使用fake为数据库生成随机数据
  • 树结构转List
  • Android复习(Android基础-四大组件)——Broadcast
  • Ubuntu下mysql8开启远程连接
  • java对象和json类型转换
  • elasticsearch-head 插件
  • Neo4j之FOREACH基础
  • 【SpringBoot】| 接口架构风格—RESTful
  • CentOS系统环境搭建(十)——CentOS7定时任务
  • 如何在安卓设备上安装并使用 ONLYOFFICE 文档
  • 【制作npm包1】申请npm账号、认识个人包和组织包
  • linux学习(文件描述符)[11]
  • 影响力再度提升,Smartbi多次蝉联Gartner、IDC等权威认可
  • 【动态map】牛客挑战赛67 B
  • mysql(2)
  • 介绍 Apache Spark 的基本概念和在大数据分析中的应用
  • Vue CLI创建Vue项目详细步骤
  • 机器学习算法之-逻辑回归(2)