当前位置: 首页 > news >正文

BN体系理解——类封装复现

 

 

 

 

 

from pathlib import Path
from typing import Optionalimport torch
import torch.nn as nn
from torch import Tensorclass BN(nn.Module):def __init__(self,num_features,momentum=0.1,eps=1e-8):##num_features是通道数"""初始化方法:param num_features:特征属性的数量,也就是通道数目C"""super(BN, self).__init__()##register_buffer:将属性当成parameter进行处理,唯一的区别就是不参与反向传播的梯度求解self.register_buffer('running_mean', torch.zeros(1, num_features, 1, 1))self.register_buffer('running_var', torch.zeros(1, num_features, 1, 1))self.running_mean: Optional[Tensor]self.running_var: Optional[Tensor]self.running_mean=torch.zeros([1,num_features,1,1])self.running_var=torch.zeros([1,num_features,1,1])self.gamma=nn.Parameter(torch.ones([1,num_features,1,1]))self.beta=nn.Parameter(torch.zeros(1,num_features,1,1))self.eps=epsself.momentum=momentumdef forward(self,x):"""前向过程output=(x-μ)/α*γ+β:param x: [N,C,H,W]:return: [N,C,H,W]"""if self.training:#训练阶段--》使用当前批次的数据_mean=torch.mean(x,dim=(0,2,3),keepdim=True)_var = torch.var(x, dim=(0, 2, 3), keepdim=True)#将训练过程中的均值和方差保存下来--方便推理的时候使用--》滑动平均self.running_mean=self.momentum*self.running_mean+(1.0-self.momentum)*_meanself.running_var=self.momentum*self.running_var+(1.0-self.momentum)*_varelse:#推理阶段-->使用的是训练过程中的累积数据_mean=self.running_mean_var=self.running_varz=(x-_mean)/torch.sqrt(_var+self.eps)*self.gamma+self.betareturn zif __name__ == '__main__':torch.manual_seed(28)path_dir=Path("./output/models")path_dir.mkdir(parents=True,exist_ok=True)device=torch.device("cuda" if torch.cuda.is_available() else "cpu")bn=BN(num_features=12)bn.to(device)#只针对子模块和参数进行转换#模拟训练过程bn.train()xs=[torch.randn(8,12,32,32).to(device) for _ in range(10)]for _x in xs:bn(_x)print(bn.running_mean.view(-1))print(bn.running_var.view(-1))#模拟推理过程bn.eval()_r=bn(xs[0])print(_r.shape)bn=bn.cpu()#保存都是以cpu保存,恢复再自己转回GPU上#模拟模型保存torch.save(bn,str(path_dir/'bn_model.pkl'))#state_dict:获取当前模块的所有参数(Parameter+register_buffer)torch.save(bn.state_dict(),str(path_dir/"bn_params.pkl"))#pt结构的保存traced_script_module=torch.jit.trace(bn.eval(),xs[0].cpu())traced_script_module.save("./output/bn_model.pt")#模拟模型恢复bn_model=torch.load(str(path_dir/"bn_model.pkl"),map_location='cpu')bn_params=torch.load(str(path_dir/"bn_params.pkl"),map_location='cpu')print(len(bn_params))

http://www.lryc.cn/news/189494.html

相关文章:

  • 请求和响应的概述
  • (深度学习快速入门)A Gentle Introduction to Graph Neural Networks 笔记
  • VIM指令
  • Android 10.0 framework层实现app默认全屏显示
  • 【计算机网络黑皮书】传输层
  • 轻量限制流量?阿里云轻量应用服务器月流量包收费说明
  • Linux手记
  • springboot配置
  • 大数据中的一些词汇解释
  • 10月11-12日上课内容 Ansible
  • android studio 我遇到的Task :app:compileDebugJavaWithJavac FAILED问题及解决过程
  • PLC电梯控制系统
  • FastAPI学习-27 使用@app.api_route() 设置多种请求方式
  • 08. 机器学习- 线性回归
  • 好奇喵 | PT(Private Tracker)——什么是P2P,什么是BT,啥子是PT?
  • 【Node.js】crypto 模块
  • vue父组件向子组件传值的方法
  • MATLAB算法实战应用案例精讲-【优化算法】高尔夫优化算法(GOA)(附MATLAB代码实现)
  • 数组的reduce和reduceRight方法
  • 自动监控网站可用性并发送通知的 Bash 脚本
  • go 项目打包部署到服务器
  • 整理mongodb文档:副本集成员可以为偶数
  • PHP - 遇到的Bug - 总结
  • 统计子岛屿的数量
  • IntelliJ IDEA Maven 项目的依赖分析
  • 数学建模、统计建模、计量建模整体框架的理解以及建模的步骤
  • WaitGroup原理分析
  • java直播源码:如何使用Java构建一个高效的直播系统
  • Websocket获取B站直播间弹幕教程——第二篇、解包/拆包
  • 膝关节检测之1设计目标手势与物体交互的动画