当前位置: 首页 > news >正文

深度学习_4_实战_直线最优解

梯度
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

实战

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

代码:

# %matplotlib inline
import random
import torch
import matplotlib.pyplot as plt
# from d21 import torch as d21def synthetic_data(w, b, num_examples):"""生成 Y = XW + b + 噪声。"""X = torch.normal(0, 1, (num_examples, len(w)))# 均值为0,方差为1的随机数,n个样本,列数为w的长度y = torch.matmul(X, w) + b # y = x * w + by += torch.normal(0, 0.01, y.shape) # 加入随机噪音,均值为0.。形状与y的一样return X, y.reshape((-1, 1))# x, y做成列向量返回true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
#读取小批量,输出batch_size的小批量,随机选取
def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))#转成listrandom.shuffle(indices)#打乱for i in range(0, num_examples, batch_size):#batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)])#取yield features[batch_indices], labels[batch_indices]#不断返回# #print(features)
# #print(labels)
#
#
batch_size = 10
#
# for x, y in data_iter(batch_size, features,labels):
#      print(x, '\n', y)
#      break
# # 提取第一列特征作为x轴,第二列特征作为y轴
# x = features[:, 1].detach().numpy() #将特征和标签转换为NumPy数组,以便能够在Matplotlib中使用。
# y = labels.detach().numpy()
#
# # 绘制散点图
# plt.scatter(x, y, 1)
# plt.xlabel('Feature 1')
# plt.ylabel('Feature 2')
# plt.title('Synthetic Data')
# plt.show()
#
# #定义初始化模型w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad = True)def linreg(x, w, b):return torch.matmul(x, w) + b#定义损失函数def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape))**2 / 2 #弄成一样的形状# 定义优化算法
def sgd(params, lr, batch_size):"""小批量随梯度下降"""with torch.no_grad():#节省内存和计算资源。for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()#用于清空张量param的梯度信息。print("训练函数")lr = 0.03 #学习率
num_ecopchs = 300 #数据扫描三遍
net = linreg #指定模型
loss = squared_loss #损失for epoch in range(num_ecopchs):#扫描数据for x, y in data_iter(batch_size, features, labels): #拿出x, yl = loss(net(x, w, b), y)#求损失,预测net,真实yl.sum().backward()#算梯度sgd([w, b], lr, batch_size)#使用参数的梯度更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch {epoch + 1},loss {float(train_l.mean()):f}')

运行效果:

在这里插入图片描述

http://www.lryc.cn/news/207269.html

相关文章:

  • 《视觉SLAM十四讲》公式推导(三)
  • pnpm、npm、yarn的区别
  • 搞定蓝牙——第四章(GATT协议)
  • Go语言入门心法(十四): Go操作Redis实战
  • Java学习笔记(三)
  • Flutter笔记:GetX模块中不使用 Get.put 怎么办
  • 2023前端面试整理
  • 文化融合:TikTok如何弥合跨文化差异
  • asp.net core获取config和env
  • Git不常用命令(持续更新)
  • PostPreSql 数据库的一些用法
  • 小工具推荐:FastGithub的下载及使用
  • 硬件信息查看工具 EtreCheckpro mac中文版功能介绍
  • 宝塔Python3.7安装模块报错ModuleNotFoundError: No module named ‘Crypto‘解决办法
  • 优化改进YOLOv5算法:加入ODConv+ConvNeXt提升小目标检测能力——(超详细)
  • ElasticSearch安装、插件介绍及Kibana的安装与使用详解
  • JVM | 命令行诊断与调优 jhsdb jmap jstat jps
  • SQL 表达式
  • Unity3D 打包发布时生成文件到打包目录
  • Elasticsearch中使用join来进行父子关联
  • 提供一个springboot使用h2数据库是无法使用脚本并报错的处理方案
  • 【组合计数】CF1866 H
  • JavaSpringbootmysql农产品销售管理系统47627-计算机毕业设计项目选题推荐(附源码)
  • 一文5000字从0到1使用Jmeter实现轻量级的接口自动化测试(图文并茂)
  • 蓝桥杯每日一题0223.10.23
  • php危险函数及rce漏洞
  • 4. 寻找两个正序数组的中位数
  • Stable Diffusion AI绘图
  • MR混合现实情景实训教学系统在旅游管理专业中的应用
  • CentOS 使用线程库Pthread 库