当前位置: 首页 > news >正文

前馈神经网络 (Feedforward Neural Network, FNN)

代码功能

网络定义:
使用 torch.nn 构建了一个简单的前馈神经网络。
隐藏层使用 ReLU 激活函数,输出层使用 Sigmoid 函数(适用于二分类问题)。
数据生成:
使用经典的 XOR 问题作为数据集。
数据点为二维输入,目标为 0 或 1。
训练过程:
使用二分类交叉熵损失函数 BCELoss。
优化器为 Adam,具有较快的收敛速度。
损失可视化:
每次训练后记录损失并绘制损失曲线。
结果输出:
显示最终预测值,并与真实标签进行比较。
在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt# 1. 定义前馈神经网络
class FeedforwardNN(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(FeedforwardNN, self).__init__()self.fc = nn.Sequential(nn.Linear(input_dim, hidden_dim),  # 输入层到隐藏层nn.ReLU(),  # 激活函数nn.Linear(hidden_dim, output_dim),  # 隐藏层到输出层nn.Sigmoid()  # 输出层的激活函数(适用于二分类问题))def forward(self, x):return self.fc(x)# 2. 创建 XOR 数据集
def create_xor_data():X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)y = np.array([[0], [1], [1], [0]], dtype=np.float32)return X, y# 3. 训练前馈神经网络
def train_fnn():# 数据准备X, y = create_xor_data()X = torch.tensor(X, dtype=torch.float32)y = torch.tensor(y, dtype=torch.float32)# 初始化网络、损失函数和优化器input_dim = X.shape[1]hidden_dim = 10output_dim = 1model = FeedforwardNN(input_dim, hidden_dim, output_dim)criterion = nn.BCELoss()  # 二分类交叉熵损失optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练网络epochs = 1000loss_history = []for epoch in range(epochs):# 前向传播outputs = model(X)loss = criterion(outputs, y)# 反向传播与优化optimizer.zero_grad()loss.backward()optimizer.step()# 记录损失loss_history.append(loss.item())if (epoch + 1) % 100 == 0:print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")# 绘制损失曲线plt.plot(loss_history)plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.show()# 输出训练结果with torch.no_grad():predictions = model(X).round()print("Predictions:", predictions.numpy())print("Ground Truth:", y.numpy())# 运行训练
if __name__ == "__main__":train_fnn()
http://www.lryc.cn/news/486086.html

相关文章:

  • 【Python进阶】Python中的数据库交互:使用SQLite进行本地数据存储
  • ZooKeeper单机、集群模式搭建教程
  • 函数指针示例
  • vue如何实现组件切换
  • 计算机视觉 1-8章 (硕士)
  • 整数唯一分解定理
  • Grass脚本2倍速多账号
  • 15分钟学 Go 第 56 天:架构设计基本原则
  • HTML5 Video(视频)
  • 开源模型应用落地-qwen模型小试-Qwen2.5-7B-Instruct-tool usage入门-串行调用多个tools(三)
  • MySQL:表设计
  • 173. 二叉搜索树迭代器【 力扣(LeetCode) 】
  • 大三学生实习面试经历(1)
  • 【论文复现】STM32设计的物联网智能鱼缸
  • 常见长选项和短选项对应表
  • Ubuntu24 上安装搜狗输入法
  • 【AI图像生成网站Golang】JWT认证与令牌桶算法
  • 关于强化学习的一份介绍
  • Python3.11.9+selenium,获取图片验证码以及输入验证码数字
  • Flutter:事件队列,异步操作,链式调用。
  • 从零开始学习 sg200x 多核开发之 eth0 自动使能并配置静态IP
  • 《TCP/IP网络编程》学习笔记 | Chapter 11:进程间通信
  • 开源模型应用落地-qwen模型小试-Qwen2.5-7B-Instruct-tool usage入门-集成心知天气(二)
  • 通过声纹或者声波来切分一段音频
  • sql专场练习(二)(16-20)完结
  • [ 网络安全介绍 2 ] 网络安全发展现状
  • 《基于Oracle的SQL优化》读书笔记
  • 零基础利用实战项目学会Pytorch
  • Go八股(Ⅵ)Goroutine 以及其中的锁和思想
  • 向潜在安全信息和事件管理 SIEM 提供商提出的六个问题