当前位置: 首页 > news >正文

【GPU RAM】实时监控GPU内存分配(一)

利用torch.cuda.memory捕捉内存快照

用到的命令

torch.cuda.memory._record_memory_history(max_entries=100000) 	# 开始记录,最多记录100000条
torch.cuda.memory._dump_snapshot(file_name)	#保存快照
torch.cuda.memory._record_memory_history(enabled=None)	#停止记录

max_entries=100000:最多记录 100000 条 GPU 内存分配/释放事件(alloc/free events)。具体来说:某一块显存 被申请/分配了(比如创建了一个 tensor)或者 某一块显存 被释放了(比如 tensor 被删除或者生命周期结束)。

记录显存分配的函数

只在使用GPU的时候才记录。


def start_record_memory_history() -> None:if not torch.cuda.is_available():logger.info("CUDA unavailable. Not recording memory history")returnlogger.info("Starting snapshot record_memory_history")torch.cuda.memory._record_memory_history(max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT)def stop_record_memory_history() -> None:if not torch.cuda.is_available():logger.info("CUDA unavailable. Not recording memory history")returnlogger.info("Stopping snapshot record_memory_history")torch.cuda.memory._record_memory_history(enabled=None)def export_memory_snapshot() -> None:if not torch.cuda.is_available():logger.info("CUDA unavailable. Not exporting memory snapshot")return# Prefix for file names.host_name = socket.gethostname()timestamp = datetime.now().strftime(TIME_FORMAT_STR)file_prefix = f"{host_name}_{timestamp}"try:logger.info(f"Saving snapshot to local file: {file_prefix}.pickle")torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")except Exception as e:logger.error(f"Failed to capture memory snapshot {e}")return

Example

以Bert代码为例子(带#!!!!!!!!!!!!!!!!!!!就是记录快照的命令):

date_str = datetime.now().strftime("%Y-%m-%d")
logging.basicConfig(format="%(levelname)s:%(asctime)s %(message)s",level=logging.INFO,filename="log_mem_snap.txt",datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"
# Keep a max of 100,000 alloc/free events in the recorded history
# leading up to the snapshot.
MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000def start_record_memory_history() -> None:if not torch.cuda.is_available():logger.info("CUDA unavailable. Not recording memory history")returnlogger.info("Starting snapshot record_memory_history")torch.cuda.memory._record_memory_history(max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT)def stop_record_memory_history() -> None:if not torch.cuda.is_available():logger.info("CUDA unavailable. Not recording memory history")returnlogger.info("Stopping snapshot record_memory_history")torch.cuda.memory._record_memory_history(enabled=None)def export_memory_snapshot() -> None:if not torch.cuda.is_available():logger.info("CUDA unavailable. Not exporting memory snapshot")return# Prefix for file names.host_name = socket.gethostname()timestamp = datetime.now().strftime(TIME_FORMAT_STR)file_prefix = f"{host_name}_{timestamp}"try:logger.info(f"Saving snapshot to local file: {file_prefix}.pickle")torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")except Exception as e:logger.error(f"Failed to capture memory snapshot {e}")returndef train(config):model = BertForPretrainingModel(config,)t_min, t_max = -1, 0  # for the input layert_min, t_max = model.bert.bert_embeddings.set_params(t_min, t_max)for i, layer in enumerate(model.bert.bert_encoder.bert_layers):t_tuple = layer.set_params(t_min, t_max)t_min, t_max = t_tuple[-2:]# print(t_min, t_max, "----------")config.T_RES = t_maxlast_epoch = -1if os.path.exists(config.model_save_path):checkpoint = torch.load(config.model_save_path)last_epoch = checkpoint['last_epoch']loaded_paras = checkpoint['model_state_dict']model.load_state_dict(loaded_paras)logging.info("## Successfully loaded the existing model and continue training. ......")model = model.to(config.device)model.train()bert_tokenize = BertTokenizer.from_pretrained(config.pretrained_model_dir).tokenizedata_loader = LoadBertPretrainingDataset(vocab_path=config.vocab_path,tokenizer=bert_tokenize,batch_size=config.batch_size,max_sen_len=config.max_sen_len,max_position_embeddings=config.max_position_embeddings,pad_index=config.pad_index,is_sample_shuffle=config.is_sample_shuffle,random_state=config.random_state,data_name=config.data_name,masked_rate=config.masked_rate,masked_token_rate=config.masked_token_rate,masked_token_unchanged_rate=config.masked_token_unchanged_rate)train_iter, test_iter, val_iter = \data_loader.load_train_val_test_data(test_file_path=config.test_file_path,train_file_path=config.train_file_path,val_file_path=config.val_file_path)# Optimizer# Split weights in two groups, one with weight decay and the other not.no_decay = ["bias", "LayerNorm.weight"]optimizer_grouped_parameters = [{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],"weight_decay": config.weight_decay,"initial_lr": config.learning_rate},{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],"weight_decay": 0.0,"initial_lr": config.learning_rate},]optimizer = AdamW(optimizer_grouped_parameters)scheduler = get_polynomial_decay_schedule_with_warmup(optimizer,int(len(train_iter) * 0),int(config.epochs * len(train_iter)),last_epoch=last_epoch)max_acc = 0state_dict = Nonefor epoch in range(config.epochs):losses = 0start_time = time.time()# Start recording memory snapshot history	      #!!!!!!!!!!!!!!!!!!!start_record_memory_history()for idx, (b_token_ids, b_segs, b_mask, b_mlm_label, b_nsp_label) in enumerate(train_iter):b_token_ids = b_token_ids.to(config.device)  # [src_len, batch_size]b_segs = b_segs.to(config.device)b_mask = b_mask.to(config.device)b_mlm_label = b_mlm_label.to(config.device)b_nsp_label = b_nsp_label.to(config.device)with record_function("## forward ##"):loss, mlm_logits, nsp_logits = model(input_ids=b_token_ids,attention_mask=b_mask,token_type_ids=b_segs,masked_lm_labels=b_mlm_label,next_sentence_labels=b_nsp_label)optimizer.zero_grad()with record_function("## backward ##"):loss.backward()with record_function("## optimizer ##"):optimizer.step()scheduler.step()losses += loss.item()mlm_acc, _, _, nsp_acc, _, _ = accuracy(mlm_logits, nsp_logits, b_mlm_label,b_nsp_label, data_loader.PAD_IDX)if idx % 20 == 0:logging.info(f"Epoch: [{epoch + 1}/{config.epochs}], Batch[{idx}/{len(train_iter)}], "f"Train loss :{loss.item():.3f}, Train mlm acc: {mlm_acc:.3f},"f"nsp acc: {nsp_acc:.3f}")config.writer.add_scalar('Training/Loss', loss.item(), scheduler.last_epoch)config.writer.add_scalar('Training/Learning Rate', scheduler.get_last_lr()[0], scheduler.last_epoch)config.writer.add_scalars(main_tag='Training/Accuracy',tag_scalar_dict={'NSP': nsp_acc,'MLM': mlm_acc},global_step=scheduler.last_epoch)# Create the memory snapshot file	        #!!!!!!!!!!!!!!!!!!!export_memory_snapshot()# Stop recording memory snapshot history	#!!!!!!!!!!!!!!!!!!!stop_record_memory_history()end_time = time.time()train_loss = losses / len(train_iter)logging.info(f"Epoch: [{epoch + 1}/{config.epochs}], Train loss: "f"{train_loss:.3f}, Epoch time = {(end_time - start_time):.3f}s")if (epoch + 1) % config.model_val_per_epoch == 0:mlm_acc, nsp_acc = evaluate(config, val_iter, model, data_loader.PAD_IDX)logging.info(f" ### MLM Accuracy on val: {round(mlm_acc, 4)}, "f"NSP Accuracy on val: {round(nsp_acc, 4)}")config.writer.add_scalars(main_tag='Testing/Accuracy',tag_scalar_dict={'NSP': nsp_acc,'MLM': mlm_acc},global_step=scheduler.last_epoch)

显示结果

将生成的.pickle文件拖拽到:https://docs.pytorch.org/memory_viz。随着训练,内存先增大,opt.step之后由于释放了不需要的梯度就减小了。
随着训练过程进行,内存分配的变化

reference

  • https://pytorch.org/blog/understanding-gpu-memory-1/
http://www.lryc.cn/news/574969.html

相关文章:

  • 微信小程序中scss、ts、wxml
  • 如何在 Manjaro Linux 上安装 Docker 容器
  • 云计算-Azure Functions :构建事件驱动的云原生应用报告
  • 《Effective Python》第十章 健壮性——警惕异常变量消失的问题
  • Encoder-only PLM RoBERTa ALBERT (BERT的变体)
  • 【大模型学习 | 量化】pytorch量化基础知识(1)
  • webpack5 css-loader 配置项中的modules
  • 华为云Flexus+DeepSeek征文|基于Dify+ModelArts打造智能客服工单处理系统
  • 设计模式精讲 Day 13:责任链模式(Chain of Responsibility Pattern)
  • 告别Excel地狱!用 PostgreSQL + ServBay 搭建跨境电商WMS数据中枢
  • 华为运维工程师面试题(英语试题,内部资料)
  • 数据库系统总结
  • AI+智慧高校数字化校园解决方案PPT(34页)
  • 【开源解析】基于PyQt5的智能费用报销管理系统开发全解:附完整源码
  • 博图SCL语言中 RETURN 语句使用详解
  • Harmony中的HAP、HAR、HSP区别
  • 《推荐技术算法与实践》
  • Linux Kernel下exFat使用fallocate函数不生效问题
  • 微信小程序 / UNIAPP --- 阻止小程序返回(顶部导航栏返回、左 / 右滑手势、安卓物理返回键和调用 navigateBack 接口)
  • Feign源码解析:动态代理与HTTP请求全流程
  • 《汇编语言:基于X86处理器》第4章 复习题和练习,编程练习
  • 福彩双色球第2025072期篮球号码分析
  • (LeetCode 面试经典 150 题) 151. 反转字符串中的单词(栈+字符串)
  • UNIAPP入门基础
  • 网络安全是什么?
  • 暴雨信创电脑代理商成功中标长沙市中医康复医院
  • iClone 中创建的面部动画导入 Daz 3D
  • 【请关注】实操mongodb集群部署
  • VS2022的C#打包出错解决
  • Liunx操作系统笔记2