当前位置: 首页 > news >正文

人工智能开发框架 04.网络构建

目录

步骤一 、全连接层

步骤二 、卷积层

步骤三 、ReLU层

步骤四 、池化层

步骤五 、Flatten层

步骤六 、定义模型类并查看参数


MindSpore将构建网络层的接口封装在nn模块中,我们将通过调用来构建不同类型的神经网络层。

步骤一 、全连接层

全连接层:mindspore.nn.Dense

  1. in_channels:输入通道
  2. out_channels:输出通道
  3. weight_init:权重初始化,Default 'normal'.
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
import numpy as np# 构造输入张量
input_a = Tensor(np.array([[1, 1, 1], [2, 2, 2]]), ms.float32)
print(input_a)
# 构造全连接网络,输入通道为3,输出通道为3
net = nn.Dense(in_channels=3, out_channels=3, weight_init=1)
output = net(input_a)
print(output)

 

步骤二 、卷积层

conv2d = nn.Conv2d(1, 6, 5, has_bias=False, weight_init='normal', pad_mode='valid')
input_x = Tensor(np.ones([1, 1, 32, 32]), ms.float32)print(conv2d(input_x).shape)

 

步骤三 、ReLU层

relu = nn.ReLU()
input_x = Tensor(np.array([-1, 2, -3, 2, -1]), ms.float16)
output = relu(input_x)print(output)

 

步骤四 、池化层

max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
input_x = Tensor(np.ones([1, 6, 28, 28]), ms.float32)print(max_pool2d(input_x).shape)

步骤五 、Flatten层

flatten = nn.Flatten()
input_x = Tensor(np.ones([1, 16, 5, 5]), ms.float32)
output = flatten(input_x)print(output.shape)

步骤六 、定义模型类并查看参数

MindSpore的Cell类是构建所有网络的基类,也是网络的基本单元。当用户需要神经网络时,需要继承Cell类,并重写__init__方法和construct方法。

class LeNet5(nn.Cell):"""Lenet网络结构"""def __init__(self, num_class=10, num_channel=1):super(LeNet5, self).__init__()# 定义所需要的运算self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')self.fc1 = nn.Dense(16 * 4 * 4, 120)self.fc2 = nn.Dense(120, 84)self.fc3 = nn.Dense(84, num_class)self.relu = nn.ReLU()self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)self.flatten = nn.Flatten()def construct(self, x):# 使用定义好的运算构建前向网络x = self.conv1(x)x = self.relu(x)x = self.max_pool2d(x)x = self.conv2(x)x = self.relu(x)x = self.max_pool2d(x)x = self.flatten(x)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.relu(x)x = self.fc3(x)return x
#实例化模型,利用parameters_and_names方法查看模型的参数
modelle = LeNet5()
for m in modelle.parameters_and_names():print(m)

http://www.lryc.cn/news/601208.html

相关文章:

  • spring gateway 配置http和websocket路由转发规则
  • Linux驱动21 --- FFMPEG 音频 API
  • Spring Boot + @RefreshScope:动态刷新配置的终极指南
  • mysql 快速上手
  • 发布 VS Code 扩展的流程:以颜色主题为例
  • 详解力扣高频SQL50题之1164. 指定日期的产品价格【中等】
  • MCP + LLM + Agent 8大架构:Agent能力、系统架构及技术实践
  • 2025年7月25日-7月26日 · AI 今日头条
  • 【测试报告】博客系统(Java+Selenium+Jmeter自动化测试)
  • Jmeter的元件使用介绍:(八)断言器详解
  • OpenResty 高并发揭秘:架构优势与 Linux 优化实践
  • 零基础学习性能测试第六章:性能难点-Jmeter实现海量用户压测
  • 人工智能与城市:城市生活的集成智能
  • FastAPI入门:查询参数模型、多个请求体参数
  • 元宇宙背景下治理模式:自治的乌托邦
  • 北大区块链技术与应用 笔记
  • solidity从入门到精通 第六章:安全第一
  • 【前后端】使用 PM2 管理 Node 进程
  • Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现标签条码一维码的检测(C#代码,UI界面版)
  • vue3.6更新哪些内容
  • 学习游戏制作记录(改进投掷剑的行为)7.27
  • Python 使用 asyncio 包处理并 发(避免阻塞型调用)
  • 创建属于自己的github Page主页
  • 【自动化运维神器Ansible】Ansible常用模块之archive模块详解
  • github上传本地项目过程记录
  • 【C语言网络编程基础】DNS 协议与请求详解
  • STM32的蓝牙通讯(HAL库)
  • 飞牛NAS本地化部署n8n打造个人AI工作流中心
  • 用 Flask 打造宠物店线上平台:从 0 到 1 的全栈开发实践
  • idea总结