当前位置: 首页 > news >正文

pytorch线性回归模型预测房价例子

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np# 1. 创建线性回归模型类
class LinearRegressionModel(nn.Module):def __init__(self):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(1, 1)  # 1个输入特征,1个输出def forward(self, x):return self.linear(x)# 2. 生成训练数据
area = np.array([1000, 1500, 1800, 2400, 3000], dtype=np.float32).reshape(-1, 1)  # 房屋面积(平方英尺)
price = np.array([250000, 300000, 350000, 500000, 600000], dtype=np.float32).reshape(-1, 1)  # 房价# 标准化房屋面积
area = area / 3000  # 假设最大面积为3000平方英尺# 转换为 PyTorch 张量
x_train = torch.from_numpy(area)
y_train = torch.from_numpy(price)# 3. 实例化模型、定义损失函数和优化器
model = LinearRegressionModel()
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001)  # 学习率调低# 4. 训练模型
epochs = 1000
for epoch in range(epochs):# 前向传播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向传播optimizer.zero_grad()  # 清零梯度loss.backward()  # 计算梯度optimizer.step()  # 更新权重# 每100次输出一次损失值if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')# 5. 保存训练好的模型
torch.save(model.state_dict(), 'linear_regression_model.pth')
print("模型已保存!")# 6. 加载模型并进行预测
loaded_model = LinearRegressionModel()
loaded_model.load_state_dict(torch.load('linear_regression_model.pth'))
loaded_model.eval()  # 设置为评估模式# 进行预测
new_area = torch.tensor([[2500 / 3000]], dtype=torch.float32)  # 假设新房屋面积为2500平方英尺,标准化处理
predicted_price = loaded_model(new_area)
print(f"Predicted price for area 2500 sq.ft: ${predicted_price.item():,.2f}")
  • 创建模型LinearRegressionModel 是一个简单的线性回归模型,只有一个线性层 (nn.Linear)。
  • 数据准备:我们生成了一个简单的示例数据集,包含房屋面积和房价数据。数据被转换为 PyTorch 张量格式。
  • 模型训练:使用均方误差损失函数 (MSELoss) 和随机梯度下降优化器 (SGD) 来训练模型。模型在1000个迭代中进行训练,并在每100次迭代后输出损失值。
  • 保存模型:训练完成后,使用 torch.save 保存模型的参数。
  • 加载模型并进行预测:使用 torch.load 加载模型参数,并将模型设置为评估模式 (eval)。然后,我们通过模型对一个新的房屋面积值进行预测。
http://www.lryc.cn/news/528608.html

相关文章:

  • 练习题 - DRF 3.x Caching 缓存使用示例和配置方法
  • 如何解压7z文件?8种方法(Win/Mac/手机/网页端)
  • python学opencv|读取图像(五十)使用addWeighted()函数实现图像加权叠加效果
  • window中80端口被占用问题
  • 06-机器学习-数据预处理
  • 电梯系统的UML文档12
  • 萌新学 Python 之运算符
  • 嵌入式知识点总结 Linux驱动 (五)-linux内核
  • zabbix7 配置字体 解决中文乱码问题(随手记)
  • 预测不规则离散运动的下一个结构
  • CTFSHOW-WEB入门-命令执行29-32
  • SQL Server 建立每日自动log备份的维护计划
  • doris:HLL
  • 双层Git管理项目,github托管显示正常
  • 准备知识——旋转机械的频率和振动基础
  • 知识库管理驱动企业知识流动与工作协同创新模式
  • CMake常用命令指南(CMakeList.txt)
  • 【回溯+剪枝】找出所有子集的异或总和再求和 全排列Ⅱ
  • 中国技术突破对国际格局的多维影响与回应
  • 【漫话机器学习系列】068.网格搜索(GridSearch)
  • 元宇宙下的Facebook:虚拟现实与社交的结合
  • 记忆力训练day08
  • 崇州市街子古镇正月初一繁华剪影
  • websocket webworker教程及应用
  • 【后端】Flask
  • 【cran Archive R包的安装方式】
  • 如何用matlab画一条蛇
  • Greenplum临时表未清除导致库龄过高处理
  • 【Linux】gdb——Linux调试器
  • C++ 中用于控制输出格式的操纵符——setw 、setfill、setprecision、fixed