当前位置: 首页 > news >正文

【pytorch】全连接网络简单二次函数拟合

下面是一个使用PyTorch实现全连接网络来拟合简单二次函数 y = x 2 y = x^2 y=x2 的示例。我们将创建一个简单的神经网络,定义损失函数和优化器,并进行训练。
在这里插入图片描述

下面是完整的代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 生成数据
x = torch.linspace(-10, 10, 100).unsqueeze(1)
y = x**2# 定义全连接网络
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(1, 128)self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x# 实例化模型、损失函数和优化器
model = SimpleNN()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练模型
num_epochs = 1000
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(x)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 预测
model.eval()
predicted = model(x).detach()# 可视化结果
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original data')
plt.plot(x.numpy(), predicted.numpy(), 'b-', label='Fitted line')
plt.legend()
plt.show()

这段代码的具体步骤如下:

  1. 生成数据:创建输入数据 ( x ) 和对应的标签 ( y )。
  2. 定义网络结构:创建一个简单的全连接神经网络,包括三层线性层。
  3. 实例化模型、损失函数和优化器:使用均方误差损失函数和Adam优化器。
  4. 训练模型:在1000个epoch上训练模型,并在每100个epoch打印一次损失。
  5. 预测和可视化:使用训练好的模型进行预测,并将原始数据和拟合结果进行可视化。

运行这段代码后,你将看到一个图形,其中红点表示原始的二次函数数据,蓝线表示神经网络拟合的结果。

http://www.lryc.cn/news/416942.html

相关文章:

  • git提交到本地仓库了,怎么撤回
  • lua学习(1)
  • SQL报错注入之updatexml
  • 单元测试的重要性
  • mysql线上查询数据注意锁表问题
  • UE5 右键菜单缺少Generate Visual Studio project files
  • 前端性能优化-webpack构建优化
  • Traefik:部署与实战
  • [Spring] SpringBoot统一功能处理与图书管理系统
  • 实现吸顶效果,一个页面多个元素吸顶效果
  • 【C++入门(下)】—— 我与C++的不解之缘(二)
  • 【数据结构】哈希应用-STL-位图
  • Unbuntu 服务器- Anaconda安装激活 + GPU配置
  • python 装饰器记录函数用时
  • 实验10 任何一个非0自然数m的立方均可写成m个连续奇数之和。
  • Jenkins的安装方式
  • 网络之华为S5700S-52P-LI交换机系统恢复
  • 蜂窝网络架构
  • 培训第二十二天(mysql数据库主从搭建)
  • 速盾:CDN回源失败都有什么原因?
  • C语言 | Leetcode C语言题解之第328题奇偶链表
  • 8月6日笔记
  • 爱可声助听器:在全球听力市场中破冰前行
  • 华为OD面试 - 最佳升级时间窗(Java JS Python C C++)
  • LE-50821F/FA激光扫描传感器|360°避障雷达之性能参数与配置清单说明
  • 精准洞察农田生态,智慧农业物联网环境监测与数据采集系统来袭
  • sql注入复现(1-14关)
  • Spring Boot-12
  • 【Linux】进程详解
  • python的多线程