当前位置: 首页 > news >正文

Pytorch中关于forward函数的理解与用法

目录

  • 前言
  • 1. 问题所示
  • 2. 原理分析
    • 2.1 forward函数理解
    • 2.2 forward函数用法

前言

深入深度学习框架的代码,发现forward函数没有被显示调用

但代码确重写了forward函数,于是好奇是不是python的魔术方法作用

1. 问题所示

代码如下所示:

class Module(nn.Module):# 初始化def __init__(self):super(Module, self).__init__()# ......# 前向传播def forward(self, x):# ......return x# 输入数据
data = .....  # 实例化
module = Module()# 前向传播
module(data)  

整个代码串没有显示调用forward函数
由此引发疑问:

  1. 谁去调用forward函数?
  2. 什么时候调用forward函数?

2. 原理分析

回顾python的基础知识:python 类和对象的详细分析
可以清楚知道对象需要执行方法,在方法中传入参数即可,类似 module.forward(data),但是执行对象(参数)就可成功。

这也说明:module(data) 等价于 module.forward(data)
即该代码块调用了forward函数(那他是怎样实现什么时候调用的呢)

本身Pytorch大部分操作都是通过继承nn.Module类实现,查看其源代码:

class Module(object):def __init__(self):def forward(self, *input):def add_module(self, name, module):def cuda(self, device=None):def cpu(self):def __call__(self, *input, **kwargs):def parameters(self, recurse=True):def named_parameters(self, prefix='', recurse=True):def children(self):def named_children(self):def modules(self):  def named_modules(self, memo=None, prefix=''):def train(self, mode=True):def eval(self):def zero_grad(self):def __repr__(self):def __dir__(self):

内部中有个def __call__(self, *input, **kwargs):函数,默认父类会执行该函数

大致如下:

class Module():def __call__(self, data):        print(data)module = Module()# 输出 1
module(1)

这正说明,深度学习的模型继承了nn.Module类,内部的__call__方法有对forward方法的调用,才不用显式地调用forward方法。
对此,深度学习的模型框架需要重写构造函数中的__init__函数和forward函数。

2.1 forward函数理解

  1. 通过module中的__call__方法
  2. __call__方法调用module中的forward方法
  3. forward方法
    —若碰到Module子类,则迭代回馈第一步;
    —若碰到Function子类,则执行第四步;
  4. 调用Function子类中的call方法
  5. __call__方法调用Function中的forward方法
  6. 由于层层嵌套,现在只需回馈上一层的值即可
    ( Function中的forward返回值 ->
    module中的forward返回值 ->
    module中的__call__进行forward_hook返回值)

代码逻辑如下:

def __call__(self, *input, **kwargs):# 此处执行forward函数result = self.forward(*input, **kwargs)for hook in self._forward_hooks.values():#将注册的hook拿出来用hook_result = hook(self, input, result)return result
  • 围观角度:所谓的__call__为函数调用,只需要将该类型的对象当做函数使用即可,即 module(data) 等价于 module.forward(data)

  • 宏观角度:当一个类默认实现特殊方法__call__,该类的实例就变成可调用的类型,即对象名() 等价于 对象名.__call__()

2.2 forward函数用法

CNN可学习的参数层和不可学习的参数层,大致如下:

  • 可学习的参数:卷积层和全连接层的权重、bias、BatchNorm的β和γ等。
  • 不可学习的参数(超参数):学习率、batch size、weight decay、模型的深度宽度分辨率等。
  • Module类中的init构造函数一般放置可学习的参数,其不可学习的参数如果不放置在init层,则在forward函数中可用nn.functional来代替。
  • forward函数必须重写(实现模型功能,链接各层之间的功能)
http://www.lryc.cn/news/179785.html

相关文章:

  • vite跨域proxy设置与开发、生产环境的接口配置,接口在生产环境下,还能使用proxy代理地址吗
  • 【嵌入式】使用MultiButton开源库驱动按键并控制多级界面切换
  • 【数据结构】树的概念理解和性质推导(保姆级详解,小白必看系列)
  • 融合之力:数字孪生、人工智能和数据分析的创新驱动
  • Spring的注解开发-Spring配置类的开发
  • Linux系统编程系列之进程间通信-信号量组
  • centos 6使用yum安装软件
  • maven无法下载时的解决方法——笔记
  • Java Spring Boot 开发框架
  • Pytorch学习记录-1-张量
  • paddle2.3-基于联邦学习实现FedAVg算法-CNN
  • nuiapp保存canvas绘图
  • Object.defineProperty()方法详解,了解vue2的数据代理
  • Linux 磁盘管理
  • 大数据与人工智能的未来已来
  • 【AI视野·今日Robot 机器人论文速览 第四十一期】Tue, 26 Sep 2023
  • [NOIP2012 提高组] 开车旅行
  • 数据库设计流程---以案例熟悉
  • Miniconda创建paddlepaddle环境
  • postgresql实现单主单从
  • 提取PDF数据:Documents for PDF ( GcPdf )
  • adb连接切换到模拟器端口
  • 为何每个开发者都在谈论Go?
  • 【Leetcode】 501. 二叉搜索树中的众数
  • 怎样给Ubuntu系统安装vmware-tools
  • DDS信号发生器波形发生器VHDL
  • Python3操作SQLite3创建表主键自增长|CRUD基本操作
  • B. Comparison String
  • python端口扫描
  • 国庆第二天