当前位置: 首页 > news >正文

深入浅出Pytorch函数——torch.nn.init.dirac_

分类目录:《深入浅出Pytorch函数》总目录
相关文章:
· 深入浅出Pytorch函数——torch.nn.init.calculate_gain
· 深入浅出Pytorch函数——torch.nn.init.uniform_
· 深入浅出Pytorch函数——torch.nn.init.normal_
· 深入浅出Pytorch函数——torch.nn.init.constant_
· 深入浅出Pytorch函数——torch.nn.init.ones_
· 深入浅出Pytorch函数——torch.nn.init.zeros_
· 深入浅出Pytorch函数——torch.nn.init.eye_
· 深入浅出Pytorch函数——torch.nn.init.dirac_
· 深入浅出Pytorch函数——torch.nn.init.xavier_uniform_
· 深入浅出Pytorch函数——torch.nn.init.xavier_normal_
· 深入浅出Pytorch函数——torch.nn.init.kaiming_uniform_
· 深入浅出Pytorch函数——torch.nn.init.kaiming_normal_
· 深入浅出Pytorch函数——torch.nn.init.trunc_normal_
· 深入浅出Pytorch函数——torch.nn.init.orthogonal_
· 深入浅出Pytorch函数——torch.nn.init.sparse_


torch.nn.init模块中的所有函数都用于初始化神经网络参数,因此它们都在torc.no_grad()模式下运行,autograd不会将其考虑在内。

该函数用 Dirac δ \text{Dirac}\delta Diracδ 函数来填充3-5维输入张量或变量,在卷积层尽可能多的保存输入通道特征。

语法

torch.nn.init.dirac_(tensor, groups=1)

参数

  • tensor:[Tensor] 一个3~5维张量torch.Tensor
  • groups:[int] conv层中的组数,默认值为1

返回值

一个torch.Tensor且参数tensor也会更新

实例

w = torch.empty(3, 16, 5, 5)
nn.init.dirac_(w)
w = torch.empty(3, 24, 5, 5)
nn.init.dirac_(w, 3)

函数实现

def dirac_(tensor, groups=1):r"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Diracdelta function. Preserves the identity of the inputs in `Convolutional`layers, where as many input channels are preserved as possible. In caseof groups>1, each group of channels preserves identityArgs:tensor: a {3, 4, 5}-dimensional `torch.Tensor`groups (int, optional): number of groups in the conv layer (default: 1)Examples:>>> w = torch.empty(3, 16, 5, 5)>>> nn.init.dirac_(w)>>> w = torch.empty(3, 24, 5, 5)>>> nn.init.dirac_(w, 3)"""dimensions = tensor.ndimension()if dimensions not in [3, 4, 5]:raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")sizes = tensor.size()if sizes[0] % groups != 0:raise ValueError('dim 0 must be divisible by groups')out_chans_per_grp = sizes[0] // groupsmin_dim = min(out_chans_per_grp, sizes[1])with torch.no_grad():tensor.zero_()for g in range(groups):for d in range(min_dim):if dimensions == 3:  # Temporal convolutiontensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1elif dimensions == 4:  # Spatial convolutiontensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,tensor.size(3) // 2] = 1else:  # Volumetric convolutiontensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,tensor.size(3) // 2, tensor.size(4) // 2] = 1return tensor
http://www.lryc.cn/news/132528.html

相关文章:

  • [Go版]算法通关村第十三关青铜——数字数学问题之统计问题、溢出问题、进制问题
  • GPT-4一纸重洗:从97.6%降至2.4%的巨大挑战
  • 大数据Flink学习圣经:一本书实现大数据Flink自由
  • 什么是微服务?
  • 【C++入门到精通】C++入门 —— 容器适配器、stack和queue(STL)
  • 系统架构设计专业技能 · 软件工程之需求工程
  • 2023国赛数学建模E题思路模型代码 高教社杯
  • Baumer工业相机堡盟工业相机如何通过BGAPISDK设置相机的Bufferlist序列(C++)
  • 从 Ansible Galaxy 使用角色
  • ROS与STM32通信(二)-pyserial
  • [oneAPI] 使用Bert进行中文文本分类
  • 【数据治理】什么是数据库归档
  • AI代码补全 案例 - 阿里云智能编码插件Cosy
  • 【Linux】进程信号篇Ⅰ:信号的产生(signal、kill、raise、abort、alarm)、信号的保存(core dump)
  • 漏洞指北-VulFocus靶场专栏-中级03
  • 【leetcode 力扣刷题】数组交集(数组、set、map都可实现哈希表)
  • MySQL 8.0.31 登录提示caching_sha2_password问题解决方法
  • [Google] DeepMind Gemini: 新一代LLM结合AlphaGo技术将力压 GPT-4|未来 AI 领域的新巨头
  • Maven高级
  • 【视觉SLAM入门】5.2. 2D-3D PNP 3D-3D ICP BA非线性优化方法 数学方法SVD DLT
  • 人脸老化预测(Python)
  • AWS SDK 3.x for .NET Framework 4.0 可行性测试
  • 两个list。如何使用流的写法将一个list中的对象中的某些属性根据另外一个list中的属性值赋值进去?
  • 美国陆军希望大数据技术能够帮助保护其云安全
  • vue 文字跑马灯
  • 开源ChatGPT系统源码 采用NUXT3+Laravel9后端开发 前后端分离版本
  • 【LeetCode|数据结构】剑指 Offer 33. 二叉搜索树的后序遍历序列
  • 自定义协程
  • 【Atcoder】 [ABC240Ex] Sequence of Substrings
  • 真机二阶段之堆叠技术