当前位置: 首页 > news >正文

AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

博客导读:

《AI—工程篇》

AI智能体研发之路-工程篇(一):Docker助力AI智能体开发提效

AI智能体研发之路-工程篇(二):Dify智能体开发平台一键部署

AI智能体研发之路-工程篇(三):大模型推理服务框架Ollama一键部署

AI智能体研发之路-工程篇(四):大模型推理服务框架Xinference一键部署

AI智能体研发之路-工程篇(五):大模型推理服务框架LocalAI一键部署

《AI—模型篇》

AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络环境下的安装、部署及使用

AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战

AI智能体研发之路-模型篇(三):中文大模型开、闭源之争 

AI智能体研发之路-模型篇(四):一文入门pytorch开发

AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

目录

一、引言

二、pytorch模型结构定义

三、tensorflow模型结构定义

四、总结


一、引言

本文是上一篇AI智能体研发之路-模型篇(四):一文入门pytorch开发的番外篇,对上文中pytorch的网络结构和tensorflow的模型结构部分进一步详细对比与说明(水一篇为了得到当天的流量卷哈哈,如果想更详细的了解pytorch,辛苦移步上一篇哈。

二、pytorch模型结构定义

def __init__(self, input_size, hidden_size, output_size):super(ThreeLayerDNN, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)  # 第一层全连接层self.fc2 = nn.Linear(hidden_size, hidden_size)  # 第二层全连接层self.fc3 = nn.Linear(hidden_size, output_size)  # 输出层self.sigmoid = nn.Sigmoid()

首先定义了一个名为`ThreeLayerDNN`的类,它是基于PyTorch框架的,用于构建一个具有三个全连接层(也称为密集层)的深度神经网络,特别适用于二分类问题。下面是对代码的详细解释:

  • `__init__`: 这是Python中的构造函数,当创建`ThreeLayerDNN`类的新实例时会被调用。
  •  `super(ThreeLayerDNN, self).__init__()`: 这行代码调用父类的初始化方法。因为`ThreeLayerDNN`继承自PyTorch的`nn.Module`类,这一步确保了`ThreeLayerDNN`具有`nn.Module`的所有基本属性和方法。
  •  `self.fc1 = nn.Linear(input_size, hidden_size)`: 这里定义了神经网络的第一层全连接层(fully connected layer)。`input_size`是输入数据的特征数量,`hidden_size`是这一层的神经元数量。全连接层意味着输入数据的每个特征都将与这一层的每个神经元相连接。
  •  `self.fc2 = nn.Linear(hidden_size, hidden_size)`: 定义了第二层全连接层,结构与第一层相同,保持了相同的隐藏层大小,这在某些架构中用于加深网络而不立即增加模型复杂度。
  •  `self.fc3 = nn.Linear(hidden_size, output_size)`: 这是网络的输出层,其输入大小与隐藏层相同,输出大小为`output_size`,对于二分类问题,通常为1。
  •  `self.sigmoid = nn.Sigmoid()`: 这行代码定义了一个Sigmoid激活函数,它将在网络的输出层之后应用。Sigmoid函数将输出映射到(0, 1)之间,非常适合二分类问题,其中输出可以解释为属于正类的概率。

综上所述,这段代码构建了一个基础的神经网络结构,适合进行二分类任务,通过全连接层提取特征,并使用Sigmoid函数将网络输出转换为概率估计。

三、tensorflow模型结构定义

model = Sequential([Dense(512, input_shape=(X_train.shape[1],)),  # 第一层Activation('relu'),Dense(512),  # 第二层Activation('relu'),Dense(1),  # 输出层Activation('sigmoid')  # 二分类使用sigmoid
])

使用Keras库(现在是TensorFlow的一个部分)定义了一个简单的深度学习模型,具体来说是一个顺序(Sequential)模型,适用于进行二分类任务。下面是对这段代码的详细解释:

  • Sequential模型: 这是一种线性堆叠层的模型,适合于简单的前向传播神经网络。

  • Dense层: 也称为全连接层,每个神经元都与前一层的所有神经元相连。

    • Dense(512, input_shape=(X_train.shape[1],)): 第一层,有512个神经元,input_shape=(X_train.shape[1],)指定了输入数据的形状,这里假设X_train是一个二维数组,其中每一行是一个样本,X_train.shape[1]表示每个样本的特征数量。
    • Dense(512): 第二层,同样有512个神经元,由于是在Sequential模型中,它自动接收前一层的输出作为输入。
    • Dense(1): 输出层,只有一个神经元,适用于二分类问题。
  • Activation层: 激活函数层,为神经网络引入非线性。

    • Activation('relu'): 使用ReLU(Rectified Linear Unit)作为激活函数,它在输入大于0时输出输入值,小于0时输出0,有助于解决梯度消失问题。
    • 最后一层使用Activation('sigmoid'): 二分类任务中,输出层常用sigmoid激活函数,将输出映射到(0, 1)之间,便于解释为概率。

四、总结

两种框架在定义模型结构时思路基本相同,pytorch基于动态图,更加灵活。tensorflow基于静态图,更加稳定。 

如果还有时间,可以看看我的其他文章:

 《AI—工程篇》

AI智能体研发之路-工程篇(一):Docker助力AI智能体开发提效

AI智能体研发之路-工程篇(二):Dify智能体开发平台一键部署

AI智能体研发之路-工程篇(三):大模型推理服务框架Ollama一键部署

AI智能体研发之路-工程篇(四):大模型推理服务框架Xinference一键部署

AI智能体研发之路-工程篇(五):大模型推理服务框架LocalAI一键部署

《AI—模型篇》

AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络环境下的安装、部署及使用

AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战

AI智能体研发之路-模型篇(三):中文大模型开、闭源之争​​​​​​​ 

AI智能体研发之路-模型篇(四):一文入门pytorch开发

AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

http://www.lryc.cn/news/358959.html

相关文章:

  • 电商物流查询解决方案助力提升消费者体验
  • 【深度密码】神经网络算法在机器学习中的前沿探索
  • 搭载算能 BM1684 芯片,面向AI推理计算加速卡
  • Python开发 我的世界 Painting-the-World: Minecraft 像素图片生成器
  • 【经验分享】盘点“食用“的写文素材
  • 实习碰到的问题w1
  • c#实现BPM系统网络传输接口,http协议,post
  • 如何修改开源项目中发现的bug?
  • 结构设计模式 - 代理设计模式 - JAVA
  • 企业了解这些cad图纸加密方法,再也不怕图纸被盗了!
  • # 详解 JS 中的事件循环、宏/微任务、Primise对象、定时器函数,以及其在工作中的应用和注意事项
  • 神经网络与深度学习——第14章 深度强化学习
  • centOS 编译C/C++
  • java——网络原理初识
  • js怎么判断是否为手机号?js格式校验方法
  • 深入理解Java中的方法重载:让代码更灵活的秘籍
  • 鸿蒙ArkTS声明式开发:跨平台支持列表【显隐控制】 通用属性
  • 每日一题——Java编程练习题
  • java编辑器中如何调试程序?
  • 第四范式Q1业务进展:驰而不息 用科技锻造不朽价值
  • SpringBoot整合Kafka的快速使用教程
  • 低边驱动与高边驱动
  • 【C++】入门(二):引用、内联、auto
  • 编程学习 (C规划) 6 {24_4_18} 七 ( 简单扫雷游戏)
  • 【AI】llama-fs的 安装与运行
  • Android NDK系列(五)内存监控
  • 软件设计师,下午题 ——试题六
  • 《Kubernetes部署篇:基于麒麟V10+ARM64架构部署harbor v2.4.0镜像仓库》
  • 远程工作/线上兼职网站整理(数字游民友好)
  • elasticsearch7.15实现用户输入自动补全