使用onnxruntime推理Bert模型
Bert模型类别:onnx
输入输出数据格式:.npz
import onnxruntime
import numpy as np
import os# 加载 ONNX 模型
ort_session = onnxruntime.InferenceSession('bert-base-uncased_final.onnx')# 指定输入文件夹和输出文件夹
input_folder = ''
output_folder = ''# 确保输出文件夹存在
os.makedirs(output_folder, exist_ok=True)# 遍历输入文件
input_files = os.listdir(input_folder)
for input_file in input_files:if input_file.endswith('.npz'):input_path = os.path.join(input_folder, input_file)output_path = os.path.join(output_folder, input_file)print('input path:', input_path)# 加载 npz 格式的输入数据input_data = np.load(input_path)input_ids = input_data['input_ids']attention_mask = input_data['attention_mask']token_type_ids = input_data['token_type_ids']# 执行推理input_dict = {'input_ids': input_ids,'attention_mask': attention_mask,'token_type_ids': token_type_ids}outputs = ort_session.run(None, input_dict)# 获取推理结果output_start_logits = outputs[0]output_end_logits = outputs[1]# 保存推理结果为 npz 格式output_data = {'output_start_logits': output_start_logits,'output_end_logits': output_end_logits}np.savez(output_path, **output_data)print('output path:', output_path)