当前位置: 首页 > article >正文

【深度学习-pytorch篇】4. 正则化方法(Regularization Techniques)

正则化方法(Regularization Techniques)

1. 目标

  • 理解什么是过拟合及其影响
  • 掌握常见正则化技术:L2 正则化、Dropout、Batch Normalization、Early Stopping
  • 能够使用 PyTorch 编程实现这些正则化方法并进行比较分析

2. 数据构造与任务设定

本实验是一个带噪声的回归任务,目标函数为 y = x + N ( 0 , σ 2 ) y = x + \mathcal{N}(0, \sigma^2) y=x+N(0,σ2)。使用均匀分布采样输入 x ∈ [ − 1 , 1 ] x \in [-1, 1] x[1,1]

import numpy as np
import torch
import torch.utils.data as DataN_SAMPLES = 20
NOISE_RATE = 0.4train_x = np.linspace(-1, 1, N_SAMPLES)[:, np.newaxis]
train_y = train_x + np.random.normal(0, NOISE_RATE, train_x.shape)validate_x = np.linspace(-1, 1, N_SAMPLES // 2)[:, np.newaxis]
validate_y = validate_x + np.random.normal(0, NOISE_RATE, validate_x.shape)test_x = np.linspace(-1, 1, N_SAMPLES)[:, np.newaxis]
test_y = test_x + np.random.normal(0, NOISE_RATE, test_x.shape)# 转换为 Tensor
train_x = torch.tensor(train_x, dtype=torch.float32)
train_y = torch.tensor(train_y, dtype=torch.float32)
validate_x = torch.tensor(validate_x, dtype=torch.float32)
validate_y = torch.tensor(validate_y, dtype=torch.float32)
test_x = torch.tensor(test_x, dtype=torch.float32)
test_y = torch.tensor(test_y, dtype=torch.float32)train_dataset = Data.TensorDataset(train_x, train_y)
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=10, shuffle=True)

3. 模型定义

3.1 原始 MLP(无正则化)

import torch.nn as nn
import torch.nn.init as initclass FC_Classifier(nn.Module):def __init__(self, input_dim=1, hidden_dim=100, output_dim=1):super().__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, output_dim)self.activation = nn.ReLU()self._init_weights()def _init_weights(self):init.normal_(self.fc1.weight, mean=0.0, std=0.1)init.constant_(self.fc1.bias, 0)init.normal_(self.fc2.weight, mean=0.0, std=0.1)init.constant_(self.fc2.bias, 0)def forward(self, x):x = self.activation(self.fc1(x))return self.fc2(x)

3.2 Dropout MLP

class DropoutMLP(nn.Module):def __init__(self, dropout_rate=0.5):super().__init__()self.fc1 = nn.Linear(1, 100)self.dropout = nn.Dropout(dropout_rate)self.fc2 = nn.Linear(100, 1)self.activation = nn.ReLU()self._init_weights()def _init_weights(self):init.normal_(self.fc1.weight, mean=0.0, std=0.1)init.constant_(self.fc1.bias, 0)init.normal_(self.fc2.weight, mean=0.0, std=0.1)init.constant_(self.fc2.bias, 0)def forward(self, x):x = self.dropout(self.fc1(x))x = self.activation(x)return self.fc2(x)

3.3 Batch Normalization MLP

class BNMLP(nn.Module):def __init__(self):super().__init__()self.bn_input = nn.BatchNorm1d(1)self.fc1 = nn.Linear(1, 100)self.bn_hidden = nn.BatchNorm1d(100)self.fc2 = nn.Linear(100, 1)self.activation = nn.ReLU()def forward(self, x):x = self.bn_input(x)x = self.fc1(x)x = self.bn_hidden(x)x = self.activation(x)return self.fc2(x)

4. Early Stopping 策略

当验证集误差连续若干轮无提升时,提前停止训练,避免过拟合。

max_patience = 5
patience = 0
best_val_loss = float("inf")
is_early_stop = False

5. RMSNorm 实现与讲解

5.1 原理说明

RMSNorm 是一种替代 LayerNorm 的轻量化归一化方法:

  • 不减均值
  • 仅用激活值的均方根进行归一化
  • 不依赖 batch 维度

数学公式:

RMS ( x ) = 1 n ∑ i = 1 n x i 2 \text{RMS}(x) = \sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2} RMS(x)=n1i=1nxi2

RMSNorm ( x ) = x RMS ( x ) + ϵ ⋅ γ \text{RMSNorm}(x) = \frac{x}{\text{RMS}(x) + \epsilon} \cdot \gamma RMSNorm(x)=RMS(x)+ϵxγ

其中 γ \gamma γ 为可学习参数, ϵ \epsilon ϵ 是一个很小的数避免除以 0。

5.2 代码实现

class RMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.eps = epsdef forward(self, x):rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)return self.weight * x / rms

5.3 与其他归一化对比

方法是否减均值是否除方差是否依赖 batch
BatchNorm
LayerNorm
RMSNorm是 (仅 RMS)

6. 实验建议

  • 尝试不同的 Dropout 比例(如 0.1 / 0.3 / 0.5)并观察效果;
  • 对比是否每层都加 BatchNorm 是否更优;
  • 比较 L2 正则项中 weight decay 的不同取值;
  • 使用 RMSNorm 替代 LayerNorm 做对比实验。
http://www.lryc.cn/news/2392442.html

相关文章:

  • ESP8266+STM32 AT驱动程序,心知天气API 记录时间: 2025年5月26日13:24:11
  • WPF【11_5】WPF实战-重构与美化(MVVM 实战)
  • ⭐️⭐️⭐️ 模拟题及答案 ⭐️⭐️⭐️ 大模型Clouder认证:RAG应用构建及优化
  • kali系统的安装及配置
  • CSS--background-repeat详解
  • Redis的大Key问题如何解决?
  • 影楼精修-AI追色算法解析
  • node入门:安装和npm使用
  • ‘js@https://registry.npmmirror.com/JS/-/JS-0.1.0.tgz‘ is not in this registry
  • el-table-column如何获取行数据的值
  • leetcode450.删除二叉搜索树中的节点:迭代法巧用中间节点应对多场景删除
  • java虚拟机2
  • 自监督软提示调优:跨域NLP新突破
  • Pydantic 学习与使用
  • PCB设计教程【入门篇】——电路分析基础-基本元件(二极管三极管场效应管)
  • 能按需拆分 PDF 为多个文档的工具
  • Apifox 5 月产品更新|数据模型支持查看「引用资源」、调试 AI 接口可实时预览 Markdown、性能优化
  • LiveGBS海康、大华、宇视、华为摄像头GB28181国标语音对讲及语音喊话:摄像头设备与服务HTTPS准备
  • Sqlalchemy 连mssql坑
  • Prompt Engineering 提示工程介绍与使用/调试技巧
  • LLaMaFactory - 支持的模型和模板 常用命令
  • 大模型深度学习之双塔模型
  • MySQL 8主从同步实战指南:从原理到高可用架构落地
  • 瑞数6代jsvmp简单分析(天津电子税x局)
  • 缓存架构方案:Caffeine + Redis 双层缓存架构深度解析
  • AI笔记 - 模型调试 - 调试方式
  • 榕壹云物品回收系统实战案例:基于ThinkPHP+MySQL+UniApp的二手物品回收小程序开发与优化
  • 《软件工程》第 9 章 - 软件详细设计
  • WebVm:无需安装,一款可以在浏览器运行的 Linux 来了
  • 王树森推荐系统公开课 排序06:粗排模型