当前位置: 首页 > article >正文

pytorch简单线性回归模型

模型五步走

1、获取数据

     1. 数据预处理

     2.归一化

     3.转换为张量

2、定义模型

3、定义损失函数和优化器

4、模型训练

5、模型评估和调优

调优方法

6、可视化(可选)

示例代码

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, r2_score# print(np.__config__.show())##1、生成数据
np.random.seed(42)
def generate_data(x, slope=2.0, intercept=1.0, noise_std=2.0):"""生成带有噪声的线性数据 y = a*x + b + ε:param x: 输入特征:param slope: 斜率 a:param intercept: 截距 b:param noise_std: 噪声标准差:return: y 数据,以及真实参数 (slope, intercept)"""y = slope * x + intercept + np.random.randn(len(x)) * noise_stdreturn y, (slope, intercept)# 使用示例
x = np.linspace(0, 10, 100)
y, true_params = generate_data(x, slope=2, intercept=1, noise_std=2)
print("真实参数:", true_params)#归一化
x_norm = (x - x.min()) / (x.max() - x.min())
y_norm = (y - y.min()) / (y.max() - y.min())#转换为pytorch张量
x_tensor = torch.tensor(x_norm, dtype=torch.float32).view(-1, 1)
y_tensor = torch.tensor(y_norm, dtype=torch.float32).view(-1, 1)#2、定义模型
class LinearRegression(nn.Module):def __init__(self,input_size,output_size):super(LinearRegression, self).__init__()self.linear = nn.Linear(input_size,output_size)def forward(self, x):out = self.linear(x)return out#实例化模型
model = LinearRegression(1,1)#3、定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)#4、训练模型
num_epochs = 10000
torch.nn.init.xavier_normal_(model.linear.weight)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)for epoch in range(num_epochs):#前向传播outputs = model(x_tensor)loss = criterion(outputs,y_tensor)#反向传播optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 1000 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')#5、输出测试结果
print('训练完成!')
print(f'权重: {model.linear.weight.item():.4f}, 偏置: {model.linear.bias.item():.4f}')#6、可视化
predicted = model(x_tensor).detach().numpy()
# 反归一化
predicted_unscaled = predicted * (y.max() - y.min()) + y.min()
y_true_unscaled = y_tensor.numpy() * (y.max() - y.min()) + y.min()# 评估指标
mae = mean_absolute_error(y_true_unscaled, predicted_unscaled)
r2 = r2_score(y_true_unscaled, predicted_unscaled)print(f'均方误差(MSE): {loss.item():.4f}')
print(f'平均绝对误差(MAE): {mae:.4f}')
print(f'R²决定系数(R²): {r2:.4f}')
r22 = r2_score(y_tensor.numpy(), predicted)
print(f"Model R² score: {r22:.4f}")#中文乱码
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.plot(x_tensor, y_tensor, 'ro', label='Original data')
plt.plot(x_tensor, predicted, label='拟合曲线')
plt.legend()
plt.show()

http://www.lryc.cn/news/2392659.html

相关文章:

  • 在 HTML 文件中添加图片的常用方法
  • 四、web安全-行业术语
  • Kafka核心技术解析与最佳实践指南
  • Unity基础学习(十二)Unity 物理系统之范围检测
  • JVM 的垃圾回收机制 GC
  • TypeScript 针对 iOS 不支持 JIT 的优化策略总结
  • 00 QEMU源码中文注释与架构讲解
  • ansible template 文件中如果包含{{}} 等非ansible 变量处理
  • Screen 连接远程服务器(Ubuntu)
  • 路由器、网关和光猫三种设备有啥区别?
  • vscode实时预览编辑markdown
  • 2505软考高项第一、二批真题终极汇总
  • 云原生安全基础:Linux 文件权限管理详解
  • A类地址中最小网络号(0.x.x.x) 默认路由 / 无效/未指定地址
  • [嵌入式实验]实验二:LED控制
  • 6.4.2_3最短路径问题_Floyd算法
  • <PLC><socket><西门子>基于西门子S7-1200PLC,实现手机与PLC通讯(通过websocket转接)
  • day 33 python打卡
  • 开发时如何通过Service暴露应用?ClusterIP、NodePort和LoadBalancer类型的使用场景分别是什么?
  • 【机械视觉】Halcon—【六、交集并集差集和仿射变换】
  • 深度学习核心网络架构详解(续):从 Transformers 到生成模型
  • AI智能混剪视频大模型开发方案:从文字到视频的自动化生成·优雅草卓伊凡
  • allWebPlugin中间件VLC专用版之截图功能介绍
  • 【JavaSE】异常处理学习笔记
  • Scratch节日 | 六一儿童节
  • 深度解析:跨学科论文 +“概念迁移表” 模板写作全流程
  • 深度剖析Node.js的原理及事件方式
  • VScode-使用技巧-持续更新
  • 主流 AI IDE 之一的 Windsurf 使用入门
  • 大数据量下的数据修复与回写Spark on Hive 的大数据量主键冲突排查:COUNT(DISTINCT) 的陷阱