当前位置: 首页 > article >正文

pytorch小记(二十二):全面解读 PyTorch 的 `torch.cumprod`——累积乘积详解与实战示例

pytorch小记(二十二):全面解读 PyTorch 的 `torch.cumprod`——累积乘积详解与实战示例

    • 一、函数签名与参数说明
    • 二、基础用法
      • 1. 一维张量累积乘积
      • 2. 二维张量按行/按列累积
    • 三、`dtype` 参数:避免整数溢出与提升精度
    • 四、典型应用场景
      • 1. 几何序列生成
      • 2. 概率分布的累积乘积
      • 3. 模型门控或权重衰减
    • 五、进阶示例:预分配 `out` 张量
    • 六、小结


在深度学习与科学计算中,往往需要沿某个维度追踪“前面所有元素的乘积”,比如几何序列计算、概率分布构建、模型门控/权重衰减等场景。PyTorch 提供的 torch.cumprod 函数可以一行代码搞定这一需求。本文将从函数签名、参数含义、基础用法,到进阶示例、典型应用场景,为你带来最全面的讲解,并附上丰富示例助你快速上手。


一、函数签名与参数说明

torch.cumprod(input: Tensor,dim: int,*,dtype: Optional[torch.dtype] = None,out: Optional[Tensor] = None
) → Tensor
  • input:任意维度的输入张量。
  • dim:指定沿哪个维度做累积乘积(0 表示第一个维度,以此类推)。
  • dtype(可选):输出张量的数据类型。如果原张量为整数且会溢出,可通过将其提升到更宽数据类型来避免溢出。
  • out(可选):预先分配好的张量,用于存储输出,避免额外内存分配。

二、基础用法

1. 一维张量累积乘积

import torchx = torch.tensor([1, 2, 3, 4])
y = torch.cumprod(x, dim=0)
print(y)  # tensor([ 1,  2,  6, 24])
  • y[0] = 1
  • y[1] = 1 * 2 = 2
  • y[2] = 1 * 2 * 3 = 6
  • y[3] = 1 * 2 * 3 * 4 = 24

2. 二维张量按行/按列累积

x2 = torch.tensor([[1, 2, 3],[4, 5, 6]])
# 沿行(dim=1)累积
row_prod = torch.cumprod(x2, dim=1)
print(row_prod)
# tensor([[  1,   2,   6],
#         [  4,  20, 120]])# 沿列(dim=0)累积
col_prod = torch.cumprod(x2, dim=0)
print(col_prod)
# tensor([[1, 2,  3],
#         [4, 10, 18]])

三、dtype 参数:避免整数溢出与提升精度

input 为大整数且乘积超出类型范围时,会导致溢出。此时可指定更宽的数据类型:

x_int = torch.tensor([1000, 1000, 1000], dtype=torch.int32)
# 默认 int32 会溢出
print(torch.cumprod(x_int, dim=0))
# tensor([1000,  -727,  -728], dtype=torch.int32)# 改为 int64 避免溢出
print(torch.cumprod(x_int, dim=0, dtype=torch.int64))
# tensor([      1000,    1000000, 1000000000])

四、典型应用场景

1. 几何序列生成

几何序列 a , a r , a r 2 , … a, ar, ar^2, … a,ar,ar2, 可用累积乘积实现:

a, r, n = 2.0, 0.5, 5
ratios = torch.full((n,), r)               # [r, r, r, r, r]
geom = a * torch.cumprod(ratios, dim=0)
print(geom)
# tensor([1.0000, 0.5000, 0.2500, 0.1250, 0.0625])

2. 概率分布的累积乘积

在构建离散分布的乘积模型时,用累乘来得到联合概率:

probs = torch.tensor([0.2, 0.3, 0.5])
# 标准化(确保和为1)
probs = probs / probs.sum()
# 获取依次乘积(注意:乘积非累加,因此并非 CDF)
joint = torch.cumprod(probs, dim=0)
print(joint)
# tensor([0.2000, 0.0600, 0.0300])

3. 模型门控或权重衰减

在 RNN、Transformer 等模型中,若需要对前 n 层或时间步做指数衰减,可用累积乘积计算衰减系数:

decay_rates = torch.linspace(0.9, 0.5, steps=4)  # 每层不同衰减率
coeffs = torch.cumprod(decay_rates, dim=0)      # 累积得到层间总衰减
print(coeffs)
# tensor([0.9000, 0.7200, 0.5040, 0.2520])

五、进阶示例:预分配 out 张量

为了在高性能场景下避免额外内存分配,可以先分配好输出张量,再将结果写入:

x = torch.arange(1, 1001, dtype=torch.float32)
out = torch.empty_like(x)
torch.cumprod(x, dim=0, out=out)
print(out[:5])  # tensor([1., 2., 6., 24., 120.])

六、小结

  • 功能torch.cumprod 沿指定维度计算输入张量的累计乘积,返回新张量。

  • 关键参数

    • dim:累积轴;
    • dtype:避免整数溢出/提升精度;
    • out:预分配输出提高性能。
  • 常见应用

    1. 几何序列生成;
    2. 概率分布乘积;
    3. 模型门控/权重衰减;
    4. 其它需要“前缀乘积”场景。
http://www.lryc.cn/news/2378785.html

相关文章:

  • TTS:F5-TTS 带有 ConvNeXt V2 的扩散变换器
  • 强化学习笔记(一)基本概念
  • 大型语言模型中的QKV与多头注意力机制解析
  • 基于地图的数据可视化:解锁地理数据的真正价值
  • 利用自适应双向对比重建网络与精细通道注意机制实现图像去雾化技术的PyTorch代码解析
  • 分布式链路跟踪
  • 刷leetcodehot100返航版--二叉树
  • chmod 777含义:
  • AGI大模型(21):混合检索之混合搜索
  • 双重差分模型学习笔记4(理论)
  • Mysql 8.0.32 union all 创建视图后中文模糊查询失效
  • Jenkins 执行器(Executor)如何调整限制?
  • Android 中 权限分类及申请方式
  • 编程错题集系列(一)
  • 【原创】基于视觉大模型gemma-3-4b实现短视频自动识别内容并生成解说文案
  • Spark(32)SparkSQL操作Mysql
  • 基于 Python 的界面程序复现:标准干涉槽型设计计算及仿真
  • c++成员函数返回类对象引用和直接返回类对象的区别
  • AGI大模型(20):混合检索之rank_bm25库来实现词法搜索
  • 数字化转型- 数字化转型路线和推进
  • 字体样式集合
  • IP68防水Type-C连接器实测:水下1米浸泡72小时的生存挑战
  • 【技术追踪】InverseSR:使用潜在扩散模型进行三维脑部 MRI 超分辨率重建(MICCAI-2023)
  • React学习(二)-变量
  • list重点接口及模拟实现
  • 【自然语言处理与大模型】大模型(LLM)基础知识④
  • 系统架构设计(九):分布式架构与微服务
  • Java 框架配置自动化:告别冗长的 XML 与 YAML 文件
  • vue使用Pinia实现不同页面共享token
  • 遨游科普:三防平板是什么?有什么功能?