当前位置: 首页 > news >正文

昇思学习营-模型推理和性能优化学习心得

一、模型推理与 JIT 优化教程
(一)JIT 优化推理:提升模型响应速度

  1. 环境准备
    首先需安装指定版本的 MindSpore 和 MindNLP,确保环境兼容性:
pip uninstall mindspore -ypip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.6.0/MindSpore/unified/aarch64/mindspore-2.6.0-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simplepip uninstall mindnlp -y
pip install https://xihe.mindspore.cn/coderepo/web/v1/file/MindSpore/mindnlp/main/media/mindnlp-0.4.1-py3-none-any.whl
  1. JIT 优化配置
    通过 MindSpore 的上下文配置开启 JIT 编译和图算融合,提升推理效率:
import mindsporemindspore.set_context(enable_graph_kernel=True,  # 图算融合加速mode=mindspore.GRAPH_MODE,  # 静态图模式jit_config={"jit_level": "O2"}  # O2级JIT优化
)
  1. 核心推理流程
    模型与分词器加载:使用mindnlp.transformers加载预训练模型和分词器
from mindnlp.transformers import AutoTokenizer, AutoModelForCausalLMmodel_id = "MindSpore-Lab/DeepSeek-R1-Distill-Qwen-1.5B-FP16"
tokenizer = AutoTokenizer.from_pretrained(model_id, mirror="modelers")
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, mirror="modelers")
model.jit()  # 模型全图静态图化

采样函数实现:采用 Top-p 采样确保生成文本的多样性,优先使用 NumPy 实现提升边缘设备效率

import numpy as np
from mindnlp.core import opsdef sample_top_p(probs, p=0.9):probs_np = probs.asnumpy()sorted_indices = np.argsort(-probs_np, axis=-1)  # 降序排序sorted_probs = np.take_along_axis(probs_np, sorted_indices, axis=-1)cumulative_probs = np.cumsum(sorted_probs, axis=-1)mask = cumulative_probs - sorted_probs > p  # 过滤累积概率超阈值的tokensorted_probs[mask] = 0.0sorted_probs /= np.sum(sorted_probs, axis=-1, keepdims=True)  # 归一化# 转换为MindSpore张量并采样sorted_probs_tensor = mindspore.Tensor(sorted_probs, dtype=mindspore.float32)sorted_indices_tensor = mindspore.Tensor(sorted_indices, dtype=mindspore.int32)next_token_idx = ops.multinomial(sorted_probs_tensor, 1)return mindspore.ops.gather(sorted_indices_tensor, next_token_idx, axis=1, batch_dims=1)

自回归生成:结合静态缓存(StaticCache)加速长序列生成,通过 JIT 装饰器优化单步解码函数

from mindnlp.transformers import StaticCachepast_key_values = StaticCache(config=model.config, max_batch_size=2, max_cache_len=512, dtype=model.dtype
)

JIT优化单步解码

@mindspore.jit(jit_config=mindspore.JitConfig(jit_syntax_level='STRICT'))
def get_decode_one_tokens_logits(model, cur_token, input_pos, cache_position, past_key_values):logits = model(cur_token,position_ids=input_pos,cache_position=cache_position,past_key_values=past_key_values,return_dict=False,use_cache=True)[0]return logits

(二)交互式对话部署:构建对话机器人

  1. 模型加载与配置
    加载模型和分词器,并配置 Peft 适配器(若有微调需求):
from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer
from mindnlp.peft import PeftModeltokenizer = AutoTokenizer.from_pretrained("MindSpore-Lab/DeepSeek-R1-Distill-Qwen-1.5B-FP16", mirror="modelers")
if tokenizer.pad_token is None:tokenizer.pad_token = tokenizer.eos_token  # 补充pad_token
model = AutoModelForCausalLM.from_pretrained("MindSpore-Lab/DeepSeek-R1-Distill-Qwen-1.5B-FP16", mirror="modelers")# 加载微调适配器(可选)
model = PeftModel.from_pretrained(model, "./output/DeepSeek-R1-Distill-Qwen-1.5B/checkpoint-3/adapter_model/")
  1. 对话流程实现
    对话历史处理:将历史对话转换为模型输入格式
def build_input_from_chat_history(chat_history, msg: str):messages = [{'role': role, 'content': content} for role, content in chat_history]messages.append({'role': 'user', 'content': msg})  # 新增用户输入return messages

流式生成:使用TextIteratorStreamer实现实时输出,提升交互体验

from mindnlp.transformers import TextIteratorStreamer
from threading import Threaddef inference(message, history):messages = build_input_from_chat_history(history, message)input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="ms")# 流式输出配置streamer = TextIteratorStreamer(tokenizer, timeout=300, skip_prompt=True, skip_special_tokens=True)generate_kwargs = dict(input_ids=input_ids,streamer=streamer,max_new_tokens=1024,use_cache=True)# 多线程生成,避免阻塞Thread(target=model.generate, kwargs=generate_kwargs).start()partial_message = ""for new_token in streamer:partial_message += new_tokenprint(new_token, end="", flush=True)  # 实时打印return messages + [{'role': 'assistant', 'content': partial_message}]
  1. 命令行交互
    实现简单的命令行交互界面,支持清空历史和终止程序:
import os
import platformos_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
history = []
print("欢迎使用对话机器人,输入clear清空历史,stop终止程序")
while True:query = input("\n用户:")if query == "stop":breakif query == "clear":os.system(clear_command)continueprint("\n机器人:", end="")history = inference(query, history)

二、学习心得
JIT 优化的核心价值:
通过 MindSpore 的 JIT 编译(尤其是 O2 级别优化)和图算融合,能显著降低大模型单次推理耗时。实际使用中,model.jit()和带 JIT 装饰器的函数可将计算图静态化,减少动态图的解释开销,这对边缘设备(如香橙派)上的部署尤为重要。
推理效率的关键技巧:
静态缓存(StaticCache)通过复用历史计算结果,避免自回归生成中重复计算,大幅提升长文本生成效率;
Top-p 采样在保证生成多样性的同时,通过 NumPy 实现降低了 MindSpore 张量操作的开销,更适配资源有限的设备。
交互式部署的用户体验设计:
流式生成(TextIteratorStreamer)配合多线程,解决了大模型生成耗时导致的界面卡顿问题,让用户能实时看到输出内容,接近人类对话的自然节奏。

结合 Peft 适配器加载功能,可轻松将预训练模型与微调权重结合,实现特定场景的定制化对话能力,为后续模型微调与应用落地提供了灵活的扩展路径。

http://www.lryc.cn/news/610158.html

相关文章:

  • MS-DOS 常用指令集
  • 【清除pip缓存】Windows上AppData\Local\pip\cache内容
  • 我的世界进阶模组开发教程——附魔(2)
  • (二)软件工程
  • 论文阅读笔记:《Dataset Distillation by Matching Training Trajectories》
  • 在CentOS 7上安装配置MySQL 8.0完整指南
  • PyTorch :三角函数与特殊运算
  • MFC-Ribbbon-图标-PS
  • 【秋招笔试】2025.08.03虾皮秋招笔试-第二题
  • 蜜汁整体二分——区间 kth
  • Next.js 中的文件路由:工作原理
  • 秋招笔记-8.4
  • 软件需求关闭前的质量评估标准是什么
  • Java项目:基于SSM框架实现的商铺租赁管理系统【ssm+B/S架构+源码+数据库+毕业论文+开题报告+任务书+远程部署】
  • 优化 Unity ConstantForce2D 性能的简单方法【资料】
  • 2025年08月04日Github流行趋势
  • 无偿分享120套开源数据可视化大屏H5模板
  • WSL安装Ubuntu与Docker环境,比VMware香
  • Flutter 对 Windows 不同版本的支持及 flutter_tts 兼容性指南
  • 离线Docker项目移植全攻略
  • Oracle 在线重定义
  • [GYCTF2020]FlaskApp
  • 【编程实践】点云曲率计算与可视化
  • 八股——Kafka相关
  • 【Pytorch✨】LSTM04 l理解长期记忆和短期记忆
  • 第12届蓝桥杯Scratch_选拔赛_初级组_真题2020年8月23日
  • 神经网络---非线性激活
  • C++进阶-封装红黑树模拟实现map和set(难度较高)
  • 李沐写作笔记
  • 嵌入式 C 语言入门:函数指针基础笔记 —— 从计算器优化到指针本质