当前位置: 首页 > news >正文

多层感知机的简洁实现

详解net = nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10))?

网络结构总览

这个网络包含 3 层(1 个输入层、1 个隐藏层、1 个输出层),整体结构如下:

输入层 (784) → 隐藏层 (256) → ReLU激活 → 输出层 (10)

逐层详解

1. nn.Flatten()
  • 作用:将输入的多维张量(如 [batch_size, 1, 28, 28] 的图像)展平为一维向量([batch_size, 784])。
  • 参数:无(自动计算输入维度)。
  • 示例
    输入形状:[256, 1, 28, 28](批量大小 256,1 通道,28×28 像素)
    输出形状:[256, 784](256 个样本,每个样本 784 维)。
2. nn.Linear(784, 256)
  • 作用:全连接层,实现线性变换 y = xW + b
  • 参数
    • in_features=784:输入特征数(对应图像展平后的 784 维)。
    • out_features=256:输出特征数(隐藏层神经元数量)。
  • 权重矩阵 W:形状为 [784, 256],随机初始化。
  • 偏置向量 b:形状为 [256],默认初始化为零。
  • 计算过程
    H = X @ W + b,其中 X 是输入,H 是隐藏层输出。
3. nn.ReLU()
  • 作用:非线性激活函数,引入非线性能力,解决线性模型无法拟合复杂模式的问题。
  • 数学公式ReLU(x) = max(0, x)
  • 特性
    • 当输入 x > 0 时,输出等于输入;当 x ≤ 0 时,输出为 0。
    • 缓解梯度消失问题,加速网络训练。
    • 使网络具有稀疏性(部分神经元输出为 0)。
4. nn.Linear(256, 10)
  • 作用:输出层,将隐藏层的 256 维表示映射到 10 个类别(对应分类任务的 10 个标签)。
  • 参数
    • in_features=256:输入特征数(来自上一层的 256 维隐藏层)。
    • out_features=10:输出特征数(对应 10 个类别)。
  • 输出含义
    输出 10 个未归一化的分数(logits),通常需要通过 Softmax 转换为概率分布。

数据流动过程

  1. 输入:28×28 像素的图像(展平为 784 维向量)。
  2. 第一层线性变换784 → 256,生成隐藏层表示。
  3. ReLU 激活:对隐藏层输出逐元素应用 ReLU,引入非线性。
  4. 第二层线性变换256 → 10,生成 10 个类别的预测分数。
  5. 后续处理(通常在训练 / 推理时):
    • 训练时:通过 nn.CrossEntropyLoss 计算损失,该损失函数内部包含 Softmax。
    • 推理时:直接取输出的最大值对应的索引作为预测类别(无需 Softmax)。

为什么这样设计多层感知机? 

  1. Flatten 层:将图像的空间结构转换为一维向量,适合全连接网络处理。
  2. 隐藏层 (256 神经元):增加模型复杂度,学习图像的抽象特征(如边缘、纹理)。
  3. ReLU 激活:避免梯度消失,使网络能够学习复杂的非线性关系。
  4. 输出层 (10 神经元):对应 10 个类别,符合分类任务需求。

倘若移除ReLU 激活函数会如何? 

如果移除 ReLU 激活函数,网络将退化为线性模型,无法学习复杂模式:

  1. 只能拟合线性关系

    对于非线性可分的数据(如 MNIST 图像),线性模型的准确率会极低(约 10% 随机猜测水平)。
  2. 表达能力严重受限

    无法学习图像中的边缘、纹理等多层次特征,只能处理简单的线性分类问题。
  3. 训练可能停滞

    没有非线性激活,梯度在多层传递后会变得非常小,导致参数更新缓慢甚至停滞。

线性模型的局限性非线性激活的作用?

 

线性模型为什么不够用?

假设我们有一个两层线性网络:

y = W₂(W₁x + b₁) + b₂

展开后得到:

y = (W₂W₁)x + (W₂b₁ + b₂)

这本质上等价于一个单层线性变换 y = W'x + b',其中:

  • W' = W₂W₁(权重矩阵相乘)
  • b' = W₂b₁ + b₂(偏置向量线性组合)

无论堆叠多少层线性层,最终都可以简化为一个单层线性模型。这种模型只能学习线性可分的模式,例如:

  • 二维空间中的直线分类边界
  • 三维空间中的平面分割

但现实世界的问题(如图像、语音、自然语言)通常是非线性的,需要更复杂的函数来拟合。

非线性激活如何打破限制?

如果在两层线性层之间插入非线性激活函数(如 ReLU),模型就不再能简化为单层线性模型:

        y = W₂·ReLU(W₁x + b₁) + b₂ # 注意:ReLU无法被合并到线性变换中

ReLU 的数学形式是:

ReLU(x) = max(0, x)​​​​​​​

它的关键作用是:

  • 分段线性:在x>0时输出x,在x≤0时输出 0
  • 引入非线性:通过分段的方式,使整个函数不再是线性的

直观来说,ReLU 让网络能够学习分段线性函数,而足够多的分段可以逼近任意复杂的曲线!

完整代码

"""
文件名: 4.3  多层感知机的简洁实现
作者: 墨尘
日期: 2025/7/12
项目名: dl_env
备注: 
"""
import torch
from torch import nn
from d2l import torch as d2l
# 手动显示图像(关键)
import matplotlib.pyplot as plt# 模型
def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)if __name__ == '__main__':# 定义网络net = nn.Sequential(nn.Flatten(),  # 输入: [batch_size, 1, 28, 28] → [batch_size, 784]nn.Linear(784, 256),  # 线性变换: 784 → 256nn.ReLU(),  # ReLU激活nn.Linear(256, 10)  # 线性变换: 256 → 10)net.apply(init_weights)batch_size, lr, num_epochs = 256, 0.1, 10# 采用train_ch3需要自定义损失函数,而train_ch6则不需要,本身自带有损失函数loss = nn.CrossEntropyLoss(reduction='none')# sgd函数训练trainer = torch.optim.SGD(net.parameters(), lr=lr)# 数据集train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)# train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)# 训练模型d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())# 显示图像plt.show(block=True)  # block=True 确保窗口阻塞,直到手动关闭

 实验结果

 

 

http://www.lryc.cn/news/586247.html

相关文章:

  • Spring Cloud Gateway中常见的过滤器
  • 【时间之外】尘封的智能套件复活记
  • 【QGC】深入解析 QGC 配置管理
  • Gas and Gas Price
  • 闲庭信步使用图像验证平台加速FPGA的开发:第十课——图像gamma矫正的FPGA实现
  • Git企业级开发(最终篇)
  • 闲庭信步使用图像验证平台加速FPGA的开发:第十一课——图像均值滤波的FPGA实现
  • TCP的socket编程
  • OneCode 3.0架构深度剖析:工程化模块管理与自治UI系统的设计与实现
  • 多路选择器的学习
  • 前端面试专栏-算法篇:24. 算法时间与空间复杂度分析
  • TCP与UDP协议详解:网络世界的可靠信使与高速快递
  • 苍穹外卖-day06
  • docker—— harbor私有仓库部署管理
  • Linux进程管理的核心:task_struct中的双链表与网状数据结构
  • Linux驱动08 --- 数据库
  • C++ Map 和 Set 详解:从原理到实战应用
  • 【Spring AOP】什么是AOP?切点、连接点、通知和切面
  • Python 实战:构建 Git 自动化助手
  • RabbitMQ面试精讲 Day 1:RabbitMQ核心概念与架构设计
  • 网络安全初级第一次作业
  • 医疗AI前端开发中的常见问题分析和解决方法
  • Filament引擎(三) ——引擎渲染流程
  • 【GESP】C++ 2025年6月一级考试-客观题真题解析
  • Apache Iceberg数据湖高级特性及性能调优
  • PyTorch神经网络实战:从零构建图像分类模型
  • 【文献阅读】DEPTH PRO: SHARP MONOCULAR METRIC DEPTH IN LESS THAN A SECOND
  • Rust Web 全栈开发(五):使用 sqlx 连接 MySQL 数据库
  • Spring 框架中的设计模式:从实现到思想的深度解析
  • 单链表的题目,咕咕咕