当前位置: 首页 > news >正文

练习2-线性回归迭代(李沐函数简要解析)

环境:再练习1中
视频链接:https://www.bilibili.com/video/BV1PX4y1g7KC/?spm_id_from=333.999.0.0

代码与详解

数据库
numpy 数据处理处理
torch.utils 数据加载与数据
d2l 专门的库
nn 包含各种层与激活函数

import numpy as np
import torch 
from torch.utils import data
from d2l import torch as d2l
from torch import nn

生成数据集
w=torch.tensor([2,-3.4]) 生成一维两个向量的张量
features,labels=d2l.synthetic_data(w,b,nume) 生成nume个w为权重,b为偏置的数据

w=torch.tensor([2,-3.4])
b=4.2
features,labels=d2l.synthetic_data(w,b,100)

定义对数据集的读取
data.TensorDataset(*data_arrays) 将多个张量合并为一个 通常用于合并特征值与标签 data_arrays=(features,labels)
data.DataLoader(dataset,batchsize,shuffle=true)每次根据上一个函数返回的对象读取batchsize个值 并打乱数据

def load_arrays(data_arrays,batch_size,is_train=True):dataset=data.TensorDataset(*data_arrays)return data.DataLoader(dataset,batch_size,is_train)

定义数据加载器 并 调用
next(iter(已初始化的数据加载器)) 重新调用数据加载器

batch_size=10
data_iter=load_arrays((features,labels),batch_size)
next(iter(data_iter))

定义模型

定义为线性模型且只有一层
nn.Sequential() 用于包装层
nn.linear(2,1) 用于定义两输入一输出的线性层

net=nn.Sequential(nn.Linear(2,1))

初始化参数 w,b,lr,epoch,batch_size
net[0].weight.data.normal 正态分布
net[0].bias.data.fill_(0) b赋值

net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)

定义损失函数 平方误差
nn.MSELoss()

Loss=nn.MSELoss()

优化算法 小批量梯度下降 torch.optim.SGD(net.parameters(), lr=0.03)

trainer=torch.optim.SGD(net.parameters(),lr=0.03)

训练

epochs=3
for epoch in range(epochs):for X,y in data_iter:l=Loss(net(X),y)# 将梯度清零   trainer.zero_grad()# 反向传播l.backward()#更新参数trainer.step()l=Loss(net(features),labels)print(f'epoch {epoch + 1}, loss {l:f}')

相关函数与组成部分

定义模型

定义线性回归模型
from torch import nn
net=nn.Sequential(nn.Linear(2,1))

为模型赋值

w,b正态分布
net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)

定义损失函数

Loss=nn.MSELoss()

定义优化算法

trainer = torch.optim.SGD(net.parameters(),lr=0.03)

(训练与反向传播不太了解)

相关的Python语法

def 函数名(变量=True):return for epoch in range(epochs):
http://www.lryc.cn/news/309327.html

相关文章:

  • 人像背景分割SDK,智能图像处理
  • 100M服务器能同时容纳多少人访问
  • Mysql 的高可用详解
  • Acwing枚举、模拟与排序(一)
  • MySQL的主从同步原理
  • naive-ui-admin 表格去掉工具栏toolbar
  • C++之结构体
  • 分布式ID选型对比(1)
  • T-SQL 高阶语法之存储过程
  • 解决鸿蒙模拟器卡顿的问题
  • 【LeetCode每日一题】【BFS模版与例题】863.二叉树中所有距离为 K 的结点
  • 设计模式-结构模式-装饰模式
  • MySQL:一行记录如何
  • ‘grafana.ini‘ is read only ‘defaults.ini‘ is read only
  • 博途PLC 面向对象系列之“输送带控制功能块“(SCL代码)
  • 2024-02学习笔记
  • 最新消息:英特尔宣布成立全新独立运营的FPGA公司——Altera
  • RC正弦波振荡电路
  • 【Git学习笔记】提交PR
  • 线程池的相关参数
  • 图书推荐||Word文稿之美
  • 前端导出word文件的多种方式、前端导出excel文件
  • Linux和Windows操作系统在腾讯云幻兽帕鲁服务器上的内存占用情况如何?
  • 腾讯云4核8G的云服务器性能水平?使用场景说明
  • 1_SQL
  • PoC免写攻略
  • c1-周考2
  • express+mysql+vue,从零搭建一个商城管理系统7--文件上传,大文件分片上传
  • markdown的使用(Typora)
  • 【python】json转成成yaml中文编码异常显示成:\u5317\u4EAC\u8DEF123\u53F7