当前位置: 首页 > news >正文

PyTorch神经网络打印存储所有权重+激活值(运行时中间值)

很多时候嵌入式或者新硬件需要纯净的权重模型和激活值(运行时中间值),本文提供一种最简洁的方法。
假设已经有模型model和pt文件了,在当前目录下新建weights文件夹,运行这段代码,就可以得到模型的权重(文本形式和二进制形式)

model.load_state_dict(state_dict)global_index = 0
for name, param in model.named_parameters():print(name, param.size())print(param.data.numpy(),file=open(f"weights/{global_index}-{name}.txt", "w"))param.data.numpy().tofile(f"weights/{global_index}-{name}.bin")global_index += 1

对于二进制形式的文件,可以通过od -t f4 <binary file name> 查看其对应的浮点数值。f4表示fp32.

打印forward的中间值:(这么复杂是必要的)

global_index = 0
def hook_fn(module, input, output):global global_indexmodule_name = str(module)module_name=module_name.replace(" ", "")module_name=module_name.replace("\n", "")# print(name)intermediate_outputs = {}# input is a tuple, output is a tensorfor i, inp in enumerate(input):intermediate_outputs[f"{global_index}-{module_name}-input-{i}"] = inpintermediate_outputs[f"{global_index}-{module_name}-output"] = outputmodule_name = module_name[0:200]  # make sure full path <= 255print(intermediate_outputs)print(f"Size input:",end=" ")if(type(input) == tuple):for i, inp in enumerate(input):if type(inp) == torch.Tensor:print(f"{i}-th Size: {inp.size()}", end=", ")inp.numpy().tofile(f"activations/{global_index}-{module_name}-input-{i}.bin")else:print(f"{i}-th : {inp}", end=", ")elif type(input) == torch.Tensor:print(f"Size: {input.size()}")input.numpy().tofile(f"activations/{global_index}-{module_name}-input.bin")print(f"Size output: {output.size()}")global_index += 1output.numpy().tofile(f"activations/{global_index}-{module_name}-output.bin")def register_hooks(model):for name, layer in model.named_children():# print(name, layer) # dump all layers, > layers.txt# Register the hook to the current layerlayer.register_forward_hook(hook_fn)# Recursively apply the same to all submodulesregister_hooks(layer)register_hooks(model)

其中regster_hooks和以下等价(不需要recursive了)

def register_hooks(model):for name, layer in model.named_modules():# print(name, layer) # dump all layerslayer.register_forward_hook(hook_fn)

其中nn.sequential作为一个整体,目前没办法拆开来看其内部的中间值。

http://www.lryc.cn/news/337715.html

相关文章:

  • grpc-教程(golang版)
  • Spring与Spring Boot的区别:从框架设计到应用开发
  • React Hooks 全解: 常用 Hooks 及使用场景详解
  • 第十三届蓝桥杯真题:x进制减法,数组切分,gcd,青蛙过河
  • JavaEE初阶Day 6:多线程(4)
  • 微信小程序 django+nodejs电影院票务售票选座系统324kd
  • 基于springboot实现桂林旅游景点导游平台管理系统【项目源码+论文说明】计算机毕业设计
  • idea 开发serlvet汽车租赁管理系统idea开发sqlserver数据库web结构计算机java编程layUI框架开发
  • Unity之PUN实现多人联机射击游戏的优化(Section 3)
  • PDF锐化
  • 【python和java】
  • C盘满了怎么办,清理工具TreeSize
  • 【vue】watch 侦听器
  • 校招生如何准备软件测试、测试开发岗位的面试?
  • 蓝桥杯抱佛脚篇~
  • 基于springboot的大学城水电管理系统源码数据库
  • AI大模型探索之路-应用篇2:Langchain框架ModelIO模块—数据交互的秘密武器
  • 【SSH】群晖开启ssh访问
  • Vue 移动端(H5)项目怎么实现页面缓存(即列表页面进入详情返回后列表页面缓存且还原页面滚动条位置)keep-alive缓存及清除keep-alive缓存
  • 【MVCC】深入浅出彻底理解MVCC
  • 【问题解决】ubuntu安装新版vscode报code-insiders相关错误
  • 【Python】面向对象(专版提升2)
  • Vscode设置滚轮进行字体大小的调节
  • 【QT入门】Qt自定义控件与样式设计之控件提升与自定义控件
  • Spring Validation解决后端表单校验
  • Harmony鸿蒙南向驱动开发-UART接口使用
  • 【示例】MySQL-事务控制示例:账户转账-savepoint关键字
  • STM32使用标准版RT-Thread,移植bsp中的板文件后,想使用I/O设备模型,使用串口3或者串口4收发时,发现串口3或者串口4没反应
  • MVCC(解决MySql中的并发事务的隔离性)
  • 第四十八章 为 Web 应用程序实现 HTTP 身份验证 - 在处理请求之前在 CSP 中进行身份验证