当前位置: 首页 > news >正文

深入浅出Pytorch函数——torch.nn.init.orthogonal_

分类目录:《深入浅出Pytorch函数》总目录
相关文章:
· 深入浅出Pytorch函数——torch.nn.init.calculate_gain
· 深入浅出Pytorch函数——torch.nn.init.uniform_
· 深入浅出Pytorch函数——torch.nn.init.normal_
· 深入浅出Pytorch函数——torch.nn.init.constant_
· 深入浅出Pytorch函数——torch.nn.init.ones_
· 深入浅出Pytorch函数——torch.nn.init.zeros_
· 深入浅出Pytorch函数——torch.nn.init.eye_
· 深入浅出Pytorch函数——torch.nn.init.dirac_
· 深入浅出Pytorch函数——torch.nn.init.xavier_uniform_
· 深入浅出Pytorch函数——torch.nn.init.xavier_normal_
· 深入浅出Pytorch函数——torch.nn.init.kaiming_uniform_
· 深入浅出Pytorch函数——torch.nn.init.kaiming_normal_
· 深入浅出Pytorch函数——torch.nn.init.trunc_normal_
· 深入浅出Pytorch函数——torch.nn.init.orthogonal_
· 深入浅出Pytorch函数——torch.nn.init.sparse_


torch.nn.init模块中的所有函数都用于初始化神经网络参数,因此它们都在torc.no_grad()模式下运行,autograd不会将其考虑在内。

根据Saxe, A等人在《Exact solutions to the nonlinear dynamics of learning in deep linear neural networks》中描述的方法,用(半)正交矩阵填充输入的张量或变量。输入张量必须至少是2维的,对于更高维度的张量,超出的维度会被展平,视作行等于第一个维度,列等于稀疏矩阵乘积的2维表示,其中非零元素生成自 N ( 0 , std 2 ) N(0, \text{std}^2) N(0,std2)

语法

torch.nn.init.orthogonal_(tensor, gain=1)

参数

  • tensor:[Tensor] 一个 N N N维张量torch.Tensor,其中 N ≥ 2 N\geq 2 N2
  • gain:[可选] 比例因子

返回值

一个torch.Tensor且参数tensor也会更新

实例

w = torch.empty(3, 5)
nn.init.orthogonal_(w)

函数实现

def orthogonal_(tensor, gain=1):r"""Fills the input `Tensor` with a (semi) orthogonal matrix, asdescribed in `Exact solutions to the nonlinear dynamics of learning in deeplinear neural networks` - Saxe, A. et al. (2013). The input tensor must haveat least 2 dimensions, and for tensors with more than 2 dimensions thetrailing dimensions are flattened.Args:tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`gain: optional scaling factorExamples:>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)>>> w = torch.empty(3, 5)>>> nn.init.orthogonal_(w)"""if tensor.ndimension() < 2:raise ValueError("Only tensors with 2 or more dimensions are supported")if tensor.numel() == 0:# no-opreturn tensorrows = tensor.size(0)cols = tensor.numel() // rowsflattened = tensor.new(rows, cols).normal_(0, 1)if rows < cols:flattened.t_()# Compute the qr factorizationq, r = torch.linalg.qr(flattened)# Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdfd = torch.diag(r, 0)ph = d.sign()q *= phif rows < cols:q.t_()with torch.no_grad():tensor.view_as(q).copy_(q)tensor.mul_(gain)return tensor
http://www.lryc.cn/news/131344.html

相关文章:

  • ORACLE中UNION、UNION ALL、MINUS、INTERSECT学习
  • 【k8s、云原生】基于metrics-server弹性伸缩
  • 回归预测 | MATLAB实现WOA-SVM鲸鱼算法优化支持向量机多输入单输出回归预测(多指标,多图)
  • VSCode快捷键
  • 贪心算法求数组中能组成三角形的最大周长
  • VMWare Workstation 17 Pro 网络设置 桥接模式 网络地址转换(NAT)模式 仅主机模式
  • 拒绝摆烂!C语言练习打卡第四天
  • KubeSphere 社区双周报 | Java functions framework 支持 SkyWalking | 2023.8.4-8.17
  • 【学习笔记之java】使用RestTemplate调用第三方接口
  • 数据集成革新:去中心化微服务集群的无限潜能
  • 后端返回可下载的xlsx文件,但是前端接收下载后为乱码
  • 提升资源管理效率必备工具推荐
  • HJ23 删除字符串中出现次数最少的字符
  • 文心一言 VS 讯飞星火 VS chatgpt (76)-- 算法导论7.3 1题
  • Leetcode - 滑动窗口
  • 如何保证数据传输的安全?
  • 政务、商务数据资源有效共享:让数据上“链”,记录每一个存储过程!
  • xml转map工具类
  • C++并发多线程--std::future_status、std::shared_future和std::atomic的使用
  • Redis在Java中的基本使用
  • 4.2 C++ Boost 内存池管理库
  • Django模型基础
  • 导读-Linux简介
  • 判断平面中两射线是否相交的高效方法
  • 基于VUE3+Layui从头搭建通用后台管理系统(前端篇)八:自定义组件封装上
  • RabbitMq交换机类型介绍
  • 中国电信秋招攻略,考试内容分析
  • prompt-engineering-note(面向开发者的ChatGPT提问工程学习笔记)
  • 2011-2021年数字普惠金融指数Bartik工具变量法(含原始数据和Bartik工具变量法代码)
  • [ MySQL ] — 常见函数的使用