当前位置: 首页 > news >正文

自定义数据集 使用paddlepaddle框架实现逻辑回归

导入必要的库

import numpy as np
import paddle
import paddle.nn as nn

数据准备:

seed=1
paddle.seed(seed)# 1.散点输入 定义输入数据
data = [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2], [-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6], [0.4, 34.0], [0.8, 62.3]]
#转化为数组
data=np.array(data)
# 提取x 和y
x_data=data[:,0]
y_data=data[:,1]
#转成张量 转成paddlepaddle张量
x_train=paddle.to_tensor(x_data,dtype=paddle.float32)
y_train=paddle.to_tensor(y_data,dtype=paddle.float32)

定义模型:

class LinearModel(nn.Layer):def __init__(self):super(LinearModel,self).__init__()self.linear=nn.Linear(1,1)def forward(self,x):x=self.linear(x)return x
#定义模型的对象
model=LinearModel()

损失函数和优化器:

#3.1损失函数
criterion=paddle.nn.MSELoss()
#3.2 优化器
optimizer=paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

模型训练和保存:

epochs=500
final_checkpoint={}
for epoch in range(1,epochs+1):#前向传播#unsqueeze()扩展一维y_prd=model(x_train.unsqueeze(1))loss=criterion(y_prd.squeeze(1),y_train)#清除之前计算的梯度optimizer.clear_grad()#自动计算梯度loss.backward()#更新参数optimizer.step()# 5.显示频率的设置if epoch % 10==0 or epoch==1:#可以使用float(loss)或者 loss.numpy()会报警告print(f"epoch:{epoch},loss:{float(loss)}")#添加检查点程序if epoch==epochs:#把迭代次数写入final_checkpoint['epoch']=epoch#把训练损失写入final_checkpoint['loss']=loss#基础API模型的保存
paddle.save(model.state_dict(),'./基础API/model.pdparams')
#保存检查点checkpoint信息 是序列化的文件
paddle.save(final_checkpoint, "./基础API/final_checkpoint.pkl")

模型加载及预测:

#基础API模型的加载
model_state_dict=paddle.load('./基础API/model.pdparams')
# optimizer_state_dict=paddle.load('./基础API/optimizer.pdopt')
final_checkpoint_state_dict=paddle.load('./基础API/final_checkpoint.pkl')
print(final_checkpoint_state_dict)#模型和参数联系起来
model.set_state_dict(model_state_dict)#训练 评估 和推理
# 模型验证模式
model.eval()
#使用TensorDateset 和DateLoader封装
dataloader_test=DataLoader(TensorDataset([paddle.to_tensor([1.5],dtype=paddle.float32)]),batch_size=1)#迭代
for x_test in dataloader_test:predict=model(x_test[0])print(predict)

结果展示:

http://www.lryc.cn/news/530958.html

相关文章:

  • Docker入门篇(Docker基础概念与Linux安装教程)
  • c/c++高级编程
  • 2024-我的学习成长之路
  • vscode软件操作界面UI布局@各个功能区域划分及其名称称呼
  • xmind使用教程
  • Day33【AI思考】-分层递进式结构 对数学数系的 终极系统分类
  • k8s二进制集群之ETCD集群证书生成
  • MySQL5.5升级到MySQL5.7
  • Golang Gin系列-9:Gin 集成Swagger生成文档
  • 利用Python高效处理大规模词汇数据
  • 【PyQt】超级超级笨的pyqt计算器案例
  • Git 的起源与发展
  • 预防和应对DDoS的方法
  • 51单片机开发:独立按键实验
  • 02.04 数据类型
  • FPGA学习篇——开篇之作
  • 【Cadence仿真技巧学习笔记】求解65nm库晶体管参数un, e0, Cox
  • 【RocketMQ】RocketMq之IndexFile深入研究
  • 小白零基础--CPP多线程
  • 利用deepseek参与软件测试 基本架构如何 又该在什么环节接入deepseek
  • 大模型微调技术总结及使用GPU对VisualGLM-6B进行高效微调
  • WPF进阶 | WPF 样式与模板:打造个性化用户界面的利器
  • Java 大视界 -- Java 大数据在自动驾驶中的数据处理与决策支持(68)
  • 自动化构建-make/Makefile 【Linux基础开发工具】
  • python学opencv|读取图像(五十二)使用cv.matchTemplate()函数实现最佳图像匹配
  • 通信方式、点对点通信、集合通信
  • TCP编程
  • OpenAI 实战进阶教程 - 第七节: 与数据库集成 - 生成 SQL 查询与优化
  • Apache Iceberg数据湖技术在海量实时数据处理、实时特征工程和模型训练的应用技术方案和具体实施步骤及代码
  • QT交叉编译环境搭建(Cmake和qmake)