当前位置: 首页 > news >正文

pytorch笔记:named_parameters

  • named_parameters 是 PyTorch 中一个非常有用的函数,用于访问模型中所有定义的参数及其对应的名称。
  • 它是 torch.nn.Module 类的方法之一,返回一个生成器,生成 (name, parameter) 对,name 是参数的名称,parameter 是对应的参数张量。

1 举例

1.0 创建模型


import torch
import torch.nn as nn# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 64, 5)self.fc1 = nn.Linear(64 * 4 * 4, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):x = torch.relu(self.conv1(x))x = torch.relu(self.conv2(x))x = x.view(-1, 64 * 4 * 4)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型
model_tst = SimpleModel()

1.1 应用1:打印模型的所有参数及其名称

for name, param in model_tst.named_parameters():print(name, param.shape)'''
conv1.weight torch.Size([20, 1, 5, 5])
conv1.bias torch.Size([20])
conv2.weight torch.Size([64, 20, 5, 5])
conv2.bias torch.Size([64])
fc1.weight torch.Size([500, 1024])
fc1.bias torch.Size([500])
fc2.weight torch.Size([10, 500])
fc2.bias torch.Size([10])
conv1.weight torch.Size([20, 1, 5, 5])
conv1.bias torch.Size([20])
conv2.weight torch.Size([64, 20, 5, 5])
conv2.bias torch.Size([64])
fc1.weight torch.Size([500, 1024])
fc1.bias torch.Size([500])
fc2.weight torch.Size([10, 500])
fc2.bias torch.Size([10])
'''

1.2 应用2:冻结特定层的参数

假设我们只想训练全连接层,而冻结卷积层的参数:

for name, param in model_tst.named_parameters():if 'conv' in name:param.requires_grad = False

1.3 应用3:自定义优化器参数

可以使用 named_parameters 创建自定义的参数组,以便对不同的参数组应用不同的学习率:

optimizer = torch.optim.SGD([{'params': [param for name, param in model_tst.named_parameters() if 'conv' in name], 'lr': 0.01},{'params': [param for name, param in model_tst.named_parameters() if 'fc' in name], 'lr': 0.1}
], momentum=0.9)

http://www.lryc.cn/news/383598.html

相关文章:

  • uniapp——H5添加支付宝授权登录,报错:系统异常,请联系商家。REDIRECT_URI_ILLEAGAL
  • 群辉NAS使用Kodi影视墙
  • 如何实现HPC数据传输的高效流转,降本增效?
  • redis 定时任务锁 分布式锁
  • 了解支付行业基本专业语
  • 408数据结构-图的应用1-最小生成树 自学知识点整理
  • Ubuntu18.04操作系统使用pip3安装open cv
  • 为什么变量不可以在 switch 语句中声明定义?
  • 手机定位技术全解析:原理、发展与应用
  • 深入探索Kylin的Cube构建:数据魔方的构建之旅
  • web渗透-CSRF漏洞
  • Python数据分析-电信客户流量预测与分析
  • 动态人物抠图换背景 MediaPipe
  • Vue3 vite使用postcss-px-to-viewport(适配vant)
  • MCU复位时GPIO是什么状态?
  • 领先GPT-4o:Anthropic 推出新一代模型 Claude 3.5 Sonnet|TodayAI
  • 使用AES,前端加密,后端解密,spring工具类了
  • 通过Spring-Data-Redis操作Redis
  • 自动驾驶ADAS
  • Python+Pytest+Allure+Yaml接口自动化测试框架详解
  • python turtle 001画两只小狗
  • 『亚马逊云科技产品测评』程序员最值得拥有的第一台专属服务器 “亚马逊EC2实例“
  • python 趣味习题_递归函数(炸弹迷宫路径计算)
  • 免费翻译API及使用指南——百度、腾讯
  • 深度测试中的隐藏面消除技术
  • oracle merge的使用
  • 《数字图像处理》实验报告四
  • 算法04 模拟算法之一维数组相关内容详解【C++实现】
  • 【技术解码】百数SRM:如何助力企业快速优化供应链管理?
  • 想要用tween实现相机的移动,three.js渲染的canvas画布上相机位置一点没动,如何解决??