当前位置: 首页 > article >正文

Pytorch的梯度控制

在之前的实验中遇到一些问题,因为之前计算资源有限,我就想着微调其中一部分参数做,于是我误打误撞使用了with torch.no_grad,可是发现梯度传递不了,于是写下此文来记录梯度控制的两个方法与区别。

在PyTorch中,控制梯度计算对于模型训练和微调至关重要。这里区分两个常用方法:

1. tensor.requires_grad = False

  • 目标: 单个张量(通常是模型参数 nn.Parameter)。
  • 行为:
    • “参数冻结”:这个张量本身不会计算梯度 (.gradNone)。
    • “参数不更新”:优化器不会更新这个张量。
    • “梯度可穿透”:如果它参与的运算的输入是 requires_grad=True 的,梯度仍然会通过这个运算传递给输入。它不阻碍梯度流向更早的可训练层。
  • 场景:
    • 微调:冻结预训练模型的某些层,只训练其他层。
    • 例子:pretrained_layer.weight.requires_grad = False

2. with torch.no_grad():

  • 目标: 一个代码块 (with 语句块内部)。
  • 行为:
    • “全局梯度关闭”(块内):块内所有新创建的张量默认 requires_grad=False
    • “不记录计算图”:块内的运算不被追踪,不构建反向传播所需的计算图。
    • “梯度截断”:梯度流到这个块的边界就会停止,无法通过块内的操作继续反向传播
  • 场景:
    • 模型评估/推理 (Inference/Evaluation):不需要梯度,节省内存和计算。
    • 执行不需要梯度的任何计算。
    • 例子:
     with torch.no_grad():outputs = model(inputs)# ...其他评估代码
    

核心区别速记:

特性requires_grad=Falsewith torch.no_grad():
谁不更新?这个参数自己(块内)没人更新
梯度能过吗?能过!不能过! (被截断)
影响范围?单个张量整个代码块

一句话总结:

  • 想让某个参数不更新但梯度能流过,用 requires_grad=False
  • 想让一段代码完全不计算梯度也不让梯度流过,用 with torch.no_grad()

搞清楚这两者的区别,能在PyTorch中更灵活地控制模型的训练过程!

http://www.lryc.cn/news/2397220.html

相关文章:

  • linux驱动开发(1)-内核模块
  • AI产品风向标:从「工具属性」到「认知引擎」的架构跃迁​
  • 前端八股之CSS
  • ps自然饱和度调整
  • 有公网ip但外网访问不到怎么办?内网IP端口映射公网连接常见问题和原因
  • InlineHook的原理与做法
  • 微服务-Sentinel
  • DNS缓存
  • MySQL垂直分库(基于MyCat)
  • Rust 变量与可变性
  • 深入理解 C++ 中的 list 容器:从基础使用到模拟实现
  • 状态机实现文件单词统计
  • 从0开始学习R语言--Day13--混合效应与生存分析
  • 基于mediapipe深度学习的虚拟画板系统python源码
  • 复变函数 $w = z^2$ 的映射图像演示
  • Python实现P-PSO优化算法优化循环神经网络LSTM回归模型项目实战
  • 复合机器人:纠偏算法如何重塑工业精度与效率?
  • 审计- 1- 审计概述
  • 在MDK中自动部署LVGL,在stm32f407ZGT6移植LVGL-8.4,运行demo,显示label
  • 模块二:C++核心能力进阶(5篇) 篇一:《STL源码剖析:vector扩容策略与迭代器失效》
  • 计算机组成原理核心剖析:CPU、存储、I/O 与总线系统全解
  • 数据分类分级的实践与反思:源自数据分析、治理与安全交叉视角的洞察
  • 自动化立体仓库WCS的设计与实现
  • 百度蜘蛛池的作用是什么?技术@baidutopseo
  • 8.linux文件与文件夹内处理命令cp,mv,rm
  • JavaScript性能优化:实战技巧提升10倍速度
  • 核函数:解锁支持向量机的强大能力
  • UE5 2D地图曝光太亮怎么修改
  • C# 类和继承(基类访问)
  • 帕金森带来的生活困境