当前位置: 首页 > news >正文

使用 Numpy 自定义数据集,使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

1. 导入必要的库

首先,导入我们需要的库:Numpy、Pytorch 和相关工具包。

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, recall_score, f1_score
2. 自定义数据集

使用 Numpy 创建一个简单的线性可分数据集,并将其转换为 Pytorch 张量。

# 创建数据集
X = np.random.rand(100, 2)  # 100 个样本,2 个特征
y = (X[:, 0] + X[:, 1] > 1).astype(int)  # 标签,若特征之和大于1则为 1,否则为 0# 转换为 PyTorch 张量
X_train = torch.tensor(X, dtype=torch.float32)
y_train = torch.tensor(y, dtype=torch.long)
3. 定义逻辑回归模型

在 Pytorch 中定义一个简单的逻辑回归模型。

class LogisticRegressionModel(nn.Module):def __init__(self, input_dim):super(LogisticRegressionModel, self).__init__()self.linear = nn.Linear(input_dim, 2)  # 二分类问题def forward(self, x):return self.linear(x)
4. 初始化模型、损失函数和优化器
# 初始化模型
model = LogisticRegressionModel(input_dim=2)# 损失函数与优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)
5. 训练模型

训练模型并保存训练好的权重。

epochs = 100
for epoch in range(epochs):# 前向传播outputs = model(X_train)loss = criterion(outputs, y_train)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 20 == 0:print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")# 保存模型
torch.save(model.state_dict(), 'logistic_regression.pth')
6. 加载模型并进行预测

加载保存的模型并进行预测。

# 加载模型
model = LogisticRegressionModel(input_dim=2)
model.load_state_dict(torch.load('logistic_regression.pth'))
model.eval()  # 设为评估模式# 预测
with torch.no_grad():y_pred = model(X_train)_, predicted = torch.max(y_pred, 1)
7. 计算精确度、召回率和 F1 分数

使用 sklearn 中的评估函数计算精确度、召回率和 F1 分数。

accuracy = accuracy_score(y_train, predicted)
recall = recall_score(y_train, predicted)
f1 = f1_score(y_train, predicted)print(f"Accuracy: {accuracy:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
8. 总结

这篇博客展示了如何使用 Numpy 自定义数据集,利用 Pytorch 框架实现逻辑回归模型,并进行训练。训练后的模型被保存,并在加载后进行预测,最后计算了精确度、召回率和 F1 分数。

http://www.lryc.cn/news/530385.html

相关文章:

  • MapReduce简单应用(一)——WordCount
  • c语言(关键字)
  • 蓝桥杯思维训练营(一)
  • 【C语言】结构体对齐规则
  • 2025-工具集合整理
  • 快速提升网站收录:利用网站用户反馈机制
  • 图漾相机——Sample_V1示例程序
  • 如何使用C#的using语句释放资源?什么是IDisposable接口?与垃圾回收有什么关系?
  • HTML 字符实体
  • Ubuntu 下 nginx-1.24.0 源码分析 - ngx_strerror_init()函数
  • 【c++】类与对象详解
  • nginx目录结构和配置文件
  • MacBook Pro(M1芯片)Qt环境配置
  • Kotlin 使用 Springboot 反射执行方法并自动传参
  • 网络安全技术简介
  • nginx 报错404
  • 【1.安装ubuntu22.04】
  • 【设计模式-行为型】备忘录模式
  • Linux环境下的Java项目部署技巧:安装 Mysql
  • 云原生(五十三) | SQL查询操作
  • 【前端知识】常用CSS样式举例
  • 硕成C语言1笔记
  • [SAP ABAP] Debug Skill
  • 理解 InnoDB 如何处理崩溃恢复
  • UE5 蓝图学习计划 - Day 8:触发器与交互事件
  • 根据接口规范封装网络请求和全局状态管理
  • Unet 改进:在encoder和decoder间加入TransformerBlock
  • work-stealing算法 ForkJoinPool
  • DeepSeek Janus-Pro:多模态AI模型的突破与创新
  • STM32-时钟树