当前位置: 首页 > news >正文

Inference with C# BERT NLP Deep Learning and ONNX Runtime

目录

效果

测试一

测试二

测试三

模型信息

项目

代码

下载


Inference with C# BERT NLP Deep Learning and ONNX Runtime

效果

测试一

Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.  

Question :What is his name?

测试二

Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.  

Question :What will he bring home?

测试三

Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.  

Question :Where is Bob?

模型信息

Inputs
-------------------------
name:unique_ids_raw_output___9:0
tensor:Int64[-1]
name:segment_ids:0
tensor:Int64[-1, 256]
name:input_mask:0
tensor:Int64[-1, 256]
name:input_ids:0
tensor:Int64[-1, 256]
---------------------------------------------------------------

Outputs
-------------------------
name:unstack:1
tensor:Float[-1, 256]
name:unstack:0
tensor:Float[-1, 256]
name:unique_ids:0
tensor:Int64[-1]
---------------------------------------------------------------

项目

代码

using BERTTokenizers;
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Windows.Forms;

namespace Inference_with_C__BERT_NLP_Deep_Learning_and_ONNX_Runtime
{
    public struct BertInput
    {
        public long[] InputIds { get; set; }
        public long[] InputMask { get; set; }
        public long[] SegmentIds { get; set; }
        public long[] UniqueIds { get; set; }
    }

    public partial class Form1 : Form
    {
        public Form1()
        {
            InitializeComponent();
        }

        RunOptions runOptions;
        InferenceSession session;
        BertUncasedLargeTokenizer tokenizer;
        Stopwatch stopWatch = new Stopwatch();

        private void Form1_Load(object sender, EventArgs e)
        {
            string modelPath = "bertsquad-10.onnx";
            runOptions = new RunOptions();
            session = new InferenceSession(modelPath);
            tokenizer = new BertUncasedLargeTokenizer();
        }

        int MaxAnswerLength = 30;
        int bestN = 20;

        private void button1_Click(object sender, EventArgs e)
        {
            txt_answer.Text = "";
            Application.DoEvents();

            string question = txt_question.Text.Trim();
            string context = txt_context.Text.Trim();

            // Get the sentence tokens.
            var tokens = tokenizer.Tokenize(question, context);

            // Encode the sentence and pass in the count of the tokens in the sentence.
            var encoded = tokenizer.Encode(tokens.Count(), question, context);

            var padding = Enumerable
              .Repeat(0L, 256 - tokens.Count)
              .ToList();

            var bertInput = new BertInput()
            {
                InputIds = encoded.Select(t => t.InputIds).Concat(padding).ToArray(),
                InputMask = encoded.Select(t => t.AttentionMask).Concat(padding).ToArray(),
                SegmentIds = encoded.Select(t => t.TokenTypeIds).Concat(padding).ToArray(),
                UniqueIds = new long[] { 0 }
            };

            // Create input tensors over the input data.
            var inputIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputIds,
                  new long[] { 1, bertInput.InputIds.Length });

            var inputMaskOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputMask,
                  new long[] { 1, bertInput.InputMask.Length });

            var segmentIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.SegmentIds,
                  new long[] { 1, bertInput.SegmentIds.Length });

            var uniqueIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.UniqueIds,
                  new long[] { bertInput.UniqueIds.Length });

            var inputs = new Dictionary<string, OrtValue>
              {
                  { "unique_ids_raw_output___9:0", uniqueIdsOrtValue },
                  { "segment_ids:0", segmentIdsOrtValue},
                  { "input_mask:0", inputMaskOrtValue },
                  { "input_ids:0", inputIdsOrtValue }
              };

            stopWatch.Restart();
            // Run session and send the input data in to get inference output. 
            var output = session.Run(runOptions, inputs, session.OutputNames);
            stopWatch.Stop();

            var startLogits = output[1].GetTensorDataAsSpan<float>();

            var endLogits = output[0].GetTensorDataAsSpan<float>();

            var uniqueIds = output[2].GetTensorDataAsSpan<long>();

            var contextStart = tokens.FindIndex(o => o.Token == "[SEP]");

            var bestStartLogits = startLogits.ToArray()
                .Select((logit, index) => (Logit: logit, Index: index))
                .OrderByDescending(o => o.Logit)
                .Take(bestN);

            var bestEndLogits = endLogits.ToArray()
                .Select((logit, index) => (Logit: logit, Index: index))
                .OrderByDescending(o => o.Logit)
                .Take(bestN);

            var bestResultsWithScore = bestStartLogits
                .SelectMany(startLogit =>
                    bestEndLogits
                    .Select(endLogit =>
                        (
                            StartLogit: startLogit.Index,
                            EndLogit: endLogit.Index,
                            Score: startLogit.Logit + endLogit.Logit
                        )
                     )
                )
                .Where(entry => !(entry.EndLogit < entry.StartLogit || entry.EndLogit - entry.StartLogit > MaxAnswerLength || entry.StartLogit == 0 && entry.EndLogit == 0 || entry.StartLogit < contextStart))
                .Take(bestN);

            var (item, probability) = bestResultsWithScore
                .Softmax(o => o.Score)
                .OrderByDescending(o => o.Probability)
                .FirstOrDefault();

            int startIndex = item.StartLogit;
            int endIndex = item.EndLogit;

            var predictedTokens = tokens
                          .Skip(startIndex)
                          .Take(endIndex + 1 - startIndex)
                          .Select(o => tokenizer.IdToToken((int)o.VocabularyIndex))
                          .ToList();

            // Print the result.
            string answer = "answer:" + String.Join(" ", StitchSentenceBackTogether(predictedTokens))
                + "\r\nprobability:" + probability
                + $"\r\n推理耗时:{stopWatch.ElapsedMilliseconds}毫秒";

            txt_answer.Text = answer;
            Console.WriteLine(answer);

        }

        private List<string> StitchSentenceBackTogether(List<string> tokens)
        {
            var currentToken = string.Empty;

            tokens.Reverse();

            var tokensStitched = new List<string>();

            foreach (var token in tokens)
            {
                if (!token.StartsWith("##"))
                {
                    currentToken = token + currentToken;
                    tokensStitched.Add(currentToken);
                    currentToken = string.Empty;
                }
                else
                {
                    currentToken = token.Replace("##", "") + currentToken;
                }
            }

            tokensStitched.Reverse();

            return tokensStitched;
        }
    }
}
 

using BERTTokenizers;
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Windows.Forms;namespace Inference_with_C__BERT_NLP_Deep_Learning_and_ONNX_Runtime
{public struct BertInput{public long[] InputIds { get; set; }public long[] InputMask { get; set; }public long[] SegmentIds { get; set; }public long[] UniqueIds { get; set; }}public partial class Form1 : Form{public Form1(){InitializeComponent();}RunOptions runOptions;InferenceSession session;BertUncasedLargeTokenizer tokenizer;Stopwatch stopWatch = new Stopwatch();private void Form1_Load(object sender, EventArgs e){string modelPath = "bertsquad-10.onnx";runOptions = new RunOptions();session = new InferenceSession(modelPath);tokenizer = new BertUncasedLargeTokenizer();}int MaxAnswerLength = 30;int bestN = 20;private void button1_Click(object sender, EventArgs e){txt_answer.Text = "";Application.DoEvents();string question = txt_question.Text.Trim();string context = txt_context.Text.Trim();// Get the sentence tokens.var tokens = tokenizer.Tokenize(question, context);// Encode the sentence and pass in the count of the tokens in the sentence.var encoded = tokenizer.Encode(tokens.Count(), question, context);var padding = Enumerable.Repeat(0L, 256 - tokens.Count).ToList();var bertInput = new BertInput(){InputIds = encoded.Select(t => t.InputIds).Concat(padding).ToArray(),InputMask = encoded.Select(t => t.AttentionMask).Concat(padding).ToArray(),SegmentIds = encoded.Select(t => t.TokenTypeIds).Concat(padding).ToArray(),UniqueIds = new long[] { 0 }};// Create input tensors over the input data.var inputIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputIds,new long[] { 1, bertInput.InputIds.Length });var inputMaskOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputMask,new long[] { 1, bertInput.InputMask.Length });var segmentIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.SegmentIds,new long[] { 1, bertInput.SegmentIds.Length });var uniqueIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.UniqueIds,new long[] { bertInput.UniqueIds.Length });var inputs = new Dictionary<string, OrtValue>{{ "unique_ids_raw_output___9:0", uniqueIdsOrtValue },{ "segment_ids:0", segmentIdsOrtValue},{ "input_mask:0", inputMaskOrtValue },{ "input_ids:0", inputIdsOrtValue }};stopWatch.Restart();// Run session and send the input data in to get inference output. var output = session.Run(runOptions, inputs, session.OutputNames);stopWatch.Stop();var startLogits = output[1].GetTensorDataAsSpan<float>();var endLogits = output[0].GetTensorDataAsSpan<float>();var uniqueIds = output[2].GetTensorDataAsSpan<long>();var contextStart = tokens.FindIndex(o => o.Token == "[SEP]");var bestStartLogits = startLogits.ToArray().Select((logit, index) => (Logit: logit, Index: index)).OrderByDescending(o => o.Logit).Take(bestN);var bestEndLogits = endLogits.ToArray().Select((logit, index) => (Logit: logit, Index: index)).OrderByDescending(o => o.Logit).Take(bestN);var bestResultsWithScore = bestStartLogits.SelectMany(startLogit =>bestEndLogits.Select(endLogit =>(StartLogit: startLogit.Index,EndLogit: endLogit.Index,Score: startLogit.Logit + endLogit.Logit))).Where(entry => !(entry.EndLogit < entry.StartLogit || entry.EndLogit - entry.StartLogit > MaxAnswerLength || entry.StartLogit == 0 && entry.EndLogit == 0 || entry.StartLogit < contextStart)).Take(bestN);var (item, probability) = bestResultsWithScore.Softmax(o => o.Score).OrderByDescending(o => o.Probability).FirstOrDefault();int startIndex = item.StartLogit;int endIndex = item.EndLogit;var predictedTokens = tokens.Skip(startIndex).Take(endIndex + 1 - startIndex).Select(o => tokenizer.IdToToken((int)o.VocabularyIndex)).ToList();// Print the result.string answer = "answer:" + String.Join(" ", StitchSentenceBackTogether(predictedTokens))+ "\r\nprobability:" + probability+ $"\r\n推理耗时:{stopWatch.ElapsedMilliseconds}毫秒";txt_answer.Text = answer;Console.WriteLine(answer);}private List<string> StitchSentenceBackTogether(List<string> tokens){var currentToken = string.Empty;tokens.Reverse();var tokensStitched = new List<string>();foreach (var token in tokens){if (!token.StartsWith("##")){currentToken = token + currentToken;tokensStitched.Add(currentToken);currentToken = string.Empty;}else{currentToken = token.Replace("##", "") + currentToken;}}tokensStitched.Reverse();return tokensStitched;}}
}

下载

源码下载

http://www.lryc.cn/news/252770.html

相关文章:

  • 6、原型模式(Prototype Pattern,不常用)
  • 图像万物分割——Segment Anything算法解析与模型推理
  • Redis实战篇笔记(最终篇)
  • 游戏配置表的导入使用
  • ❀dialog命令运用于linux❀
  • 【算法】蓝桥杯2013国C 横向打印二叉树 题解
  • XunSearch 讯搜 error: storage size of ‘methods_bufferevent’ isn’t known
  • 基于AWS Serverless的Glue服务进行ETL(提取、转换和加载)数据分析(三)——serverless数据分析
  • 08、分析测试执行时间及获取pytest帮助
  • 视频集中存储/智能分析融合云平台EasyCVR平台接入rtsp,突然断流是什么原因?
  • JavaScript 复杂的<三元运算符和比较操作>的组合--案例(一)
  • uniapp搭建内网映射测试https域名
  • 国防科技大博士招生入学考试【50+论文主观题】
  • CUDA简介——编程模式
  • Linux 软件安装
  • flask之邮件发送
  • 【Filament】Filament环境搭建
  • 外包干了2个月,技术倒退2年。。。。。
  • 使用 python ffmpeg 批量检查 音频文件 是否损坏或不完整
  • Django:通过user-agent判断请求是来自移动端还是PC端(电脑端)
  • Linux中ssh远程登录系统和远程拷贝
  • git常用命令小记
  • 深入Android S (12.0) 探索Framework之输入系统IMS的构成与启动
  • SoC with CPLD and MCU ?
  • 基于AWS Serverless的Glue服务进行ETL(提取、转换和加载)数据分析(二)——数据清洗、转换
  • vuepress-----6、时间更新
  • C++ ini配置文件的简单读取使用
  • 【稳定检索|投稿优惠】2024年经济管理与安全科学国际学术会议(EMSSIC 2024)
  • 什么是网站?
  • pg_stat_replication.state 含义