当前位置: 首页 > news >正文

前馈神经网络正则化例子

直接看代码:

import torch  
import numpy as np  
import random  
from IPython import display  
from matplotlib import pyplot as plt  
import torchvision  
import torchvision.transforms as transforms   mnist_train = torchvision.datasets.MNIST(root='/MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  batch_size = 256 train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  num_inputs,num_hiddens,num_outputs =784, 256,10def init_param():W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  b1 = torch.zeros(1, dtype=torch.float32)  W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  b2 = torch.zeros(1, dtype=torch.float32)  params =[W1,b1,W2,b2]for param in params:param.requires_grad_(requires_grad=True)  return W1,b1,W2,b2def relu(x):x = torch.max(input=x,other=torch.tensor(0.0))  return xdef net(X):  X = X.view((-1,num_inputs))  H = relu(torch.matmul(X,W1.t())+b1)  #myrelu =((matmal x,w1)+b1),return  matmal(myrelu,w2 )+ b2return relu(torch.matmul(H,W2.t())+b2 )return torch.matmul(H,W2.t())+b2def SGD(paras,lr):  for param in params:  param.data -= lr * param.grad  def l2_penalty(w):return (w**2).sum()/2def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None,mylambda=0):  train_ls, test_ls = [], []for epoch in range(num_epochs):ls, count = 0, 0for X,y in train_iter :X = X.reshape(-1,num_inputs)l=loss(net(X),y)+ mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)optimizer.zero_grad()l.backward()optimizer.step()ls += l.item()count += y.shape[0]train_ls.append(ls)ls, count = 0, 0for X,y in test_iter:X = X.reshape(-1,num_inputs)l=loss(net(X),y) + mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)ls += l.item()count += y.shape[0]test_ls.append(ls)if(epoch)%2==0:print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))return train_ls,test_lslr = 0.01num_epochs = 20Lamda = [0,0.1,0.2,0.3,0.4,0.5]Train_ls, Test_ls = [], []for lamda in Lamda:print("current lambda is %f"%lamda)W1,b1,W2,b2 = init_param()loss = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer,lamda)   Train_ls.append(train_ls)Test_ls.append(test_ls)x = np.linspace(0,len(Train_ls[1]),len(Train_ls[1]))plt.figure(figsize=(10,8))for i in range(0,len(Lamda)):plt.plot(x,Train_ls[i],label= f'L2_Regularization:{Lamda [i]}',linewidth=1.5)plt.xlabel('different epoch')plt.ylabel('loss')plt.legend(loc=2, bbox_to_anchor=(1.1,1.0),borderAxesPad = 0.)plt.title('train loss with L2_penalty')plt.show()

运行结果:

在这里插入图片描述

疑问和心得:

  1. 画图的实现和细节还是有些模糊。
  2. 正则化系数一般是一个可以根据算法有一定变动的常数。
  3. 前馈神经网络中,二分类最后使用logistic函数返回,多分类一般返回softmax值,若是一般的回归任务,一般是直接relu返回。
  4. 前馈神经网络的实现,从物理层上应该是全连接的,但是网上的代码一般都是两层单个神经元,这个容易产生误解。个人感觉,还是要使用nn封装的函数比较正宗。
http://www.lryc.cn/news/130365.html

相关文章:

  • spring的核心技术---bean的生命周期加案例分析详细易懂
  • 【Maven教程】(一)入门介绍篇:Maven基础概念与其他构建工具:理解构建过程与Maven的多重作用,以及与敏捷开发的关系 ~
  • 今天,谷歌Chrome浏览器部署抗量子密码
  • SUMO traci接口控制电动车前往充电站充电
  • 现代CSS中的换行布局技术
  • 简单理解Python中的深拷贝与浅拷贝
  • C++之std::pair<uint64_t, size_t>应用实例(一百七十七)
  • 前端打开后端返回的HTML格式的数据
  • How to deal with document-oriented data
  • Http 状态码汇总
  • mysql自定义实体类框架
  • 批量将Excel中的第二列内容从拼音转换为汉字
  • 消息推送:精准推送,提升运营效果,增添平台活力
  • [保研/考研机试] KY43 全排列 北京大学复试上机题 C++实现
  • Java将时间戳转化为特定时区的日期字符串
  • 【算法挨揍日记】day03——双指针算法_有效三角形的个数、和为s的两个数字
  • 通过 kk 创建 k8s 集群和 kubesphere
  • 感觉和身边其他人有差距怎么办?
  • 【C语言基础】宏定义的用法详解
  • 微服务系列文章之 SpringBoot 最佳实践
  • C++并发多线程--std::async、std::packaged_task和std::promise的使用
  • opencv-目标追踪
  • 【数据结构】 单链表面试题讲解
  • C++ string类的模拟实现
  • Qt实现简单的漫游器
  • 【c语言】文件操作
  • 【Unity】坐标转换经纬度方法(应用篇)
  • element时间选择器el-date-picter使用disabledDate指定禁用的日期
  • 出学校干了 5 年外包,已经废了
  • day-23 代码随想录算法训练营(19)part09