当前位置: 首页 > news >正文

自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score# 数据准备
class1_points = np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])
class2_points = np.array([[3.2, 3.2],[3.7, 2.9],[3.2, 2.6],[1.7, 3.3],[3.4, 2.6],[4.1, 2.3],[3.0, 2.9]])x_train = np.concatenate((class1_points, class2_points), axis=0)
y_train = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))))x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)# 设置随机种子
seed = 42
torch.manual_seed(seed)# 定义模型
class LogisticRegreModel(nn.Module):def __init__(self):super(LogisticRegreModel, self).__init__()self.fc = nn.Linear(2, 1)def forward(self, x):x = self.fc(x)x = torch.sigmoid(x)return xmodel = LogisticRegreModel()# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.05)# 训练模型
epochs = 1000
for epoch in range(1, epochs + 1):y_pred = model(x_train_tensor)loss = criterion(y_pred, y_train_tensor.unsqueeze(1))optimizer.zero_grad()loss.backward()optimizer.step()if epoch % 50 == 0 or epoch == 1:print(f"epoch: {epoch}, loss: {loss.item()}")# 保存模型
torch.save(model.state_dict(), 'model.pth')# 加载模型
model = LogisticRegreModel()
model.load_state_dict(torch.load('model.pth'))
# 设置模型为评估模式
model.eval()# 进行预测
with torch.no_grad():y_pred = model(x_train_tensor)y_pred_class = (y_pred > 0.5).float().squeeze()# 计算精确度、召回率和F1分数
precision = precision_score(y_train_tensor.numpy(), y_pred_class.numpy())
recall = recall_score(y_train_tensor.numpy(), y_pred_class.numpy())
f1 = f1_score(y_train_tensor.numpy(), y_pred_class.numpy())print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

http://www.lryc.cn/news/530093.html

相关文章:

  • 【论文笔记】Fast3R:前向并行muti-view重建方法
  • 谈谈你所了解的AR技术吧!
  • upload labs靶场
  • 搜索引擎友好:设计快速收录的网站架构
  • 基于 oneM2M 标准的空气质量监测系统的互操作性
  • 春晚舞台上的人形机器人:科技与文化的奇妙融合
  • 零基础学习书生.浦语大模型-入门岛
  • Gurobi基础语法之 addConstr, addConstrs, addQConstr, addMQConstr
  • 数据结构---图的遍历
  • Qwen 模型自动构建知识图谱,生成病例 + 评价指标优化策略
  • .Net Web API 访问权限限定
  • 项目架构调整,切换版本并发布到中央仓库
  • 考试知识点位运算
  • matlab快速入门(2)-- 数据处理与可视化
  • Kafka中文文档
  • Python-列表
  • 51单片机开发:定时器中断
  • 【HarmonyOS之旅】基于ArkTS开发(三) -> 兼容JS的类Web开发(二)
  • 算法【混合背包】
  • WordPress eventon-lite插件存在未授权信息泄露漏洞(CVE-2024-0235)
  • 基于微信小程序的医院预约挂号系统设计与实现(LW+源码+讲解)
  • C++初阶 -- 手撕string类(模拟实现string类)
  • 【Postman接口测试】Postman的安装和使用
  • miniconda学习笔记
  • 区块链项目孵化与包装设计:从概念到市场的全流程指南
  • JavaScript的基本组成
  • [Linux]从零开始的STM32MP157 U-Boot移植
  • 【Unity3D】实现横版2D游戏——攀爬绳索(简易版)
  • 【llm对话系统】大模型 Llama 源码分析之 LoRA 微调
  • 算法随笔_35: 每日温度