当前位置: 首页 > news >正文

【pytorch】深度学习准备:基本配置

深度学习中常用包

import os 
import numpy as np 
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optimizer

超参数设置
2种设置方式:将超参数直接设置在训练的代码中;用yaml、json,dict等文件来存储超参数

# 批次的大小
batch_size = 16
# 优化器的学习率
lr = 1e-4
# 训练次数
max_epochs = 100

GPU设置

# 方案一:使用os.environ,这种情况如果使用GPU不需要设置
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 指明调用的GPU为0,1号# 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # 指明调用的GPU为1号

使用argparse和yaml文件

  1. argparse的使用:
import argparse
"""argparse.ArgumentParser()创建了一个对象add_argument()添加参数parse_args()将参数封装在opt内,各个参数通过.运算符调用
"""def main(opt):print(opt.num_batches)if __name__ == '__main__':parse = argparse.ArgumentParser()parse.add_argument('--num_batches', type=int, default=50, help='the num of batch')parse.add_argument('--num_window', type=int, default=5, help='the num of window')parse.add_argument('--weight', type=str, default= '../pretrain.pth', help='the path of pretrained model')opt = parse.parse_args()main(opt)
  1. yaml文件的使用
    下面是一个yaml文件的例子,参数呈现层级结构
device: 'cpu'data:train_path: 'data/train'test_path: 'test/train'num: 1000

读取yaml文件

def read_yaml(path):
"""read()读入yaml文件中的内容safe_load()加载yaml格式的内容并转换为字典
"""file = open(path, 'r', encoding='utf-8')string = file.read()file.close()dict = yaml.safe_load(string)return dictpath = 'config.yaml'
Dict = read_yaml(path)
device = Dict['device']
print(device)
train_path = Dict['data']['train_path']
print(train_path)
  1. 使用方法
    在yaml文件中给全部参数设置默认值,使用argparse库设置待调参数的值

参考资料

  1. 深度学习代码中的argparse以及yaml文件的使用
  2. datawhale的thorough-pytorch repo
http://www.lryc.cn/news/190601.html

相关文章:

  • etcd随笔
  • 0基础学习VR全景平台篇 第107篇:全景图调色和细节处理(上,地拍)
  • Verilog功能模块——同步FIFO
  • Unity ToLua热更框架使用教程(1)
  • 车载相关名词--车载数据中心方案
  • helm使用
  • Python in Visual Studio Code 2023年10月发布
  • Webmin远程命令执行漏洞复现报告
  • webstorm自定义文件模板(Vue + Scss)
  • 楔子-写在之前
  • 第 5 章 数组和广义表(稀疏矩阵的三元组顺序表存储实现)
  • 【RabbitMQ 实战】11 队列的结构和惰性队列
  • Python3-批量重命名指定目录中的一组文件,更改其扩展名
  • 渗透测试KAILI系统的安装环境(第八课)
  • 如何正确方便的理解双指针?力扣102 (二叉树的层序遍历)
  • Vue或uniapp引入自定义字体
  • ​力扣:LCR 122. 路径加密​ 题目:剑指Offer 05.替换空格(c++)
  • cJson堆内存释放问题
  • 论文阅读/写作扫盲
  • 一文拿捏对象内存布局及JMM(JAVA内存模型)
  • Android组件通信——ActivityGroup(二十五)
  • js的继承的方式
  • 聊聊HttpClient的重试机制
  • 北邮22级信通院数电:Verilog-FPGA(4)第三周实验:按键消抖、呼吸灯、流水灯 操作流程注意事项
  • Ghidra101再入门(上?)-Ghidra架构介绍
  • Vue3路由引入报错解决:无法找到模块“xxx.vue”的声明文件 xxx隐式拥有 “any“ 类型。
  • 基于若依ruoyi-nbcio支持flowable流程分类里增加流程应用类型
  • JS之同步异步promise、async、await
  • 【OpenCV • c++】自定义直方图 | 灰度直方图均衡 | 彩色直方图均衡
  • el-tree目录和el-table实现搜索定位高亮方法