当前位置: 首页 > news >正文

【PyTorch】(三)模型的创建、参数初始化、保存和加载

文章目录

  • 1. 模型的创建
    • 1.1. 创建方法
      • 1.1.1. 通过使用模型组件
      • 1.1.2. 通过继承nn.Module类
    • 1.2. 模型组件
      • 1.2.1. 网络层
      • 1.2.2. 函数包
      • 1.2.3. 容器
    • 1.3. 将模型转移到GPU
  • 2. 模型参数初始化
  • 3. 模型的保存与加载
    • 3.1. 只保存参数
    • 3.2. 保存模型和参数

1. 模型的创建

1.1. 创建方法

1.1.1. 通过使用模型组件

可以直接使用模型组件快速创建模型。

import torch.nn as nnmodel =	nn.Linear(10, 10),
print(model)

输出结果:

Linear(in_features=10, out_features=10, bias=True)

1.1.2. 通过继承nn.Module类

在__init__方法中使用模型组件定义模型各层。在forward方法中实现前向传播。

import torch.nn as nnclass Model(nn.Module):def __init__(self):super().__init__()self.layer1 = nn.Linear(10, 10)self.layer2 = nn.Linear(10, 10)self.layer3 = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10))def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return xmodel = Model()
print(model)

输出结果:

Model((layer1): Linear(in_features=10, out_features=10, bias=True)(layer2): Linear(in_features=10, out_features=10, bias=True)(layer3): Sequential((0): Linear(in_features=10, out_features=10, bias=True)(1): ReLU()(2): Linear(in_features=10, out_features=10, bias=True))
)

1.2. 模型组件

1.2.1. 网络层

1.2.2. 函数包

1.2.3. 容器

1.3. 将模型转移到GPU

方法与将数据转移到GPU类似,都有两种方法:

  1. model.to(device)
  2. mode.cuda()
import torch
import torch.nn as nn# 创建模型实例
model = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10)
)# 将模型移动到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 也可以
model = model.cuda()

2. 模型参数初始化

3. 模型的保存与加载

模型保存和加载使用的python内置的pickle模块。

3.1. 只保存参数

import torch
import torch.nn as nn# 创建模型实例
model1 = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10)
)# 保存和加载参数
torch.save(model1.state_dict(), '../model/model_params.pkl')
model1.load_state_dict(torch.load('../model/model_params.pkl'))

3.2. 保存模型和参数

import torch
import torch.nn as nn# 创建模型实例
model1 = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10)
)# 保存和加载模型和参数
torch.save(model1, '../model/model.pt')
model2 = torch.load('../model/model.pt')
print(model2)
http://www.lryc.cn/news/249514.html

相关文章:

  • 高效开发之:判断复杂list中的对象属性是否包含某个值
  • MacOS + Android Studio 通过 USB 数据线真机调试
  • 部署jekins遇到的问题
  • SQLY优化
  • 设计模式——行为型模式(一)
  • Rust语言入门教程(六) - 字符串类型
  • 【MATLAB源码-第92期】基于simulink的QPSK调制解调仿真,采用相干解调对比原始信号和解调信号。
  • 关于C语言控制浮点数输出精度问题
  • 【Linux 静态IP配置】
  • 【Linux 操作系统配置 SFTP】
  • 信贷专员简历模板
  • Python自动化测试面试经典题
  • java+springboot物流管理系统设计与实现wl-ssmj+jsp
  • 概念理论类-k8s :架构篇
  • window10家庭版中文转专业版流程
  • Chrome显示分享按钮
  • GPTS-生成一个动漫图像GPT
  • 在gazebo里搭建一个livox mid360 + 惯导仿真平台测试 FAST-LIO2
  • SpringMVC文件下载
  • 前端项目打包放到springboot项目时,访问不带index.html
  • Tomcat注册为服务后,如何配置Tomcat内存大小
  • C语言入门实战教程——嵌入式必备教程(2023年版最全最新整理)
  • Chatbot开发三剑客:LLAMA、LangChain和Python
  • 【Spring之AOP底层源码解析】
  • 【UCAS自然语言处理作业二】训练FFN, RNN, Attention机制的语言模型,并计算测试集上的PPL
  • RabbitMQ消息模型之Sample
  • 安全技术与防火墙
  • Windows系统搭建Appium 2 和 Appium Inspector 环境
  • 计算机应用基础_错题集_OutLook操作题_操作系统应用题_电子表格---网络教育统考工作笔记005
  • 2023-11-26 LeetCode每日一题(统计子串中的唯一字符)