当前位置: 首页 > news >正文

机器学习day3

自定义数据集使用框架的线性回归方法对其进行拟合

import matplotlib.pyplot as plt
import torch
import numpy as np
# 1.散点输入
# 1、散点输入
# 定义输入数据
data = [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2], [-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6], [0.4, 34.0], [0.8, 62.3]]
# 转换为 NumPy 数组
data = np.array(data)
# 提取 x_data 和 y_data
x_data = data[:, 0]
y_data = data[:, 1]#将x_data 和y_data 转化成tensor
x_train=torch.tensor(x_data,dtype=torch.float32)
y_train=torch.tensor(y_data,dtype=torch.float32)
print(x_train)
# 2.定义前向模型
import torch.nn as nn
#定义损失
criterion=nn.MSELoss()# 方案4
# 最常用的网络结构
# 直接重写继承nn.Module
class LinearModel(nn.Module):#初始化def __init__(self):super(LinearModel,self).__init__()#定义一个nn.ModuleListself.layers=nn.Linear(1,1)#前向传播def forward(self, x):x=self.layers(x)return x
#初始化一下模型,返回模型对象
model=LinearModel()#优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#4.开始迭代
epoches =500
for n in range(1,epoches+1):#现在x_train 相当于10个样本,但是现在维度,添加一个维度#10x1   变成样本 x 维度形式y_prd=model(x_train.unsqueeze(1))#计算损失#y_prd在前面,y_true 是后面loss=criterion(y_prd.squeeze(1),y_train)#梯度更新#清空之前存储在优化器中的梯度optimizer.zero_grad()#损失函数对模型参数的梯度loss.backward()#根据优化算法更新参数optimizer.step()# 5、显示频率设置if n % 10 == 0 or n == 1:print(f"epoches:{n},loss:{loss}")x_np = x_train.numpy()
y_np = y_train.numpy()
y_pred_np = model(x_train.unsqueeze(1)).detach().numpy()# 绘制数据点和拟合直线
plt.scatter(x_np, y_np, color='blue', label='Data Points')  # 原始数据点
plt.plot(x_np, y_pred_np, color='red', label='Fitted Line')  # 拟合直线
plt.xlabel('x')
plt.ylabel('y')
plt.title('Linear Regression Fit (PyTorch)')
plt.legend()
plt.show()

结果展示

http://www.lryc.cn/news/526396.html

相关文章:

  • 追剧记单词之:国色芳华与单词速记
  • AIGC浪潮下,图文内容社区数据指标体系构建探索
  • 总线、UART、IIC、SPI
  • 戴尔电脑设置u盘启动_戴尔电脑设置u盘启动多种方法
  • 【python】四帧差法实现运动目标检测
  • JVM学习指南(48)-JVM即时编译
  • office 2019 关闭word窗口后卡死未响应
  • [操作系统] 深入进程地址空间
  • CVE-2025-0411 7-zip 漏洞复现
  • leetcode151-反转字符串中的单词
  • 若依 v-hasPermi 自定义指令失效场景
  • vue3中自定一个组件并且能够用v-model对自定义组件进行数据的双向绑定
  • 使用 Python 和 Tesseract 实现验证码识别
  • 谈一谈前端构建工具的本地代理配置(Webpack与Vite)
  • CentOS7非root用户离线安装Docker及常见问题总结、各种操作系统docker桌面程序下载地址
  • Alibaba Spring Cloud 十三 Nacos,Gateway,Nginx 部署架构与负载均衡方案
  • +-*/运算符优先级计算模板
  • GPT 结束语设计 以nanogpt为例
  • FastDFS的安装及使用
  • C++ lambda表达式
  • react页面定时器调用一组多个接口,如果接口请求返回令牌失效,清除定时器不再触发这一组请求
  • Python的泛型(Generic)与协变(Covariant)
  • Python Typing: 实战应用指南
  • OpenEuler学习笔记(六):OpenEuler与其他Linux服务器的区别是什么?
  • 如何使用CRM数据分析和洞察来支持业务决策和市场营销?
  • MyBatis和JPA区别详解
  • SVN客户端使用手册
  • VsCode安装文档
  • 豆包MarsCode 蛇年编程大作战 | 高效开发“蛇年运势预测系统”
  • 【动态规划】--- 斐波那契数模型