当前位置: 首页 > news >正文

自定义数据集使用框架的线性回归方法对其进行拟合

代码

import torch
import numpy as np
import torch.nn as nncriterion = nn.MSELoss()data = np.array([[-0.5, 7.7],[1.8, 98.5],[0.9, 57.8],[0.4, 39.2],[-1.4, -15.7],[-1.4, -37.3],[-1.8, -49.1],[1.5, 75.6],[0.4, 34.0],[0.8, 62.3]])x_data = data[:, 0]
y_data = data[:, 1]x_train = torch.tensor(x_data, dtype=torch.float32)
y_train = torch.tensor(y_data, dtype=torch.float32)model = nn.Sequential(nn.Linear(1, 1))optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
epoch = 500
for n in range(1,epoch+1):y_prd = model(x_train.unsqueeze(1))loss = criterion(y_prd.squeeze(1), y_train)optimizer.zero_grad()loss.backward()optimizer.step()if n % 10 == 0 or n == 1:print(f'epoch: {n}, loss: {loss}')

运行结果:

epoch: 1, loss: 2957.510986328125
epoch: 10, loss: 1785.2095947265625
epoch: 20, loss: 1031.366455078125
epoch: 30, loss: 606.9622192382812
epoch: 40, loss: 366.8883361816406
epoch: 50, loss: 230.34017944335938
epoch: 60, loss: 152.19020080566406
epoch: 70, loss: 107.1496810913086
epoch: 80, loss: 80.98954772949219
epoch: 90, loss: 65.66667175292969
epoch: 100, loss: 56.6099967956543
epoch: 110, loss: 51.205726623535156
epoch: 120, loss: 47.949058532714844
epoch: 130, loss: 45.966827392578125
epoch: 140, loss: 44.74833297729492
epoch: 150, loss: 43.99201583862305
epoch: 160, loss: 43.51824188232422
epoch: 170, loss: 43.218894958496094
epoch: 180, loss: 43.02821731567383
epoch: 190, loss: 42.90589141845703
epoch: 200, loss: 42.82688903808594
epoch: 210, loss: 42.77559280395508
epoch: 220, loss: 42.74211120605469
epoch: 230, loss: 42.72018051147461
epoch: 240, loss: 42.705726623535156
epoch: 250, loss: 42.696205139160156
epoch: 260, loss: 42.68989181518555
epoch: 270, loss: 42.68572235107422
epoch: 280, loss: 42.68293380737305
epoch: 290, loss: 42.68109130859375
epoch: 300, loss: 42.67984390258789
epoch: 310, loss: 42.679046630859375
epoch: 320, loss: 42.678489685058594
epoch: 330, loss: 42.67811584472656
epoch: 340, loss: 42.67786407470703
epoch: 350, loss: 42.67770004272461
epoch: 360, loss: 42.677608489990234
epoch: 370, loss: 42.677520751953125
epoch: 380, loss: 42.67747116088867
epoch: 390, loss: 42.677452087402344
epoch: 400, loss: 42.67742156982422
epoch: 410, loss: 42.67741775512695
epoch: 420, loss: 42.67739486694336
epoch: 430, loss: 42.67738342285156
epoch: 440, loss: 42.67737579345703
epoch: 450, loss: 42.677391052246094
epoch: 460, loss: 42.67737579345703
epoch: 470, loss: 42.67738723754883
epoch: 480, loss: 42.67738342285156
epoch: 490, loss: 42.67738342285156
epoch: 500, loss: 42.67737579345703

http://www.lryc.cn/news/527004.html

相关文章:

  • 15天基础内容-5
  • 82,【6】BUUCTF WEB .[CISCN2019 华东南赛区]Double Secret
  • Android WebView 中网页被劫持的原因及解决方案
  • 特朗普政府将开展新网络攻击
  • 快递代取项目Uniapp+若依后端管理
  • arcgis短整型变为长整型的处理方式
  • 06、Redis相关概念:缓存击穿、雪崩、穿透、预热、降级、一致性等
  • 嵌入式基础 -- PCIe 控制器中断管理之MSI与MSI-X简介
  • websocket实现
  • unity学习20:time相关基础 Time.time 和 Time.deltaTime
  • 【C++】特殊类设计、单例模式与类型转换
  • scratch七彩六边形 2024年12月scratch三级真题 中国电子学会 图形化编程 scratch三级真题和答案解析
  • 代码随想录刷题day16|(哈希表篇)349.两个数组的交集
  • Synology 群辉NAS安装(6)安装mssql
  • 2025年美赛B题-结合Logistic阻滞增长模型和SIR传染病模型研究旅游可持续性-成品论文
  • Hook 函数
  • 蓝桥杯模拟算法:蛇形方阵
  • DeepSeek-R1解读:纯强化学习,模型推理能力提升的新范式?
  • 深度解析:基于Vue 3的教育管理系统架构设计与优化实践
  • 【PyTorch】3.张量类型转换
  • Spring Boot整合JavaMail实现邮件发送
  • 字节跳动发布UI-TARS,超越GPT-4o和Claude,能接管电脑完成复杂任务
  • 数据的秘密:如何用大数据分析挖掘商业价值
  • OAuth1和OAuth2授权协议
  • AI学习(vscode+deepseek+cline)
  • 04-机器学习-网页数据抓取
  • 计网week1+2
  • 重定向与缓冲区
  • 练习题 - Django 4.x File 文件上传使用示例和配置方法
  • [VSCode] vscode下载安装及安装中文插件详解(附下载链接)