当前位置: 首页 > news >正文

torch显存分析——对生成模型清除显存

torch显存分析——对生成模型清除显存

  • 1. 问题介绍
  • 2. 应对方法

1. 问题介绍

本文主要针对生成场景下,如何方便快捷地清除当前进程占用的显存。文章的重点不止是对显存的管理,还包括怎样灵活的使用自定义组件来控制生成过程。

在之前的文章torch显存分析——如何在不关闭进程的情况下释放显存中,通过一个实验,分析了torch的显存占用情况,以及如何在不关闭进程的前提下,利用代码将显存释放掉。然而,在近期的实验中,却发现之前所介绍的显存释放方法对生成模型并不好用。

在前文中,所使用的方法是:

real_inputs = inputs['input_ids'][..., : 2, ...].to(model.device)
with torch.no_grad():logits = model(real_inputs, tail)
del real_inputs
del logits
torch.cuda.empty_cache()

然而,如果对生成模型,直接将model的forward替换成generate的话,即如下的替换方法,则会遇到问题。

with torch.no_grad():logits = model.generate(real_inputs)
del real_inputs
del logits
torch.cuda.empty_cache()

因为生成过程中,会有新的token生成,model.generate很可能不止一次在调用forward,所以这种方法就不灵了。

2. 应对方法

既然是模拟一边模型的forward方法,那就想办法让forward方法只被调用一次。或许直接还是使用model.forward就可以解决这个问题。但是这里我采用了另一种方法——使用Stopping Criteria。

既然只希望它生成执行一次,那就可以直接使用一个默认的criteria:

from transformers.generation.stopping_criteria import MaxNewTokensCriteria, StoppingCriteriaListempty_cache_helper = StoppingCriteriaList()
empty_cache_helper.append(MaxNewTokensCriteria(start_length=0, max_new_tokens=1))

这个东西的作用就是,最多只生成一个新的token,然后立即停止生成。

那么在清除显存时,只需要将它加上就好了:

with torch.no_grad():logits = model.generate(real_inputs, stopping_criteria=self.empty_cache_helper)
del real_inputs
del logits
torch.cuda.empty_cache()

如果不了解stopping criteria的话,可以去回顾之前的两篇文章:

以beam search为例,详解transformers中generate方法(上)
以beam search为例,详解transformers中generate方法(下)

今后的博客中,可能会结合一些例子,对自定义的logits processor和stopping criteria的使用进行介绍,感兴趣的同学可以关注一下。

http://www.lryc.cn/news/106234.html

相关文章:

  • electron+vue+ts窗口间通信
  • 基于Fringe-Projection环形投影技术的人脸三维形状提取算法matlab仿真
  • 如何使用Webman框架实现多语言支持和国际化功能?
  • 接受平庸,特别是程序员
  • HTML兼容性
  • Java日期和时间处理入门指南
  • anndata k折交叉
  • 深入解析项目管理中的用户流程图
  • Vue使用QrcodeVue生成二维码并下载
  • “用户登录”测试用例总结
  • 适应于Linux系统的三种安装包格式 .tar.gz、.deb、rpm
  • Linux lvs负载均衡
  • Tomcat 创建https
  • 超导电性的基本现象和相关理论
  • 在 PHP 中单引号(‘ ‘)和双引号(“ “)用法的区别
  • SpringCloudAlibaba:服务网关之Gateway的cors跨域问题
  • react中的高阶组件理解与使用
  • “从零开始学习Spring Boot:构建高效的Java应用程序“
  • 容器部署jenkins定时构建于本地时间不一致
  • 生成指定网段的IP字典自动化脚本
  • Java版工程行业管理系统源码-专业的工程管理软件- 工程项目各模块及其功能点清单 em
  • 《向量数据库指南》——大模型时代,为什么向量数据库成为标配?
  • Pytorch个人学习记录总结 10
  • 18款奔驰S320升级后排座椅加热功能,提升后排乘坐舒适性
  • Vue中的插值表达式
  • 背包问题(模板)
  • docker容器创建私有仓库(第三篇)
  • Eureka 学习笔记4:客户端 DiscoveryClient
  • 【方法】PDF可以转换成Word文档吗?如何操作?
  • AlphaControls crack