当前位置: 首页 > news >正文

神经网络基础-神经网络搭建和参数计算

文章目录

    • 1.构建神经网络
    • 2. 神经网络的优缺点

1.构建神经网络

在 pytorch 中定义深度神经网络其实就是层堆叠的过程,继承自nn.Module,实现两个方法:

  • __init__方法中定义网络中的层结构,主要是全连接层,并进行初始化。
  • forward方法,在实例化模型的时候,底层会自动调用该函数。该函数中可以定义学习率,为初始化定义的layer传入数据等。

我们来构建如下图所示的神经网络模型:
在这里插入图片描述

编码设计如下:

  1. 第1个隐藏层:权重初始化采用标准化的xavier初始化 激活函数使用sigmoid。
  2. 第2个隐藏层:权重初始化采用标准化的He初始化 激活函数采用relu。
  3. out输出层线性层 假若二分类,采用softmax做数据归一化。
# 创建神经网络
import torch
import torch.nn as nn
# pip install torchsummary
from torchsummary import summary # 计算模型参数,查看模型结构 pip install torchsummary
# 创建神经网络模型类
class Model(nn.Module):# 初始化属性值def __init__(self):# 调用父类的初始化属性值super(Model, self).__init__()# 创建第一个隐藏层模型,3个输入特征,3个输出特征self.linear1 = nn.Linear(3, 3)# 初始化权重 xavier 均匀分布初始化nn.init.xavier_uniform_(self.linear1.weight)# 创建第二个隐藏层,3个输入特征(上一层的输出特征),2个输出特征self.linear2 = nn.Linear(3, 2)# 初始化权重 kaiming 正太分布初始化nn.init.kaiming_normal_(self.linear2.weight)# 创建输出层模型self.out = nn.Linear(2, 2)# 创建向前传播方法,自动执行 forward()方法def forward(self, x):# 数据经过第一个线性层x = self.linear1(x)# 使用 sigmoid 激活函数x = torch.sigmoid(x)# 数据经过第二个线性层x = self.linear2(x)# 使用 relu 激活函数x = torch.relu(x)# 数据经过输出层x = self.out(x)# 使用 softmax 激活函数# dim=-1:每一维度行数据相机为1x = torch.softmax(x, dim=-1)return xif __name__ == '__main__':# 实例化model对象model = Model()# 随机产生数据data = torch.randn(5,3)print('data.shape',data.shape)# 数据经过神经网络模型训练out = model(data)print('out.shape',out.shape)# 计算模型参数# 计算每层每个神经元的 w 和 b 个数总和summary(model,input_size=(3,),batch_size=5)# 查看模型参数print("======查看模型参数w和b======")for name, param in model.named_parameters():print(name, param)
  • 神经网络的输入数据是为[batch_size, in_features]的张量经过网络处理后获取了[batch_size, out_features]的输出张量。

  • 在上述例子中,batch_size=5, in_features=3,out_features=2,结果如下所示:

    data.shape torch.Size([5, 3])
    out.shape torch.Size([5, 2])
    

    模型参数输出:

    ----------------------------------------------------------------Layer (type)               Output Shape         Param #
    ================================================================Linear-1                     [5, 3]              12Linear-2                     [5, 2]               8Linear-3                     [5, 2]               6
    ================================================================
    Total params: 26
    Trainable params: 26
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.00
    Forward/backward pass size (MB): 0.00
    Params size (MB): 0.00
    Estimated Total Size (MB): 0.00
    ----------------------------------------------------------------
    ======查看模型参数w和b======
    linear1.weight Parameter containing:
    tensor([[ 0.3857,  0.4809, -0.0346],[ 0.3645,  0.2803, -0.6291],[ 0.1999, -0.6617,  0.7724]], requires_grad=True)
    linear1.bias Parameter containing:
    tensor([0.3084, 0.5636, 0.4501], requires_grad=True)
    linear2.weight Parameter containing:
    tensor([[ 0.1063,  0.7494,  0.4311],[-1.4152,  0.3396, -0.8590]], requires_grad=True)
    linear2.bias Parameter containing:
    tensor([-0.3771,  0.2937], requires_grad=True)
    out.weight Parameter containing:
    tensor([[-0.6012,  0.4727],[-0.2953, -0.5854]], requires_grad=True)
    out.bias Parameter containing:
    tensor([-0.3271,  0.4940], requires_grad=True)
    

模型参数的计算:

  1. 以第一个隐层为例:该隐层有3个神经元,每个神经元的参数为:4个(w1,w2,w3,b1),所以一共用3x4=12个参数。
  2. 输入数据和网络权重是两个不同的事儿!对于初学者理解这一点十分重要,要分得清。
    在这里插入图片描述

2. 神经网络的优缺点

  1. 优点
    ➢ 精度高,性能优于其他的机器学习算法,甚至在某些领域超过了人类。
    ➢ 可以近似任意的非线性函数。
    ➢ 近年来在学界和业界受到了热捧,有大量的框架和库可供调。
  2. 缺点
    ➢ 黑箱,很难解释模型是怎么工作的。
    ➢ 训练时间长,需要大量的计算资源。
    ➢ 网络结构复杂,需要调整超参数。
    ➢ 部分数据集上表现不佳,容易发生过拟合。
http://www.lryc.cn/news/504985.html

相关文章:

  • Linux入门攻坚——41、Linux集群系统入门-lvs(2)
  • 音视频入门基础:MPEG2-TS专题(17)——FFmpeg源码中,解析TS program map section的实现
  • 了解https原理,对称加密/非对称加密原理,浏览器与服务器加密的进化过程,https做了些什么
  • 山西省第十八届职业院校技能大赛高职组 5G 组网与运维赛项规程
  • tcpdump编译 wireshark远程抓包
  • Web开发 -前端部分-CSS
  • 用 Python Turtle 绘制流动星空:编程中的璀璨星河
  • Java从入门到工作2 - IDEA
  • fastadmin批量压缩下载远程视频文件
  • 【保姆级】Mac如何安装+切换Java环境
  • 2024首届世界酒中国菜国际地理标志产品美食文化节成功举办篇章
  • Springboot静态资源
  • MTK修改配置更改产品类型ro.build.characteristics
  • SQL 查询中的动态字段过滤
  • 数字IC后端零基础入门基础理论(Day1)
  • 【LC】240. 搜索二维矩阵 II
  • Redis应用—4.在库存里的应用
  • selenium获取请求头
  • Rust中自定义Debug调试输出
  • docker离线安装、linux 安装docker
  • 卓易通:鸿蒙Next系统的蜜糖还是毒药?
  • AI大模型学习笔记|神经网络与注意力机制(逐行解读)
  • Linux 操作系统中的管道与共享内存
  • 恢复删除的文件:6个免费Windows电脑数据恢复软件
  • linux网络编程 | c | select实现多路IO转接服务器
  • 基于前后端分离的食堂采购系统源码:从设计到开发的全流程详解
  • 小程序自定义tab-bar,踩坑记录
  • 游戏引擎学习第52天
  • 【热力学与工程流体力学】流体静力学实验,雷诺实验,沿程阻力实验,丘里流量计流量系数测定,局部阻力系数的测定,稳态平板法测定材料的导热系数λ
  • 【HTML】根据不同域名设置不同的网站图标(替换 link 中 href 地址)