昇思学习营-模型推理和性能优化学习心得
一、模型推理与 JIT 优化教程
(一)JIT 优化推理:提升模型响应速度
- 环境准备
首先需安装指定版本的 MindSpore 和 MindNLP,确保环境兼容性:
pip uninstall mindspore -ypip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.6.0/MindSpore/unified/aarch64/mindspore-2.6.0-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simplepip uninstall mindnlp -y
pip install https://xihe.mindspore.cn/coderepo/web/v1/file/MindSpore/mindnlp/main/media/mindnlp-0.4.1-py3-none-any.whl
- JIT 优化配置
通过 MindSpore 的上下文配置开启 JIT 编译和图算融合,提升推理效率:
import mindsporemindspore.set_context(enable_graph_kernel=True, # 图算融合加速mode=mindspore.GRAPH_MODE, # 静态图模式jit_config={"jit_level": "O2"} # O2级JIT优化
)
- 核心推理流程
模型与分词器加载:使用mindnlp.transformers加载预训练模型和分词器
from mindnlp.transformers import AutoTokenizer, AutoModelForCausalLMmodel_id = "MindSpore-Lab/DeepSeek-R1-Distill-Qwen-1.5B-FP16"
tokenizer = AutoTokenizer.from_pretrained(model_id, mirror="modelers")
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, mirror="modelers")
model.jit() # 模型全图静态图化
采样函数实现:采用 Top-p 采样确保生成文本的多样性,优先使用 NumPy 实现提升边缘设备效率
import numpy as np
from mindnlp.core import opsdef sample_top_p(probs, p=0.9):probs_np = probs.asnumpy()sorted_indices = np.argsort(-probs_np, axis=-1) # 降序排序sorted_probs = np.take_along_axis(probs_np, sorted_indices, axis=-1)cumulative_probs = np.cumsum(sorted_probs, axis=-1)mask = cumulative_probs - sorted_probs > p # 过滤累积概率超阈值的tokensorted_probs[mask] = 0.0sorted_probs /= np.sum(sorted_probs, axis=-1, keepdims=True) # 归一化# 转换为MindSpore张量并采样sorted_probs_tensor = mindspore.Tensor(sorted_probs, dtype=mindspore.float32)sorted_indices_tensor = mindspore.Tensor(sorted_indices, dtype=mindspore.int32)next_token_idx = ops.multinomial(sorted_probs_tensor, 1)return mindspore.ops.gather(sorted_indices_tensor, next_token_idx, axis=1, batch_dims=1)
自回归生成:结合静态缓存(StaticCache)加速长序列生成,通过 JIT 装饰器优化单步解码函数
from mindnlp.transformers import StaticCachepast_key_values = StaticCache(config=model.config, max_batch_size=2, max_cache_len=512, dtype=model.dtype
)
JIT优化单步解码
@mindspore.jit(jit_config=mindspore.JitConfig(jit_syntax_level='STRICT'))
def get_decode_one_tokens_logits(model, cur_token, input_pos, cache_position, past_key_values):logits = model(cur_token,position_ids=input_pos,cache_position=cache_position,past_key_values=past_key_values,return_dict=False,use_cache=True)[0]return logits
(二)交互式对话部署:构建对话机器人
- 模型加载与配置
加载模型和分词器,并配置 Peft 适配器(若有微调需求):
from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer
from mindnlp.peft import PeftModeltokenizer = AutoTokenizer.from_pretrained("MindSpore-Lab/DeepSeek-R1-Distill-Qwen-1.5B-FP16", mirror="modelers")
if tokenizer.pad_token is None:tokenizer.pad_token = tokenizer.eos_token # 补充pad_token
model = AutoModelForCausalLM.from_pretrained("MindSpore-Lab/DeepSeek-R1-Distill-Qwen-1.5B-FP16", mirror="modelers")# 加载微调适配器(可选)
model = PeftModel.from_pretrained(model, "./output/DeepSeek-R1-Distill-Qwen-1.5B/checkpoint-3/adapter_model/")
- 对话流程实现
对话历史处理:将历史对话转换为模型输入格式
def build_input_from_chat_history(chat_history, msg: str):messages = [{'role': role, 'content': content} for role, content in chat_history]messages.append({'role': 'user', 'content': msg}) # 新增用户输入return messages
流式生成:使用TextIteratorStreamer实现实时输出,提升交互体验
from mindnlp.transformers import TextIteratorStreamer
from threading import Threaddef inference(message, history):messages = build_input_from_chat_history(history, message)input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="ms")# 流式输出配置streamer = TextIteratorStreamer(tokenizer, timeout=300, skip_prompt=True, skip_special_tokens=True)generate_kwargs = dict(input_ids=input_ids,streamer=streamer,max_new_tokens=1024,use_cache=True)# 多线程生成,避免阻塞Thread(target=model.generate, kwargs=generate_kwargs).start()partial_message = ""for new_token in streamer:partial_message += new_tokenprint(new_token, end="", flush=True) # 实时打印return messages + [{'role': 'assistant', 'content': partial_message}]
- 命令行交互
实现简单的命令行交互界面,支持清空历史和终止程序:
import os
import platformos_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
history = []
print("欢迎使用对话机器人,输入clear清空历史,stop终止程序")
while True:query = input("\n用户:")if query == "stop":breakif query == "clear":os.system(clear_command)continueprint("\n机器人:", end="")history = inference(query, history)
二、学习心得
JIT 优化的核心价值:
通过 MindSpore 的 JIT 编译(尤其是 O2 级别优化)和图算融合,能显著降低大模型单次推理耗时。实际使用中,model.jit()和带 JIT 装饰器的函数可将计算图静态化,减少动态图的解释开销,这对边缘设备(如香橙派)上的部署尤为重要。
推理效率的关键技巧:
静态缓存(StaticCache)通过复用历史计算结果,避免自回归生成中重复计算,大幅提升长文本生成效率;
Top-p 采样在保证生成多样性的同时,通过 NumPy 实现降低了 MindSpore 张量操作的开销,更适配资源有限的设备。
交互式部署的用户体验设计:
流式生成(TextIteratorStreamer)配合多线程,解决了大模型生成耗时导致的界面卡顿问题,让用户能实时看到输出内容,接近人类对话的自然节奏。
结合 Peft 适配器加载功能,可轻松将预训练模型与微调权重结合,实现特定场景的定制化对话能力,为后续模型微调与应用落地提供了灵活的扩展路径。