当前位置: 首页 > news >正文

深度学习训练中的种子设置


文章目录

  • 深度学习训练中的种子设置
    • 1. 为什么需要设置随机种子
    • 2. 随机种子的设置及使用


深度学习训练中的种子设置

1. 为什么需要设置随机种子

在神经网络训练过程中,经常会通过随机的方式对一些数据进行初始化:

1、随机权重,网络有些部分的权重没有预训练,它的值则是随机初始化的,每次随机初始化不同会导致结果不同。
2、随机数据增强,一般来讲网络训练会进行数据增强,特别是少量数据的情况下,数据增强一般会随机变化光照、对比度、扭曲等,也会导致结果不同。
3、随机数据读取,喂入训练数据的顺序也会影响结果。

如果每次的实验都进行随机操作,那么实验的结果也会具有随机性,即相同的训练数据,相同的超参数,但是最终的结果可能会相差好几个百分点。

如何解决随机带来的问题呢?即,使得我们每次实验具有可复现性?

在计算机中的随机,其实不是真随机,都是伪随机,通过设置一个随机数种子,就能使得每次随机产生的结果都相同。

2. 随机种子的设置及使用

一般训练会用到多个库包含有关random的内容。

在pytorch构建的网络中,一般都是使用下面三个库来获得随机数,我们需要对三个库都设置随机种子:
1、torch库;
2、numpy库;
3、random库。

通常只会在两个地方使用这些random操作:初始化操作和数据加载操作,只需要在这两个操作之前对种子进行设置即可。

#---------------------------------------------------#
#   设置种子
#---------------------------------------------------#
def seed_everything(seed=11):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False

在初始化操作之前使用seed_everything()进行种子设置。

torch.backends.cudnn.deterministic=True用于保证CUDA 卷积运算的结果确定。
torch.backends.cudnn.benchmark=False是用于保证数据变化的情况下,减少网络效率的变化。为True的话容易降低网络效率。

Pytorch一般使用Dataloader来加载数据,Dataloader一般会使用多worker加载多进程来加载数据,此时我们需要使用Dataloader自带的worker_init_fn函数初始化Dataloader启动的多进程,这样才能保证多进程数据加载时数据的确定性。

#---------------------------------------------------#
#   设置Dataloader的种子
#---------------------------------------------------#
def worker_init_fn(worker_id, rank, seed):worker_seed = rank + seedrandom.seed(worker_seed)np.random.seed(worker_seed)torch.manual_seed(worker_seed)

小结:

  • 通过设置随机数种子,利用伪随机,使得每次实验具有可复现性
  • 设置种子环节:在初始化参数之前设置随机数种子,在使用Dataloader加载数据时配置worker_seed

参考文章:神经网络学习小记录74——Pytorch 设置随机种子Seed来保证训练结果唯一_pytorch seed-CSDN博客

http://www.lryc.cn/news/329625.html

相关文章:

  • LLM:函数调用(Function Calling)
  • ssm 房屋销售管理系统开发mysql数据库web结构java编程计算机网页源码eclipse项目
  • MySQL使用ALTER命令创建与修改索引
  • 54 npm run serve 和 npm run build 输出的关联和差异
  • iOS —— 初识KVO
  • 什么是HTTP? HTTP 和 HTTPS 的区别?
  • 微信小程序如何进行npm导入组件
  • MySQL编程实战LeetCode经典考题
  • 发生播放错误,即将重试 jellyfin
  • BIONIOAIO
  • SpringSecurity学习总结(三更草堂)
  • C++20中的jthread
  • Xception模型详解
  • 【合合TextIn】AI构建新质生产力,合合信息Embedding模型助力专业知识应用
  • Flutter 拦截系统键盘,显示自定义键盘
  • 内存泄漏是什么?如何避免内存泄漏?
  • linux 中的syslog的含义和用法
  • kubernetes(K8S)学习(一):K8S集群搭建(1 master 2 worker)
  • 巧克力(蓝桥杯)
  • Python爬虫之pyquery和parsel的使用
  • 移动硬盘怎么加密?移动硬盘加密软件有哪些?
  • openEuler 22.03 安装 .NET 8.0
  • 【转载】OpenCV ECC图像对齐实现与代码演示(Python / C++源码)
  • 每日一题(相交链表 )
  • C#WPF控件大全
  • 好书推荐 《AIGC重塑金融》
  • 【Linux】权限理解
  • 插入排序、归并排序、堆排序和快速排序的稳定性分析
  • 【pytest、playwright】多账号同时操作
  • 软考 系统架构设计师系列知识点之云原生架构设计理论与实践(8)