当前位置: 首页 > news >正文

深入浅出Pytorch函数——torch.nn.init.kaiming_normal_

分类目录:《深入浅出Pytorch函数》总目录
相关文章:
· 深入浅出Pytorch函数——torch.nn.init.calculate_gain
· 深入浅出Pytorch函数——torch.nn.init.uniform_
· 深入浅出Pytorch函数——torch.nn.init.normal_
· 深入浅出Pytorch函数——torch.nn.init.constant_
· 深入浅出Pytorch函数——torch.nn.init.ones_
· 深入浅出Pytorch函数——torch.nn.init.zeros_
· 深入浅出Pytorch函数——torch.nn.init.eye_
· 深入浅出Pytorch函数——torch.nn.init.dirac_
· 深入浅出Pytorch函数——torch.nn.init.xavier_uniform_
· 深入浅出Pytorch函数——torch.nn.init.xavier_normal_
· 深入浅出Pytorch函数——torch.nn.init.kaiming_uniform_
· 深入浅出Pytorch函数——torch.nn.init.kaiming_normal_
· 深入浅出Pytorch函数——torch.nn.init.trunc_normal_
· 深入浅出Pytorch函数——torch.nn.init.orthogonal_
· 深入浅出Pytorch函数——torch.nn.init.sparse_


torch.nn.init模块中的所有函数都用于初始化神经网络参数,因此它们都在torc.no_grad()模式下运行,autograd不会将其考虑在内。

根据He, K等人于2015年在《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》中描述的方法,用一个正态分布生成值,填充输入的张量或变量。结果张量中的值采样自 N ( 0 , std 2 ) N(0, \text{std}^2) N(0,std2),其中:
std = gain fan_mode \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} std=fan_mode gain

这种方法也被称为He initialisation。

语法

torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

参数

  • tensor:[Tensor] 一个 N N N维张量torch.Tensor
  • a:[float] 这层之后使用的rectifier的斜率系数(ReLU的默认值为0)
  • mode:[str] 可以为fan_infan_out。若为fan_in则保留前向传播时权值方差的量级,若为fan_out则保留反向传播时的量级,默认值为fan_in
  • nonlinearity:[str] 一个非线性函数,即一个nn.functional的名称,推荐使用relu或者leaky_relu,默认值为leaky_relu

返回值

一个torch.Tensor且参数tensor也会更新

实例

w = torch.empty(3, 5)
nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')

函数实现

def kaiming_normal_(tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
):r"""Fills the input `Tensor` with values according to the methoddescribed in `Delving deep into rectifiers: Surpassing human-levelperformance on ImageNet classification` - He, K. et al. (2015), using anormal distribution. The resulting tensor will have values sampled from:math:`\mathcal{N}(0, \text{std}^2)` where.. math::\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}Also known as He initialization.Args:tensor: an n-dimensional `torch.Tensor`a: the negative slope of the rectifier used after this layer (onlyused with ``'leaky_relu'``)mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``preserves the magnitude of the variance of the weights in theforward pass. Choosing ``'fan_out'`` preserves the magnitudes in thebackwards pass.nonlinearity: the non-linear function (`nn.functional` name),recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).Examples:>>> w = torch.empty(3, 5)>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')"""if 0 in tensor.shape:warnings.warn("Initializing zero-element tensors is a no-op")return tensorfan = _calculate_correct_fan(tensor, mode)gain = calculate_gain(nonlinearity, a)std = gain / math.sqrt(fan)with torch.no_grad():return tensor.normal_(0, std)
http://www.lryc.cn/news/132095.html

相关文章:

  • D. Anton and School - 2
  • xcode把包打到高版本的iPhone里
  • PMP项目管理考试小结
  • 【NAS群晖drive异地访问】使用cpolar远程访问内网Synology Drive「内网穿透」
  • 【傅里叶级数与傅里叶变换】数学推导——2、[Part2:T = 2 π的周期函数的傅里叶级数展开] 及 [Part3:周期为2L的函数展开]
  • 【IMX6ULL驱动开发学习】06.DHT11温湿度传感器驱动程序编写与测试
  • sip开发从理论到实践,让你快速入门sip
  • 十三、Linux中必须知道的几个快捷键!!!
  • Django进阶-文件上传
  • clickhouse-数据导入导出方案
  • [JavaWeb]【一】入门JavaWeb开发总概及HTML、CSS、JavaScript
  • Python自动化小技巧18——自动化资产月报(word设置字体表格样式,查找替换文字)
  • FFmpeg5.0源码阅读——VideoToobox硬件解码
  • IDEA 中Tomcat源码环境搭建
  • MATLAB | 七夕节用MATLAB画个玫瑰花束叭
  • 嵌入式开发之configure
  • 深入浅出Pytorch函数——torch.nn.Module
  • 【100天精通python】Day38:GUI界面编程_PyQt 从入门到实战(中)_数据库操作与多线程编程
  • STM32--TIM定时器(3)
  • 爬虫框架- feapder + 爬虫管理系统 - feaplat 的学习简记
  • 设计模式详解-享元模式
  • BDA初级分析——用SQL筛选数据
  • (成功踩坑)electron-builder打包过程中报错
  • 【STM32】 工程
  • Git概述
  • ubuntu 编译安装nginx及安装nginx_upstream_check_module模块
  • 近 2000 台 Citrix NetScaler 服务器遭到破坏
  • MySQL MVCC的详解之Read View
  • 基于springboot+vue的考研资讯平台(前后端分离)
  • 学习网络编程No.3【socket理论实战】