当前位置: 首页 > news >正文

神经网络基础-神经网络补充概念-37-其他正则化方法

概念

L1 正则化(Lasso Regularization):L1 正则化通过在损失函数中添加参数的绝对值之和作为惩罚项,促使部分参数变为零,实现特征选择。适用于稀疏性特征选择问题。

L2 正则化(Ridge Regularization):L2 正则化通过在损失函数中添加参数的平方和作为惩罚项,使得参数值保持较小。适用于减小参数大小,减轻参数之间的相关性。

弹性网络正则化(Elastic Net Regularization):弹性网络是 L1 正则化和 L2 正则化的结合,综合了两者的优势。适用于同时进行特征选择和参数限制。

数据增强(Data Augmentation):数据增强是通过对训练数据进行随机变换来扩展数据集,以提供更多的样本。这有助于模型更好地泛化到不同的数据变化。

早停(Early Stopping):早停是一种简单的正则化方法,它通过在训练过程中监控验证集上的性能,并在性能不再改善时停止训练,从而避免模型过拟合训练数据。

批标准化(Batch Normalization):批标准化是一种在每个小批次数据上进行标准化的技术,有助于稳定网络的训练,减少内部协变量偏移,也可以视为一种正则化方法。

权重衰减(Weight Decay):权重衰减是在损失函数中添加参数的权重平方和或权重绝对值之和,以限制参数的大小。

DropConnect:类似于 Dropout,DropConnect 随机地将神经元与其输入连接断开,而不是将神经元的输出置为零。

代码实现

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler# 加载数据
data = load_iris()
X = data.data
y = data.target# 数据预处理
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = keras.utils.to_categorical(y, num_classes=3)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 定义模型
def build_model(regularization=None):model = keras.Sequential([layers.Input(shape=(X_train.shape[1],)),layers.Dense(64, activation='relu', kernel_regularizer=regularization),layers.Dense(32, activation='relu', kernel_regularizer=regularization),layers.Dense(3, activation='softmax')])return model# L1 正则化
model_l1 = build_model(keras.regularizers.l1(0.01))
model_l1.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model_l1.fit(X_train, y_train, epochs=50, batch_size=8, validation_split=0.1)# L2 正则化
model_l2 = build_model(keras.regularizers.l2(0.01))
model_l2.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model_l2.fit(X_train, y_train, epochs=50, batch_size=8, validation_split=0.1)# 弹性网络正则化
model_elastic = build_model(keras.regularizers.l1_l2(l1=0.01, l2=0.01))
model_elastic.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model_elastic.fit(X_train, y_train, epochs=50, batch_size=8, validation_split=0.1)# 早停(Early Stopping)
early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
model_early = build_model()
model_early.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model_early.fit(X_train, y_train, epochs=100, batch_size=8, validation_split=0.1, callbacks=[early_stopping])# 评估模型
print("L1 Regularization:")
model_l1.evaluate(X_test, y_test)print("L2 Regularization:")
model_l2.evaluate(X_test, y_test)print("Elastic Net Regularization:")
model_elastic.evaluate(X_test, y_test)print("Early Stopping:")
model_early.evaluate(X_test, y_test)
http://www.lryc.cn/news/127022.html

相关文章:

  • 掌握Python的X篇_36_定义类、名称空间
  • 回归预测 | MATLAB实现GRU门控循环单元多输入多输出
  • 数据结构--拓扑排序
  • 算法竞赛备赛之搜索与图论训练提升,暑期集训营培训
  • Linux驱动入门(6.2)按键驱动和LED驱动 --- 将逻辑电平与物理电平分离
  • CentOS系统环境搭建(十四)——CentOS7.9安装elasticsearch-head
  • 设计HTML5图像和多媒体
  • 基于YOLOv8模型和Caltech数据集的行人检测系统(PyTorch+Pyside6+YOLOv8模型)
  • Flutter 宽高自适应
  • LeetCode 0833. 字符串中的查找与替换
  • Redis对象和五种常用数据类型
  • 常用的Elasticsearch查询DSL
  • 计算机网络笔记
  • 高效反编译luac文件
  • 密码湘军,融合创新!麒麟信安参展2023商用密码大会,铸牢数据安全坚固堡垒
  • 关于视频监控平台EasyCVR视频汇聚平台建设“明厨亮灶”具体实施方案以及应用
  • 区块链系统探索之路:私钥的压缩和WIF格式详解
  • DiffusionDet: Diffusion Model for Object Detection
  • CH01_重构、第一个示例
  • 学习篇之React Fiber概念及原理
  • 商城-学习整理-高级-全文检索-ES(九)
  • 无人机跟随一维高度避障场景--逻辑分析
  • Android Studio Giraffe控制台乱码
  • 云原生 envoy xDS 动态配置 java控制平面开发 支持restful grpc实现 EDS 动态endpoint配置
  • Linux--实用指令与方法(部分)
  • 常见期权策略类型有哪些?
  • tomcat服务七层搭建动态页面查看
  • sql A表(含有部分B表字段) 向B表插入A表数据
  • 如何用思维导图+Markdown提升工作效率?
  • 睿趣科技:抖音开网店现在做还来得及吗