当前位置: 首页 > news >正文

【TORCH】torch.normal()中的size参数

torch.normal() 函数中,size 参数用于指定生成张量的形状。torch.normal() 函数用于从正态(高斯)分布中生成随机数。函数的基本形式是:

torch.normal(mean, std, size)
  • mean:均值,可以是标量或张量。如果是标量,表示生成的所有元素的均值;如果是张量,表示对应位置元素的均值。
  • std:标准差,可以是标量或张量。如果是标量,表示生成的所有元素的标准差;如果是张量,表示对应位置元素的标准差。
  • size:生成张量的形状。

以下是一些示例,展示了如何使用 size 参数生成不同形状的张量:

示例代码

import torch# 生成一个形状为(3,)的一维张量
mean = 0.0
std = 1.0
size = (3,)
tensor_1d = torch.normal(mean, std, size)
print("1D Tensor:", tensor_1d)# 生成一个形状为(2, 3)的二维张量
size = (2, 3)
tensor_2d = torch.normal(mean, std, size)
print("2D Tensor:", tensor_2d)# 生成一个形状为(2, 3, 4)的三维张量
size = (2, 3, 4)
tensor_3d = torch.normal(mean, std, size)
print("3D Tensor:", tensor_3d)# 生成一个形状为(3, 3)的二维张量,均值和标准差为张量
mean_tensor = torch.tensor([[0.0, 1.0, 2.0],[0.0, 1.0, 2.0],[0.0, 1.0, 2.0]])
std_tensor = torch.tensor([[1.0, 1.0, 1.0],[1.0, 1.0, 1.0],[1.0, 1.0, 1.0]])
size = (3, 3)
tensor_2d_with_tensor_mean_std = torch.normal(mean_tensor, std_tensor)
print("2D Tensor with tensor mean and std:", tensor_2d_with_tensor_mean_std)

输出示例

1D Tensor: tensor([ 0.0343, -0.4731,  1.1844])
2D Tensor: tensor([[ 0.1239,  1.1049,  0.4560],[-0.3104,  0.6228,  0.2698]])
3D Tensor: tensor([[[ 0.0793, -0.2101,  0.7634,  0.1921],[-0.1220, -0.9352, -1.3496, -0.6405],[ 0.3821,  0.2745,  0.1925,  0.4075]],[[-0.8833,  1.1430,  0.3650, -0.7995],[ 0.1403, -0.2226, -0.2483,  0.5914],[-0.3337,  0.3735, -0.0515, -1.1255]]])
2D Tensor with tensor mean and std: tensor([[-0.2971,  1.5936,  2.3287],[ 1.0322,  1.3414,  1.7221],[-0.6370,  1.5202,  1.3766]])

说明

  • 一维张量size = (3,) 生成一个形状为 (3,) 的一维张量。
  • 二维张量size = (2, 3) 生成一个形状为 (2, 3) 的二维张量。
  • 三维张量size = (2, 3, 4) 生成一个形状为 (2, 3, 4) 的三维张量。
  • 均值和标准差为张量:如果 meanstd 是张量,那么生成的张量每个元素的均值和标准差分别由对应位置的值决定。

通过指定不同的 size 参数,可以生成不同形状的张量。这对于初始化神经网络的权重特别有用,因为不同层的权重通常具有不同的形状。

如果您有更多问题或需要进一步的帮助,请告诉我!

http://www.lryc.cn/news/393295.html

相关文章:

  • 【第20章】MyBatis-Plus逻辑删除支持
  • 【IT领域新生必看】 Java编程中的重载(Overloading):初学者轻松掌握的全方位指南
  • python转文本为语音并播放
  • 解锁高效软件测试:虚拟机助力提升测试流程的秘诀
  • 创建vue3项目
  • 中国网络安全审查认证和市场监管大数据中心数据合规官CCRC-DCO
  • Web漏洞扫描工具AppScan与AWVS测评及使用体验
  • 瞰景Smart3D使用体验分享
  • Android系统adb shell dumpsys activity processes
  • vue侦听器watch()
  • 如何用Python向PPT中批量插入图片
  • C# Socket
  • node的下载、安装、配置和使用(node.js下载安装和配置、npm命令汇总、cnpm的使用)
  • 深度卷积神经网络 AlexNet
  • 【刷题汇总--大数加法、 链表相加(二)、大数乘法】
  • 基于Java的网上花店系统
  • uniApp 封装VUEX
  • 最长公共子序列求长度和输出子序列C代码
  • 安卓Framework开发快速分析日志及定位源码
  • 数据结构算法之B树
  • 【图卷积网络】GCN基础原理简单python实现
  • 【话题】AI是在帮助开发者还是取代他们
  • 精通Perl正则表达式修饰符:提升文本处理能力的艺术
  • 【web前端HTML+CSS+JS】--- HTML学习笔记01
  • Go 语言入门(一)
  • 爬虫笔记20——票星球抢票脚本的实现
  • DDR3(三)
  • JDK都出到20多了,你还不会使用JDK8的Stream流写代码吗?
  • QT slots 函数
  • pycharm如何使用jupyter