当前位置: 首页 > news >正文

《动手深度学习》 线性回归从零开始实现实例


🎈 作者:Linux猿

🎈 简介:CSDN博客专家🏆,华为云享专家🏆,Linux、C/C++、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊!

🎈 欢迎小伙伴们点赞👍、收藏⭐、留言💬


本文是《动手深度学习》线性回归从零开始实现实例的实现和分析。

一、代码实现

实现代码如下所示。

# random 模块 调用 random() 方法返回随机生成的一个实数,值在[0,1)范围内
import random
# 机器学习框架 pythorch,类似于 TensorFlow 和 Keras
import torch
# 线性回归函数 y = Xw + b + e(噪音)'''
一系列封装的函数
'''
# 批量获取数据函数
def synthetic_data(w, b, num_examples):  #@save# 生成 y=Xw+b+噪声'''返回一个张量,张量里面的随机数是从相互独立的正态分布中随机生成的参与 1: 均值参与 2: 标准差参数 3: 张量的大小 [num_examples, len(w)]'''X = torch.normal(0, 1, (num_examples, len(w)))# torch.matmul 两个张量元素相乘y = torch.matmul(X, w) + b# 加上噪声y += torch.normal(0, 0.01, y.shape)return X, y.reshape((-1, 1))# 随机批量取数据函数
def data_iter(batch_size, features, labels):num_examples = len(features)# 生成存储值 0 ~ num_examples 值的列表,不重复indices = list(range(num_examples))# 在原列表 indices 中随机打乱所有元素random.shuffle(indices)# range() 第三个参数是步长for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])# yield 相当于不断的 return 的作用yield features[batch_indices], labels[batch_indices]# 计算预测值,网络模型
def linreg(X, w, b):# 线性回归模型return torch.matmul(X, w) + b# 计算损失
def squared_loss(y_hat, y):# 均方损失return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2# 梯度更新
def sgd(params, lr, batch_size):# 小批量随机梯度下降with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_() # 清除 param 的梯度值为 0'''
1. 生成数据集
包含 1000 条数据,每条 [x1, x2]
'''
# 用于生成数据临时的 true_w 和 true_b
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
# features: [1000, 2], labels: [1000, 1]'''
2. 初始化 w 和 b
w: 2 x 1, b: [0]
'''
# requires_grad 在计算中保留梯度信息
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
# 初始化张量为全零
b = torch.zeros(1, requires_grad=True)'''
3. 开始训练
'''
# 设置超参数 学习率
lr = 0.03
# 设置超参数 训练批次/迭代周期
num_epochs = 3
# 设置超参数 每次训练的数据量
batch_size = 10# 重命名函数
net = linreg
loss = squared_lossfor epoch in range(num_epochs): # num_epochs 个迭代周期for X, y in data_iter(batch_size, features, labels): # 每次随机取 10 条数据一起训练l = loss(net(X, w, b), y)  # X 和 y 的小批量损失,计算损失l.sum().backward() # 损失求和后,根据构建的计算图,计算关于[w,b]的梯度,反向传播算法一定要是一个标量才能进行计算,所以进行 sum 操作后 backwardsgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数# 不自动求导with torch.no_grad():train_l = loss(net(features, w, b), labels) # 使用更新后的 [w, b] 计算所有训练数据的 lossprint(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}') # 通过 mean 函数取平均值'''
with torch.no_grad():
在使用 pytorch 时,并不是所有的操作都需要进行计算图的生成(计算过程的构建,以便梯度反向传播等操作)。
而对于 tensor 的计算操作,默认是要进行计算图的构建的,在这种情况下,可以使用 with torch.no_grad():,
强制之后的内容不进行计算图构建。
'''

二、实现解析

2.1 参数和超参数

参数是需要通过训练来得到的结果,最常见的就是神经网络的权重 W 和 b。训练模型的目的就是要找到一套好的模型参数,用于预测未知的结果。这些参数我们是不用调的,是模型来训练的过程中自动更新生成的。

超参数是我们控制我们模型结构、功能、效率等的 调节旋钮,常见超参数:

(1)learning rate(学习率)

(2)epochs(迭代次数,也可称为 num of iterations)

(3)num of hidden layers(隐层数目)

(4)num of hidden layer units(隐层的单元数/神经元数)

(5)activation function(激活函数)

(6)batch-size(用mini-batch SGD的时候每个批量的大小)

(7)optimizer(选择什么优化器,如SGD、RMSProp、Adam)

(8)用诸如RMSProp、Adam优化器的时候涉及到的β1,β2等等

2.2 模型训练

整体的模型训练思路如下所示。

1. 数据集生成,包括:训练数据、测试数据;

2. 初始化参数 w 和 b;

3. 训练模型,设置超参数,开始训练模型;

参考链接:

深度学习中的超参数调节(learning rate、epochs、batch-size...) - 知乎

loss.sum().backward()中对于sum()的理解


🎈 感觉有帮助记得「一键三连支持下哦!有问题可在评论区留言💬,感谢大家的一路支持!🤞猿哥将持续输出「优质文章回馈大家!🤞🌹🌹🌹🌹🌹🌹🤞


http://www.lryc.cn/news/143685.html

相关文章:

  • Redis 命令
  • Linux网络编程:线程池并发服务器 _UDP客户端和服务器_本地和网络套接字
  • nvm安装electron开发与编译环境
  • 玩转Mysql系列 - 第7篇:玩转select条件查询,避免采坑
  • 启动程序结束程序打开指定网页
  • 从零开始学习 Java:简单易懂的入门指南之包装类(十九)
  • leetcode分类刷题:哈希表(Hash Table)(一、数组交集问题)
  • UML四大关系
  • forms组件(钩子函数(局部钩子、全局钩子)、三种页面的渲染方式、数据校验的使用)、form组件的参数以及单选多选形式
  • 跨专业申请成功|金融公司经理赴美国密苏里大学访学交流
  • 第十一章 CUDA的NMS算子实战篇(下篇)
  • R语言01-数据类型
  • 【网络基础实战之路】基于三层架构实现一个企业内网搭建的实战详解
  • C++11相较于C++98多了哪些可调用对象?--《包装器》篇
  • 栈与队列:常见的线性数据结构
  • android framework之AMS的启动管理与职责
  • Decoupling Knowledge from Memorization: Retrieval-augmented Prompt Learning
  • 腾讯云coding平台平台inda目录遍历漏洞复现
  • 无法正常访问服务器
  • 解决css英文内容不自动换行的问题
  • python语言学习
  • 1. 深度学习介绍
  • 【现场问题】oracle 11g 和12c 使用jdbc链接,兼容的问题
  • 嵌入式底层驱动需要知道的基本知识
  • 《软件开发的201个原则》阅读笔记 120-161条
  • JVM——类加载与字节码技术—类文件结构
  • C语言学习之main函数两个参数的应用
  • 本地部署 Stable Diffusion(Windows 系统)
  • Java源码分析(二)Double
  • 文件上传漏洞之条件竞争