Datawhale AI 夏令营:RAG多模态检索(Baseline解读)
一、RAG技术概述
RAG(Retrieval-Augmented Generation)是一种结合检索与生成的混合模型技术,旨在通过检索外部知识库增强生成式模型的准确性和事实性。其核心思想是在生成答案前,先从大规模文档库中检索相关段落,再将检索结果作为上下文输入生成模型,从而减少幻觉(hallucination)问题。
1、技术原理
-
检索模块
使用稠密检索(如DPR)或稀疏检索(如BM25)从知识库中提取与输入问题相关的文档片段。检索器通常基于向量相似度匹配,将查询和文档映射到同一向量空间。 -
生成模块
将检索到的文档与原始问题拼接,输入生成模型(如GPT、T5等)。模型基于联合上下文生成最终答案。
2、关键优势
- 动态知识更新:无需重新训练模型,通过更新检索库即可适应新领域。
- 可解释性:生成结果可追溯至检索到的参考文档。
- 减少幻觉:生成受检索内容约束,降低虚构信息的概率。
3、典型应用场景
- 开放域问答(如客服机器人)
- 事实核查与数据报告生成
- 长文本生成(如论文摘要辅助)
4、实现示例(伪代码)
# 检索阶段
retriever = DenseRetriever(index_path="knowledge_base")
relevant_docs = retriever.search(query, top_k=3)# 生成阶段
generator = T5Generator()
input_text = f"Query: {query}\nContext: {' '.join(relevant_docs)}"
output = generator.generate(input_text)
4、挑战与改进方向
- 检索效率:大规模知识库需优化近似最近邻搜索(ANN)
- 噪声过滤:检索结果可能存在无关内容
- 多模态扩展:结合图像、表格等非文本数据
二、多模态RAG检索
1、多模态RAG检索的概念
多模态RAG(Retrieval-Augmented Generation)是一种结合检索与生成的混合模型,能够处理文本、图像、音频、视频等多种模态的数据。其核心思想是通过检索相关的外部知识库内容,增强生成模型的输出质量,确保生成内容的准确性和丰富性。
2、多模态RAG的核心组件
多模态编码器:将不同模态的数据(如文本、图像)映射到统一的向量空间,便于后续的相似性匹配。
检索模块:基于查询从多模态知识库中检索最相关的片段,支持跨模态检索(如文本搜图像、图像搜文本)。
生成模块:融合检索结果与输入查询,生成符合多模态上下文的输出(如生成带图像的描述文本)。
3、多模态RAG的工作流程
- 多模态查询输入:用户输入可以是任意模态(如文本提问或上传图片)。
- 跨模态检索:系统将查询编码为向量,从知识库中检索相关的多模态内容(如文本段落、图像片段)。
- 生成增强输出:生成模型结合检索结果与输入,产生最终回答(如生成图文结合的解答)。
4、应用场景
- 视觉问答(VQA):根据图像内容生成自然语言回答。
- 跨模态推荐:基于用户输入的文本或图像推荐相关商品或内容。
- 医疗诊断辅助:结合医学影像与文本报告生成诊断建议。
5、技术挑战
- 模态对齐:不同模态数据的语义对齐需要高质量的多模态预训练模型(如CLIP、Flamingo)。
- 检索效率:大规模多模态数据的实时检索对计算资源要求较高。
- 生成一致性:确保生成内容与检索结果在逻辑和语义上保持一致。
三、Baseline拆解
Step1、解压PDF文件
!unzip -n datas/财报数据库.zip -d datas/
赛题数据是行业研究报告或者财务报告的PDF文件,发布来源与证券研究所。
Step2、安装依赖
!pip install -r requirements.txt
# requirements.txt
# Core dependencies
tqdm
python-dotenv
openai
numpy
loguru
PyMuPDF# 需要高硬件资源
# mineru[all]>=2.1.9
# xinference>=1.8.0
baseline 中使用的依赖库,可以看到这里使用的是 PyMuPDF
解析PDF文件。
Step3、解析PDF为文本
!python fitz_pipeline_all.py
# fitz_pipeline_all.py
import fitz # PyMuPDF
import json
from pathlib import Pathdef process_pdfs_to_chunks(datas_dir: Path, output_json_path: Path):"""使用 PyMuPDF 直接从 PDF 提取每页文本,并生成最终的 JSON 文件。Args:datas_dir (Path): 包含 PDF 文件的输入目录。output_json_path (Path): 最终输出的 JSON 文件路径。"""all_chunks = []# 递归查找 datas_dir 目录下的所有 .pdf 文件pdf_files = list(datas_dir.rglob('*.pdf'))if not pdf_files:print(f"警告:在目录 '{datas_dir}' 中未找到任何 PDF 文件。")returnprint(f"找到 {len(pdf_files)} 个 PDF 文件,开始处理...")for pdf_path in pdf_files:file_name_stem = pdf_path.stem # 文件名(不含扩展名)full_file_name = pdf_path.name # 完整文件名(含扩展名)print(f" - 正在处理: {full_file_name}")try:# 使用 with 语句确保文件被正确关闭with fitz.open(pdf_path) as doc:# 遍历 PDF 的每一页for page_idx, page in enumerate(doc):# 提取当前页面的所有文本content = page.get_text("text")# 如果页面没有文本内容,则跳过if not content.strip():continue# 构建符合最终格式的 chunk 字典chunk = {"id": f"{file_name_stem}_page_{page_idx}","content": content,"metadata": {"page": page_idx, # 0-based page index"file_name": full_file_name}}all_chunks.append(chunk)except Exception as e:print(f"处理文件 '{pdf_path}' 时发生错误: {e}")# 确保输出目录存在output_json_path.parent.mkdir(parents=True, exist_ok=True)# 将所有 chunks 写入一个 JSON 文件with open(output_json_path, 'w', encoding='utf-8') as f:json.dump(all_chunks, f, ensure_ascii=False, indent=2)print(f"\n处理完成!所有内容已保存至: {output_json_path}")def main():base_dir = Path(__file__).parentdatas_dir = base_dir / 'datas'chunk_json_path = base_dir / 'all_pdf_page_chunks.json'process_pdfs_to_chunks(datas_dir, chunk_json_path)if __name__ == '__main__':main()
可以看到核心逻辑非常简单,就是读取目录下的PDF文件,针对每个PDF逐页提取文本内容,最终把所有提取的文本内容整合为 json 文件。
Step4、文本向量存储
!python rag_from_page_chunks.py
# rag_from_page_chunks.py
import json
import osimport hashlib
from typing import List, Dict, Any
from tqdm import tqdm
import sys
import concurrent.futures
import randomfrom get_text_embedding import get_text_embeddingfrom dotenv import load_dotenv
from openai import OpenAI
# 统一加载项目根目录的.env
load_dotenv()class PageChunkLoader:def __init__(self, json_path: str):self.json_path = json_pathdef load_chunks(self) -> List[Dict[str, Any]]:with open(self.json_path, 'r', encoding='utf-8') as f:return json.load(f)class EmbeddingModel:def __init__(self, batch_size: int = 64):self.api_key = os.getenv('LOCAL_API_KEY')self.base_url = os.getenv('LOCAL_BASE_URL')self.embedding_model = os.getenv('LOCAL_EMBEDDING_MODEL')self.batch_size = batch_sizeif not self.api_key or not self.base_url:raise ValueError('请在.env中配置LOCAL_API_KEY和LOCAL_BASE_URL')def embed_texts(self, texts: List[str]) -> List[List[float]]:return get_text_embedding(texts,api_key=self.api_key,base_url=self.base_url,embedding_model=self.embedding_model,batch_size=self.batch_size)def embed_text(self, text: str) -> List[float]:return self.embed_texts([text])[0]class SimpleVectorStore:def __init__(self):self.embeddings = []self.chunks = []def add_chunks(self, chunks: List[Dict[str, Any]], embeddings: List[List[float]]):self.chunks.extend(chunks)self.embeddings.extend(embeddings)def search(self, query_embedding: List[float], top_k: int = 3) -> List[Dict[str, Any]]:from numpy import dotfrom numpy.linalg import normimport numpy as npif not self.embeddings:return []emb_matrix = np.array(self.embeddings)query_emb = np.array(query_embedding)sims = emb_matrix @ query_emb / (norm(emb_matrix, axis=1) * norm(query_emb) + 1e-8)idxs = sims.argsort()[::-1][:top_k]return [self.chunks[i] for i in idxs]class SimpleRAG:def __init__(self, chunk_json_path: str, model_path: str = None, batch_size: int = 32):self.loader = PageChunkLoader(chunk_json_path)self.embedding_model = EmbeddingModel(batch_size=batch_size)self.vector_store = SimpleVectorStore()def setup(self):print("加载所有页chunk...")chunks = self.loader.load_chunks()print(f"共加载 {len(chunks)} 个chunk")print("生成嵌入...")embeddings = self.embedding_model.embed_texts([c['content'] for c in chunks])print("存储向量...")self.vector_store.add_chunks(chunks, embeddings)print("RAG向量库构建完成!")def query(self, question: str, top_k: int = 3) -> Dict[str, Any]:q_emb = self.embedding_model.embed_text(question)results = self.vector_store.search(q_emb, top_k)return {"question": question,"chunks": results}def generate_answer(self, question: str, top_k: int = 3) -> Dict[str, Any]:"""检索+大模型生成式回答,返回结构化结果"""qwen_api_key = os.getenv('LOCAL_API_KEY')qwen_base_url = os.getenv('LOCAL_BASE_URL')qwen_model = os.getenv('LOCAL_TEXT_MODEL')if not qwen_api_key or not qwen_base_url or not qwen_model:raise ValueError('请在.env中配置LOCAL_API_KEY、LOCAL_BASE_URL、LOCAL_TEXT_MODEL')q_emb = self.embedding_model.embed_text(question)chunks = self.vector_store.search(q_emb, top_k)# 拼接检索内容,带上元数据context = "\n".join([f"[文件名]{c['metadata']['file_name']} [页码]{c['metadata']['page']}\n{c['content']}" for c in chunks])# 明确要求输出JSON格式 answer/page/filenameprompt = (f"你是一名专业的金融分析助手,请根据以下检索到的内容回答用户问题。\n"f"请严格按照如下JSON格式输出:\n"f'{{"answer": "你的简洁回答", "filename": "来源文件名", "page": "来源页码"}}'"\n"f"检索内容:\n{context}\n\n问题:{question}\n"f"请确保输出内容为合法JSON字符串,不要输出多余内容。")client = OpenAI(api_key=qwen_api_key, base_url=qwen_base_url)completion = client.chat.completions.create(model=qwen_model,messages=[{"role": "system", "content": "你是一名专业的金融分析助手。"},{"role": "user", "content": prompt}],temperature=0.2,max_tokens=1024)import json as pyjsonfrom extract_json_array import extract_json_arrayraw = completion.choices[0].message.content.strip()# 用 extract_json_array 提取 JSON 对象json_str = extract_json_array(raw, mode='objects')if json_str:try:arr = pyjson.loads(json_str)# 只取第一个对象if isinstance(arr, list) and arr:j = arr[0]answer = j.get('answer', '')filename = j.get('filename', '')page = j.get('page', '')else:answer = rawfilename = chunks[0]['metadata']['file_name'] if chunks else ''page = chunks[0]['metadata']['page'] if chunks else ''except Exception:answer = rawfilename = chunks[0]['metadata']['file_name'] if chunks else ''page = chunks[0]['metadata']['page'] if chunks else ''else:answer = rawfilename = chunks[0]['metadata']['file_name'] if chunks else ''page = chunks[0]['metadata']['page'] if chunks else ''# 结构化输出return {"question": question,"answer": answer,"filename": filename,"page": page,"retrieval_chunks": chunks}if __name__ == '__main__':# 路径可根据实际情况调整chunk_json_path = "./all_pdf_page_chunks.json"rag = SimpleRAG(chunk_json_path)rag.setup()# 控制测试时读取的题目数量,默认只随机抽取10个,实际跑全部时设为NoneTEST_SAMPLE_NUM = 10 # 设置为None则全部跑FILL_UNANSWERED = True # 未回答的也输出默认内容# 批量评测脚本:读取测试集,检索+大模型生成,输出结构化结果test_path = "./datas/test.json"if os.path.exists(test_path):with open(test_path, 'r', encoding='utf-8') as f:test_data = json.load(f)# 记录所有原始索引all_indices = list(range(len(test_data)))# 随机抽取部分题目用于测试selected_indices = all_indicesif TEST_SAMPLE_NUM is not None and TEST_SAMPLE_NUM > 0:if len(test_data) > TEST_SAMPLE_NUM:selected_indices = sorted(random.sample(all_indices, TEST_SAMPLE_NUM))def process_one(idx):item = test_data[idx]question = item['question']tqdm.write(f"[{selected_indices.index(idx)+1}/{len(selected_indices)}] 正在处理: {question[:30]}...")result = rag.generate_answer(question, top_k=5)return idx, resultresults = []if selected_indices:with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:results = list(tqdm(executor.map(process_one, selected_indices), total=len(selected_indices), desc='并发批量生成'))# 先输出一份未过滤的原始结果(含 idx)raw_out_path = "./rag_top1_pred_raw.json"with open(raw_out_path, 'w', encoding='utf-8') as f:json.dump(results, f, ensure_ascii=False, indent=2)print(f'已输出原始未过滤结果到: {raw_out_path}')# 只保留结果部分,并去除 retrieval_chunks 字段idx2result = {idx: {k: v for k, v in r.items() if k != 'retrieval_chunks'} for idx, r in results}filtered_results = []for idx, item in enumerate(test_data):if idx in idx2result:filtered_results.append(idx2result[idx])elif FILL_UNANSWERED:# 未被回答的,补默认内容filtered_results.append({"question": item.get("question", ""),"answer": "","filename": "","page": "",})# 输出结构化结果到jsonout_path = "./rag_top1_pred.json"with open(out_path, 'w', encoding='utf-8') as f:json.dump(filtered_results, f, ensure_ascii=False, indent=2)print(f'已输出结构化检索+大模型生成结果到: {out_path}')else:print("datas/test.json 不存在")
这里的逻辑也较为简单,首先将 json 文件的所有信息进行文本向量化,这里使用的是 bge-m3 嵌入模型;然后从测试集中提取部分问题,逐个将问题进行向量化,从向量存储中搜索与问题向量最接近的知识片段;最后把问题与知识片段附加到提示词,让模型根据上下文回答问题。
四、Baseline不足分析
1、PDF 解析粗糙
baseline中使用PyMuPDF提取每个PDF文件的文本,但是忽略了信息结构与图片,因此提取出来的文本作为知识库有很多噪声数据并且丢失了重要信息
2、向量存储粗糙
仅使用数组对向量进行存储,搜索方法也非常简陋,因此在检索步骤会丢失准确度。
3、文本未进行切割
大块文本直接进行向量化,若文本包含的信息量很大,而只用一个向量表示多个特征,显然在精细化检索场景下是有较大误差的,因为文本之间会相互影响,一个很好的例子就是:大海捞针。
4、溯源页码错误
baseline中直接使用的是pdf页码的下标,从0开始。而实际上,在训练集中的示例,下标从1开始。因此,页码错误会导致溯源准确度这一块得分为0。
5、大模型未适配金融领域
从赛题要求中,我们可得知这是金融领域的RAG检索场景,而baseline中使用的是通用大模型,因此对金融领域的兼容并不够完善。