当前位置: 首页 > news >正文

pytorch 中_call_impl()函数

记录pytorch 版本中的 nn.Module() 重要函数

1. _call_impl()

1.1 torch1.7.1 版本

    def _call_impl(self, *input, **kwargs):for hook in itertools.chain(_global_forward_pre_hooks.values(),self._forward_pre_hooks.values()):result = hook(self, input)if result is not None:if not isinstance(result, tuple):result = (result,)input = resultif torch._C._get_tracing_state():result = self._slow_forward(*input, **kwargs)else:result = self.forward(*input, **kwargs)for hook in itertools.chain(_global_forward_hooks.values(),self._forward_hooks.values()):hook_result = hook(self, input, result)if hook_result is not None:result = hook_resultif (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):var = resultwhile not isinstance(var, torch.Tensor):if isinstance(var, dict):var = next((v for v in var.values() if isinstance(v, torch.Tensor)))else:var = var[0]grad_fn = var.grad_fnif grad_fn is not None:for hook in itertools.chain(_global_backward_hooks.values(),self._backward_hooks.values()):wrapper = functools.partial(hook, self)functools.update_wrapper(wrapper, hook)grad_fn.register_hook(wrapper)return result

以上的函数的功能作用解释如下:

提供的代码是 PyTorch 模块方法的 _call_impl 实现。当模块用作可调用对象时,通常使用输入数据调用它时,将调用此方法。让我们逐步分解代码以了解其功能:

for hook in itertools.chain(_global_forward_pre_hooks.values(), self._forward_pre_hooks.values()):
此循环遍历两个钩子集合:
_global_forward_pre_hooks 和 _forward_pre_hooks 。
钩子是可以注册为在神经网络向前或向后传递期间在特定点执行的函数。
这些预置挂钩将在模块的实际前向传递之前执行。
result = hook(self, input):
对于每个钩子,使用 self (模块)和 input 参数调用 hook 该函数。
if result is not None:
如果钩子返回非 None 值,则表示钩子修改了输入数据,并且此修改后的数据将成为循环中下一个挂钩的新输入。
if not isinstance(result, tuple): result = (result,)
钩子的结果将转换为元组(如果它还没有元组)。这是为了处理钩子可能返回单个值而不是元组的情况。
input = result
修改后的输入将成为循环中下一个钩子的新输入。

if torch._C._get_tracing_state(): result = self._slow_forward(*input, **kwargs)
此条件块检查是否正在跟踪当前正向传递。如果是,则使用修改后的输入数据调用该方_slow_forward 。
else: result = self.forward(*input, **kwargs)
如果未跟踪正向传递,则使用修改后的输入数据调用模块的正常 forward 方法。

for hook in itertools.chain(_global_forward_hooks.values(), self._forward_hooks.values()):
此循环遍历两个前向钩子集合: _global_forward_hooks 和 _forward_hooks 。这些钩子在模块的正向传递之后执行。
hook_result = hook(self, input, result):
对于每个钩子,使用 self 、 input 和 result 作为参数调用 hook 函数。
if hook_result is not None: result = hook_result
如果钩子返回非 None 值,则表示钩子修改了正向传递的结果,并且此修改后的结果将成为循环中下一个挂钩的新结果。

if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
此条件块检查是否有任何全局或为此特定模块注册的向后钩子。
var = result
正向传递的结果存储在变量 var 中。

while not isinstance(var, torch.Tensor):
此循环迭代 till 是 var 的 torch.Tensor 实例。
if isinstance(var, dict): var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
如果是一个字典,它 var 查找字典中的第一个值,即 torch.Tensor .
else: var = var[0]
如果不是字典,它 var 假定它是一个序列(例如,列表,元组)并获取其第一个元素。
grad_fn = var.grad_fn
grad_fn torch.Tensor 实例的属性被分配给变量 grad_fn 。此属性表示在反向传播期间计算张量梯度的函数。

if grad_fn is not None:
如果 不是 grad_fn None,则表示张量参与了需要梯度的计算,我们需要向其附加向后钩子。
for hook in itertools.chain(_global_backward_hooks.values(), self._backward_hooks.values()):
此循环遍历两个向后钩子集合: _global_backward_hooks 和 _backward_hooks 。这些钩子在模块的向后传递期间执行。
wrapper = functools.partial(hook, self)
对于每个向后钩子,通过将钩子函数与模块 self 作为参数部分应用来创建一个新函数。这样做是为了确保钩子函数可以访问模块。
functools.update_wrapper(wrapper, hook)
包装器函数使用原始钩子函数中的信息进行更新,例如其名称和文档字符串。
grad_fn.register_hook(wrapper)
包装器函数注册为挂接到 grad_fn .这意味着在反向传播期间计算梯度时,将执行钩子以对梯度执行其他操作。
return result
返回前向传递的最终结果。

总之,该方法 _call_impl 执行前向挂钩(前向前和后向前),执行前向传递,执行向后挂钩(如果需要),并返回向前传递的结果。它还处理钩子修改输入或结果数据的情况,并确保向后钩子附加到相关张量,以便在反向传播期间进行梯度计算。

http://www.lryc.cn/news/98766.html

相关文章:

  • openGauss学习笔记-22 openGauss 简单数据管理-HAVING子句
  • 干货 | 常见电路板GND与外壳GND之间接一个电阻一个电容,为什么?
  • 网络层协议总览
  • C++模拟实现list
  • PostgreSQL PG16 逻辑复制在STANDBY 上工作 (译)
  • 《零基础入门学习Python》第058讲:论一只爬虫的自我修养6:正则表达式2
  • 第一堂棒球课:MLB棒球大联盟的主要战术·棒球1号位
  • 【论文阅读】利用道路目标特征的多期车载激光点云配准
  • L---泰拉瑞亚---2023河南萌新联赛第(三)场:郑州大学
  • windows无盘启动技术开发之使用本地镜像文件启动电脑
  • PoseiSwap 即将开启质押,利好刺激下 POSE通证短时涨超 30%
  • Linux文本编辑器-vim
  • vscode使用g++编译.c文件或.cpp文件
  • 云计算的服务模式包括哪些|PetaExpress云服务商
  • iOS--通知、代理、单例模式总结
  • 选择最佳安全文件传输方法的重要性
  • IBM LSF 集群虚拟化和工作负载管理解决方案
  • C++(14):重载运算与类型转换
  • 【深度学习】基于图形的机器学习:概述
  • 内存泄漏是什么?有什么危害
  • 【项目设计】基于负载均衡的在线oj平台
  • 生产环境Session解决方案、Session服务器之Redis
  • SPECjvm2008_1_01 openjdk8 x86_64 ARM64 运行时长、成绩 Run is valid, but not compliant
  • 安卓:百度地图开发(超详细)
  • DDSv1.4规范(中文版)
  • oracle,获取每日24*60,所有分钟数
  • vue elementui table去掉滚动条与实现表格自动滚动且无滚动条
  • SDK命令行工具配置
  • 【数字IC基础】竞争与冒险
  • 速成版-带您一天学完python自动化测试(selenium)