当前位置: 首页 > news >正文

《动手学深度学习(PyTorch版)》笔记3.3

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过。

Chapter3 Linear Neural Networks

3.3 Concise Implementations of Linear Regression

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2ltrue_w=torch.tensor([2,-3.4])
true_b=4.2
features,labels=d2l.synthetic_data(true_w,true_b,1000)#构造一个pytorch数据迭代器
def load_array(data_arrays,batch_size,is_train=True): #@savedataset=data.TensorDataset(*data_arrays)#"TensorDataset" is a class provided by the torch.utils.data module which is a dataset wrapper that allows you to create a dataset from a sequence of tensors. #"*data_arrays" is used to unpack the tuple into individual tensors.#The '*' operator is used for iterable unpacking.#Here, data_arrays is expected to be a tuple containing the input features and corresponding labels. The "*data_arrays" syntax is used to unpack the elements of the tuple and pass them as separate arguments.return data.DataLoader(dataset,batch_size,shuffle=is_train)#Constructs a PyTorch DataLoader object which is an iterator that provides batches of data during training or testing.
batch_size=10
data_iter=load_array([features,labels],batch_size)
print(next(iter(data_iter)))#调用next()函数时会返回迭代器的下一个项目,并更新迭代器的内部状态以便下次调用#定义模型变量,nn是神经网络的缩写
from torch import nn
net=nn.Sequential(nn.Linear(2,1))
#Creates a sequential neural network with one linear layer.
#Input size (in_features) is 2, indicating the network expects input with 2 features.
#Output size (out_features) is 1, indicating the network produces 1 output.#初始化模型参数
net[0].weight.data.normal_(0,0.01)#The underscore at the end (normal_) indicates that this operation is performed in-place, modifying the existing tensor in memory.
net[0].bias.data.fill_(0)#定义均方误差损失函数,也称平方L2范数,返回所有样本损失的平均值
loss=nn.MSELoss()#MSE:mean squared error #定义优化算法(仍是小批量随机梯度下降)
#update the parameters of the neural network (net.parameters()) using gradients computed during backpropagation. 
trainer=torch.optim.SGD(net.parameters(),lr=0.03)#SGD:stochastic gradient descent(随机梯度下降)#训练
num_epochs=3
for epoch in range(num_epochs):for X,y in data_iter:l=loss(net(X),y)trainer.zero_grad()l.backward()trainer.step()#Updates the model parameters using the computed gradients and the optimization algorithm.l=loss(net(features),labels)print(f'epoch {epoch+1},loss {l:.6f}')#{l:.f}表示将变量l格式化为小数点后有6位的浮点数。w=net[0].weight.data
print('w的估计误差:',true_w-w.reshape(true_w.shape))
b=net[0].bias.data
print('b的估计误差:',true_b-b)
http://www.lryc.cn/news/290653.html

相关文章:

  • OpenGL ES 渲染 NV21、NV12 格式图像有哪些“姿势”?
  • P8813 [CSP-J 2022] 乘方 题解
  • Ubuntu 常用命令、docker 常用命令、unzip常用命令、tar常用命令
  • 保护医疗数据不受威胁:MPLS专线在医疗网络安全中的角色
  • Java面试题夺命连环问
  • 华为策略路由+NQA配置
  • 逆置字符串
  • 第九节HarmonyOS 常用基础组件14-DataPanel
  • Vue开发之proxy代理的配置(附带uniapp代理配置)
  • 【数据分享】2023年我国省市县三级的公司企业数量(21类公司企业/Excel/Shp格式)
  • 6JS对象
  • 粒子群算法求解港口泊位调度问题(MATLAB代码)
  • idea控制台出现乱码的解决方案
  • R语言【taxlist】——summary(),show(),print():打印taxlist对象及其内容的概述
  • 【深度学习】sdxl中的 text_encoder text_encoder_2 区别
  • 上位机图像处理和嵌入式模块部署(python opencv)
  • 父元素flex:1 高度却被子元素撑开的问题
  • 【LUA】mac状态栏添加天气
  • 网络原理-TCP/IP(1)
  • C# Socket 允许控制台应用通过防火墙
  • Centos安装mysql/mariadb
  • 2024 年, Web 前端开发趋势
  • Mysql 插入数据
  • 【每日一题】YACS 473:栈的判断
  • Python - 整理 MySQL 慢查询日志
  • Python算法题集_无重复字符的最长子串
  • 12.Elasticsearch应用(十二)
  • linux -- 内存管理 -- SLAB分配器
  • 【MySQL】学习如何通过DQL进行数据库数据的条件查询
  • TS:子类型关系