当前位置: 首页 > news >正文

通过 PromptTemplate 生成干净的 SQL 查询语句并执行SQL查询语句

问题描述

在使用 LangChain 和 Llama 模型生成 SQL 查询时,遇到了 sqlite3.OperationalError 错误。错误信息如下:

OperationalError: (sqlite3.OperationalError) near "```sql
SELECT Name 
FROM MediaType 
LIMIT 5;
```": syntax error
[SQL: ```sql
SELECT Name 
FROM MediaType 
LIMIT 5;
```]

错误发生的原因是生成的 SQL 查询包含了不必要的 Markdown 代码块标记 ```,也就是在生成SQL语句的过程中,产生了其他的不干净文本,导致 SQL 语法错误。

最终解决方案

通过修改 PromptTemplate 来生成干净的 SQL 查询,确保生成的查询不包含任何 Markdown 代码块标记或附加评论。以下是解决方案的详细步骤和代码实现:

1. 初始化环境

首先,初始化所需的环境变量和模型:

import getpass
import os
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool# 如果没有设置 GROQ_API_KEY,则提示用户输入
if not os.environ.get("GROQ_API_KEY"):os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")# 初始化 Llama 模型,使用 Groq 后端
llm = init_chat_model("llama-3.3-70b-versatile", model_provider="groq", temperature=0)

2. 定义自定义提示模板

定义一个自定义的 PromptTemplate,用于生成干净的 SQL 查询:

custom_prompt = PromptTemplate(input_variables=["dialect", "input", "table_info", "top_k"],template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)

3. 创建 SQL 查询链

创建一个 SQL 查询链,并使用自定义提示模板:

write_query = create_sql_query_chain(llm, db, prompt=custom_prompt)

4. 构造输入数据字典

构造输入数据字典,其中包含方言、表结构、问题和行数限制:

input_data = {"dialect": db.dialect,                    # 数据库方言,如 "sqlite""table_info": db.get_table_info(),        # 表结构信息"input": "What name of MediaType is?",    # 问题"top_k": 5                                # 行数限制
}

5. 调用链生成并执行 SQL 查询

调用链生成 SQL 查询,确保生成的查询不包含 Markdown 代码块标记,然后执行查询并打印结果:

response = write_query.invoke(input_data)
query = response["query"]# 执行 SQL 查询并打印结果
execute_query = QuerySQLDataBaseTool(db=db)
result = execute_query.invoke({"query": query})
print(result)

总结

通过修改 PromptTemplate 来生成 SQL 查询时,明确要求返回的 SQL 查询不包含任何附加评论或 Markdown 格式,确保生成的 SQL 查询是干净的、可执行的。这样可以避免由多余的标记导致的 SQL 语法错误。

最后提供完整代码:

import getpass
import os
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabaseload_dotenv()# 如果没有设置 GROQ_API_KEY,则提示用户输入
if not os.environ.get("GROQ_API_KEY"):os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
table_info = db.get_table_info(["Album"])  # 注意需要传递列表
print(f"\n Original table info: {table_info}")#  初始化 Llama 模型,使用 Groq 后端
llm = init_chat_model("llama-3.3-70b-specdec", model_provider="groq", temperature=0)
# 定义自定义提示模板,用于生成 SQL 查询
custom_prompt = PromptTemplate(input_variables=["dialect", "input", "table_info", "top_k"],template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)write_query  = create_sql_query_chain(llm, db,prompt=custom_prompt)
# 构造输入数据字典,其中包含方言、表结构、问题和行数限制
input_data = {"dialect": db.dialect,                    # 数据库方言,如 "sqlite""table_info": db.get_table_info(),          # 表结构信息"question": "What name of MediaType is?","top_k": 5
}# 调用链生成 SQL 查询,返回结果为一个字典,包含键 "query"
write_query_response = write_query.invoke(input_data)
print('\n write_query result:',write_query_response)#执行SQL语句
execute_query = QuerySQLDataBaseTool(db=db)
execute_response = execute_query.invoke(write_query_response)
print('\n execute_response result:',execute_response)#两个动作合起来搞成链
chain = write_query | execute_query
result_chain = chain.invoke(input_data)
print('\n result_chain==',result_chain)

输出:
在这里插入图片描述

http://www.lryc.cn/news/544422.html

相关文章:

  • 用大白话解释缓存Redis +MongoDB是什么有什么用怎么用
  • 计算机毕业设计SpringBoot+Vue.js汽车销售网站(源码+文档+PPT+讲解)
  • 【0010】HTML水平线标签详解
  • FastExcel与Reactor响应式编程深度集成技术解析
  • Netty是如何实现零拷贝的?
  • 【大模型➕知识图谱】大模型结合医疗知识图谱:解锁智能辅助诊疗系统新范式
  • Spring Boot @Component注解介绍
  • MulFS-CAP: Multimodal Fusion-supervisedCross-modal
  • WordPress多语言插件GTranslate
  • wordpress子分类调用父分类名称和链接的3种方法
  • Prometheus + Grafana 监控
  • 初学STM32之简单认识IO口配置(学习笔记)
  • springboot2.7.18升级springboot3.3.0遇到的坑
  • gtest 和 gmock讲解
  • GC垃圾回收介绍及GC算法详解
  • 2020 年英语(一)考研真题 笔记(更新中)
  • 【springboot】Spring 官方抛弃了 Java 8!新idea如何创建java8项目
  • playbin之autoplug_factories源码剖析
  • 正浩创新内推:校招、社招EcoFlow社招内推码: FRQU1CY
  • 一文了解:部署 Deepseek 各版本的硬件要求
  • 有没有什么免费的AI工具可以帮忙做简单的ppt?
  • python绘图之灰度图
  • 华为 VRP 系统简介配置SSH,TELNET远程登录
  • 1.14 重叠因子:TRIMA三角移动平均线(Triangular Moving Average, TRIMA)概念与Python实战
  • 【tplink】校园网接路由器如何单独登录自己的账号,wan-lan和lan-lan区别
  • PC 端连接安卓手机恢复各类数据:安装、操作步骤与实用指南
  • 【折线图 Line】——1
  • SpringBoot 整合mongoDB并自定义连接池,实现多数据源配置
  • TCP/IP的分层结构、各层的典型协议,以及与ISO七层模型的差别
  • FreeRTOS-中断管理