当前位置: 首页 > news >正文

【李沐】3.3线性回归的简洁实现

1、生成数据集

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w = torch.tensor([2, -3.4])  # 定义真实权重 true_w,其中 [2, -3.4] 表示两个特征的权重值
true_b = 4.2  # 定义真实偏差 true_b,表示模型的截距项# 调用 synthetic_data 函数生成合成数据集,传入真实权重 true_w、偏差 true_b 和样本数量 1000
# 这将返回特征矩阵 features 和目标值 labels,用于训练和测试模型
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

2、读取数据集

def load_array(data_arrays, batch_size, is_train=True):  # 定义函数 load_array,接受数据数组、批量大小和是否训练标志 is_train 作为参数"""构造一个 PyTorch 数据迭代器"""dataset = data.TensorDataset(*data_arrays)  # 创建一个 PyTorch 数据集,使用给定的数据数组# 使用 data.DataLoader 构造数据迭代器,传入数据集、批量大小和是否训练标志# 当 is_train 为 True 时,数据会被随机打乱,用于训练;否则,数据不会被打乱,用于测试或验证return data.DataLoader(dataset, batch_size, shuffle=is_train)
batch_size = 10
data_iter = load_array((features, labels), batch_size)
batch_size = 10
data_iter = load_array((features, labels), batch_size)

3、定义模型
线性层输入2,输出1

# nn是神经⽹络的缩写
from torch import nn
net = nn.Sequential(nn.Linear(2, 1))

4、初始化模型
通过net[0]选择⽹络中的第⼀个图层,然后使⽤weight.data和bias.data⽅法访问参数。我们还可以使⽤替换⽅法normal_和fill_来重写参数值。
0,0.0.01的意思是均值为0、标准差为0.01的正态分布中随机采样

net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

5、定义损失函数
均方误差,L2范数

loss = nn.MSELoss()

6、定义优化函数
net.parameters() 返回神经网络模型中需要被优化的参数列表

trainer = torch.optim.SGD(net.parameters(), lr=0.03)

7、训练
主要是注意里面的写法,看到的别的代码知道啥意思就可以

num_epochs = 3  # 定义训练轮次数量为 3for epoch in range(num_epochs):  # 迭代每个训练轮次for X, y in data_iter:  # 遍历数据迭代器中的每个数据批次l = loss(net(X), y)  # 计算模型预测值与真实标签之间的损失trainer.zero_grad()  # 清零梯度,以便进行下一轮的梯度计算l.backward()  # 对损失进行反向传播,计算参数的梯度trainer.step()  # 使用优化器更新模型参数l = loss(net(features), labels)  # 在整个训练集上计算损失print(f'epoch {epoch + 1}, loss {l:f}')  # 打印当前训练轮次和损失值
http://www.lryc.cn/news/130248.html

相关文章:

  • Ghost-free High Dynamic Range Imaging withContext-aware Transformer
  • 过来,我告诉你个秘密:送给程序员男友最好的礼物,快教你对象学习磁盘分区啦!小点声哈,别让其他人学会了!
  • Cadence+硬件每日学习十个知识点(38)23.8.18 (Cadence的使用,界面介绍)
  • React Native Expo项目,复制文本到剪切板
  • React源码解析18(5)------ 实现函数组件【修改beginWork和completeWork】
  • vscode ssh 远程 gdb 调试
  • 云原生 AI 工程化实践之 FasterTransformer 加速 LLM 推理
  • PHP酒店点菜管理系统mysql数据库web结构apache计算机软件工程网页wamp
  • 【面试复盘】知乎暑期实习算法工程师二面
  • 内网穿透和服务器+IP 实现公网访问内网的区别
  • JAVA权限管理 助力企业精细化运营
  • 金融语言模型:FinGPT
  • LeetCode--HOT100题(30)
  • Springboot 实践(3)配置DataSource及创建数据库
  • 【问题整理】Ubuntu 执行 apt-get install xxx 报错
  • Java课题笔记~ SpringBoot简介
  • 一种基于springboot、redis的分布式任务引擎的实现(一)
  • 基于IDE Eval Resetter延长IntelliJ IDEA等软件试用期的方法(包含新版本软件的操作方法)
  • RocketMQ消费者可以手动消费但无法主动消费问题,或生成者发送超时
  • 【数据库系统】--【2】DBMS架构
  • 第三章 图论 No.13拓扑排序
  • 喜报 | 擎创再度入围IDC中国FinTech 50榜单
  • 【C++ 记忆站】引用
  • Hlang--用Python写个编程语言-变量的实现
  • 多维时序 | MATLAB实现PSO-CNN-BiLSTM多变量时间序列预测
  • 实现Java异步调用的高效方法
  • 批量提取文件名到excel,详细的提取步骤
  • C#中的泛型约束可以用在以下几个地方?
  • Linux Vm上部署Docker
  • ubuntu bind dns服务配置