当前位置: 首页 > news >正文

15.手动实现BatchNorm(BN)

15.1 BatchNorm操作手动实现

import torch 
from torch import nndef batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):if not torch.is_grad_enabled():#这个是推理模式X_hat=(X-moving_mean)/torch.sqrt(moving_var+eps)else:assert len(X.shape) in (2,4)if len(X.shape)==2:mean=X.mean(dim=0)var=((X-mean)**2).mean(dim=0)else:mean=X.mean(dim=(0,2,3),keepdim=True)var=((X-mean)**2).mean(dim=(0,2,3),keepdim=True)# 更新移动平均的均值和方差X_hat=(X-mean)/torch.sqrt(var+eps)moving_mean=momentum*moving_mean+(1.0-momentum)*meanmoving_var=momentum*moving_var+(1.0-momentum)*varY=gamma*X_hat+betareturn Y,moving_mean.data,moving_var.data
class BatchNorm(nn.Module):def __init__(self, num_features,num_dims):super().__init__()if num_dims==2:shape=(1,num_features)else:shape=(1,num_features,1,1)#这是两个需要更新的参数self.gamma=nn.Parameter(torch.ones(shape))self.beta=nn.Parameter(torch.zeros(shape))self.moving_mean=torch.zeros(shape)self.moving_var=torch.ones(shape)#这个不能为0,应该是/sqrt(var)def forward(self,X):#计算设备对齐if self.moving_mean.device!=X.device:self.moving_mean=self.moving_mean.to(X.device)self.moving_var=self.moving_var.to(X.device)Y,self.moving_mean,self.moving_var=batch_norm(X,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)return Y
model=nn.Sequential(nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Flatten(),#Flatten()之后就是[batch_size,features] 2维度的向量矩阵nn.Linear(16*4*4,120),BatchNorm(120,num_dims=2),nn.Sigmoid(),nn.Linear(120,84),BatchNorm(84,num_dims=2),nn.Sigmoid(),nn.Linear(84,10))

15.2 BatchNorm实验效果

################################################################################################################
"""BatchNorm"""
################################################################################################################
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in test_data:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):if not torch.is_grad_enabled():#这个是推理模式X_hat=(X-moving_mean)/torch.sqrt(moving_var+eps)else:assert len(X.shape) in (2,4)if len(X.shape)==2:mean=X.mean(dim=0)var=((X-mean)**2).mean(dim=0)else:mean=X.mean(dim=(0,2,3),keepdim=True)var=((X-mean)**2).mean(dim=(0,2,3),keepdim=True)# 更新移动平均的均值和方差X_hat=(X-mean)/torch.sqrt(var+eps)moving_mean=momentum*moving_mean+(1.0-momentum)*meanmoving_var=momentum*moving_var+(1.0-momentum)*varY=gamma*X_hat+betareturn Y,moving_mean.data,moving_var.data
class BatchNorm(nn.Module):def __init__(self, num_features,num_dims):super().__init__()if num_dims==2:shape=(1,num_features)else:shape=(1,num_features,1,1)#这是两个需要更新的参数self.gamma=nn.Parameter(torch.ones(shape))self.beta=nn.Parameter(torch.zeros(shape))self.moving_mean=torch.zeros(shape)self.moving_var=torch.ones(shape)#这个不能为0,应该是/sqrt(var)def forward(self,X):#计算设备对齐if self.moving_mean.device!=X.device:self.moving_mean=self.moving_mean.to(X.device)self.moving_var=self.moving_var.to(X.device)Y,self.moving_mean,self.moving_var=batch_norm(X,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)return Y
################################################################################################################
transforms=transforms.Compose([transforms.Resize(28),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
train_img=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=transforms,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
device=torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
model=nn.Sequential(nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Flatten(),#Flatten()之后就是[batch_size,features] 2维度的向量矩阵nn.Linear(16*4*4,120),BatchNorm(120,num_dims=2),nn.Sigmoid(),nn.Linear(120,84),BatchNorm(84,num_dims=2),nn.Sigmoid(),nn.Linear(84,10)).to(device)
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=15)
################################################################################################################
print("BatchNorm算法学习参数效果:")
print("gamma:",model[1].gamma.reshape((-1,)))
print("beta:",model[1].beta.reshape((-1,)))

在这里插入图片描述
在这里插入图片描述

http://www.lryc.cn/news/586743.html

相关文章:

  • Linux中的数据库操作基础
  • pycharm+SSH 深度学习项目 远程后台运行命令
  • python爬取新浪财经网站上行业板块股票信息的代码
  • 【读书笔记】《C++ Software Design》第七章:Bridge、Prototype 与 External Polymorphism
  • cuda编程笔记(7)--多GPU上的CUDA
  • UniHttp生命周期钩子与公共参数实战:打造智能天气接口客户端
  • jenkins部署前端vue项目使用Docker+Jenkinsfile方式
  • 财务管理体系——解读大型企业集团财务管理体系解决方案【附全文阅读】
  • 算法入门--动态规划(C++)
  • 傅里叶变换中相位作用
  • 通过同态加密实现可编程隐私和链上合规
  • 终端输入命令,背后发生了什么--shell,tty,terminal解析
  • 数据结构 单链表(1)
  • 以太坊应用开发基础:从理论到实战的完整指南
  • 完整 Spring Boot + Vue 登录系统
  • 20250711_Sudo 靶机复盘
  • Http与Https区别和联系
  • linux:进程详解(2)
  • Excel的学习
  • SQL的初步学习(二)(以MySQL为例)
  • 基于 SpringBoot 的 REST API 与 RPC 调用的统一封装
  • JavaScript 获取 URL 参数值的全面指南
  • DOS下用TC2显示Bmp文件
  • Cesium初探-CallbackProperty
  • 单页面和多页面的区别和优缺点
  • 退出登录后头像还在?这个缓存问题坑过多少前端!
  • 开发语言的优劣势对比及主要应用领域分析
  • DNS协议解析过程
  • 前端进阶之路-从传统前端到VUE-JS(第五期-路由应用)
  • 开发语言中关于面向对象和面向过程的笔记