当前位置: 首页 > news >正文

7.神经网络基础

7.1 构造网络模型

#自定义自己的网络模块
import torch
from torch import nn
from torch.nn import functional as F
class MLP(nn.Module):def __init__(self):super().__init__()self.hidden=nn.Linear(20,256)self.out=nn.Linear(256,10)#定义模型的前向传播,即如何根据输入X返回所需的模型输出def forward(self,x):x=self.hidden(x)x=F.relu(x)x=self.out(x)return x
X=torch.normal(mean=0.5,std=2.0,size=(2,20))
net=MLP()
net(X)

7.2 参数管理与初始化

#访问参数
import torch
from torch import nn
net=nn.Sequential(nn.Linear(4,8),nn.ReLU(),nn.Linear(8,1))
X=torch.rand(size=(2,4))
print(net[2].state_dict())
print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data)
print(net[2].bias.grad)
print(*[(name,param.shape) for name,param in net[0].named_parameters()])
print(*[(name,param.shape) for name,param in net.named_parameters()])
print(net.state_dict()['2.bias'].data)
import torch
from torch import nn
net=nn.Sequential(nn.Linear(4,8),nn.ReLU(),nn.Linear(8,1))
def init_weight(m):if type(m)==nn.Linear:nn.init.normal_(m.weight,mean=0,std=0.01)nn.init.zeros_(m.bias)
def init_constant(m):if type(m)==nn.Linear:nn.init.constant_(m.weight,1)nn.init.zeros_(m.bias)
net.apply(init_weight)
print(net[0].weight.data[0],net[0].bias.data[0])
net.apply(init_constant)
print(net[0].weight.data[0],net[0].bias.data[0])
#参数绑定
shared=nn.Linear(8,8)
net=nn.Sequential(nn.Linear(4,8),nn.ReLU(),shared,nn.ReLU(),shared,nn.ReLU(),nn.Linear(8,1))
net(X)
print(net[2].weight.data[0]==net[4].weight.data[0])
net[2].weight.data[0,0]=100
print(net[2].weight.data[0]==net[4].weight.data[0])

7.3 自定义层

#如何自定义一个层
import torch
from torch import nn
class centeredLayer(nn.Module):def __init__(self):super().__init__()def forward(self,X):return X-X.mean()
layer=centeredLayer()
layer(torch.FloatTensor([1,2,3,4,5]))
net=nn.Sequential(nn.Linear(8,128),centeredLayer())
Y=net(torch.rand(4,8))
Y.mean()

7.4 文件读写与模型保存

#文件读写
import torch
from torch import nn
from torch.nn import functional as F
x=torch.arange(4)
y=torch.zeros(4)
torch.save(x,'x-file')
x2=torch.load('x-file')
print(x2)
torch.save([x,y],'x-files')
x2,y2=torch.load('x-files')
print(x2,y2)
#模型保存
import torch
from torch import nn
from torch.nn import functional as F
class MLP(nn.Module):def __init__(self):super().__init__()self.hidden=nn.Linear(20,256)self.out=nn.Linear(256,10)#定义模型的前向传播,即如何根据输入X返回所需的模型输出def forward(self,x):x=self.hidden(x)x=F.relu(x)x=self.out(x)return x
X=torch.normal(mean=0.5,std=2.0,size=(2,20))
net=MLP()
Y=net(X)
torch.save(net.state_dict(),'mlp.params')
#读取权重
clone=MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
Y_clone=clone(X)
Y_clone==Y
http://www.lryc.cn/news/583165.html

相关文章:

  • 【深度学习】【入门】Sequential的使用和简单神经网络搭建
  • 【机器学习】BeamSearch算法
  • 华为OD机试_2025 B卷_观看文艺汇演问题(Python,100分)(附详细解题思路)
  • 七牛云C++开发面试题及参考答案
  • Vue 3 中父子组件双向绑定的 4 种方式
  • mysql互为主从失效,重新同步
  • qml加载html以及交互
  • HarmonyOS中各种动画的使用介绍
  • C语言extern的用法(非常详细,通俗易懂)
  • 〔从零搭建〕数据湖平台部署指南
  • 17.Spring Boot的Bean详解(新手版)
  • OpenCV颜色矩哈希算法------cv::img_hash::ColorMomentHash
  • STM32-待机唤醒实验
  • [Leetcode] 预处理 | 多叉树bfs | 格雷编码 | static_cast | 矩阵对角线
  • User手机上如何抓取界面的布局uiautomatorviewer
  • 【机器人】Aether 多任务世界模型 | 4D动态重建 | 视频预测 | 视觉规划
  • 速卖通跨境运营破局:亚矩阵云手机如何用“本地化黑科技”撬动俄罗斯市场25%客单价增长
  • React 编译器与性能优化:告别手动 Memoization
  • 开始读 PostgreSQL 16 Administration Cookbook
  • 苍穹外卖项目日记(day04)
  • 【Netty+WebSocket详解】WebSocket全双工通信与Netty的高效结合与实战
  • 冷冻电镜重构的GPU加速破局:从Relion到CryoSPARC的并行重构算法
  • 《重构项目》基于Apollo架构设计的项目重构方案(多种地图、多阶段、多任务、状态机管理)
  • 仓颉语言 1.0.0 升级指南:工具链适配、collection 操作重构与 Map 遍历删除避坑
  • IT系统安全刚需:绝缘故障定位系统
  • Tailwind CSS纵向滚动条设置
  • PiscTrace深蹲计数功能实现:基于 YOLO-Pose 和人体关键点分析
  • 基于Java+Maven+Testng+Selenium+Log4j+Allure+Jenkins搭建一个WebUI自动化框架(4)集成Allure报表
  • JavaScript数组方法——梳理和考点
  • SpringBoot实现动态Job实战