当前位置: 首页 > news >正文

pytorch | 使用vmap对自定义函数进行并行化/ 向量化的执行

0. 参考

  1. pytorch官方文档:https://pytorch.org/docs/stable/generated/torch.func.vmap.html#torch-func-vmap
  2. 关于if语句如何执行:https://github.com/pytorch/functorch/issues/257

1. 问题背景

  1. 笔者现在需要执行如下的功能:
    root_ls = [func(x,b) for x in input]
    因此突然想到pytorch或许存在对于自定义的函数的向量化执行的支持

  2. 一顿搜索发现了from functorch import vmap这种好东西,虽然还在开发中,但是很多功能已经够用了

2. 具体例子

  1. 这里只介绍笔者需要的一个方面,vmap的其他支持还请参阅pytorch官方文档
  2. 自定义函数及其输入:
# 自定义函数
def func_2(t,b):return torch.where((t>5.),t*b,-t)
# 输入t = torch.tensor([1.,2.,3.,4.,5.,6.,7.,8.])
b = torch.tensor([1.],requires_grad=True)
  • 注意1:自定义函数不要出现if,用torch.where替代。至于为什么参阅这个issue,大概的原因是“if isn’t a differentiability requirement;”,强行使用会报错error of Data-dependent control flow
  1. 然后对于b,我们需要扩张到和t同样的大小:
    b_extend = torch.expand_copy(b,size=t.shape) # 必须把b扩张到和t同一个size否则报错

  2. 利用vmap,它返回一个新的函数func_vec ,具有向量化执行的支持,也可以利用autograd求导

# Use vmap() to construct a new function.  
func_vec = vmap(func_2)  				# [N, D], [N, D] -> [N]
ans = func_vec(t,b_extend)
ans.sum().backward()   # 等价于: ans.backward(torch.ones(b_extend.shape))
b_extend.grad          # 可以预见:b的导数是t:在t>5.时导数是t,在t<=5.时导数是0
  1. 全部代码:
import torch
from functorch import vmap# if分支isn't a differentiability requirement;
def func(t,b):tmp = t*bif tmp > 5:     # error: Data-dependent control flowroot = t*belse:root = -treturn rootdef func_2(t,b):return torch.where((t>5.),t*b,-t)t = torch.tensor([1.,2.,3.,4.,5.,6.,7.,8.])
b = torch.tensor([1.],requires_grad=True)
b_extend = torch.expand_copy(b,size=t.shape)    # 必须把b扩张到和t同一个size否则报错
b_extend.retain_grad()print(f"shape of t:{t.shape}, shape of b_extend:{b_extend.shape}")
# shape of t:torch.Size([8]), shape of b_extend:torch.Size([8])# Use vmap() to construct a new function.  # [D], [D] -> []
func_vec = vmap(func_2)  # [N, D], [N, D] -> [N]
ans = func_vec(t,b_extend)
ans.sum().backward()   # 等价于: ans.backward(torch.ones(b_extend.shape))b_extend.grad          # 可以预见:b的导数是t:在t>5.时导数是t,在t<=5.时导数是0
# tensor([0., 0., 0., 0., 0., 6., 7., 8.])
  1. 问题在于,它真的比root_ls = [func(x,b) for x in input]这种快吗?在笔者的设计中确实是使用vmap更快一些,但是不见得总是好用,只是在pytorch中写大量的for实在是太愚蠢了QAQ

感谢阅读,欢迎交流

http://www.lryc.cn/news/59645.html

相关文章:

  • Docker部署RabbitMQ(单机,集群,仲裁队列)
  • 生活污水处理设备选购指南
  • 奥威BI数据可视化大屏分享|多场景、多风格
  • 超越时空:加速预训练语言模型的训练
  • 数据库管理系统PostgreSQL部署安装完整教程
  • 有学生问我,重构是什么?我应该如何回答?
  • 交际场合---英文单词
  • 【网络安全】文件上传漏洞及中国蚁剑安装
  • [Java]面向对象高级篇
  • 苹果应用商店上架流程
  • 基于Eclipse下使用arm gcc开发GD32调用printf
  • 5个降低云成本并提高IT运营效率的优先事项
  • 95-拥塞控制
  • Linux常见操作命令【二】
  • Linux驱动中断和定时器
  • 表达式和函数
  • C#基础复习
  • Windows服务器使用代码SSH免密登录并执行脚本
  • (Deep Learning)交叉验证(Cross Validation)
  • 通俗举例讲解动态链接】静态链接
  • K8S部署常见问题归纳
  • Redis高可用
  • Hyperledger Fabric 2.2版本环境搭建
  • macOS Monterey 12.6.5 (21G531) Boot ISO 原版可引导镜像
  • 【软件设计师13】数据库设计
  • SpringMVC的全注解开发
  • C# | 导出DataGridView中的数据到Excel、CSV、TXT
  • 新规拉开中国生成式AI“百团大战”序幕?
  • 日撸 Java 三百行day31
  • 在线绘制思维导图