当前位置: 首页 > news >正文

20250720问答课题-基于BERT与混合检索问答系统代码解读

1. 配置

1.1. TagConfig 类

该类用于从 CSV 文件中读取标签数据,并把这些标签整理成一级、二级标签,并进行整理和排序。

def __init__(self, csv_path):
  • 该类的构造函数 __init__ 接受一个 CSV 文件路径作为参数,说明TagConfig 类在创建时需要一个 CSV 文件的路径
  • __init__ 是Python 类中的构造函数
    • self 必须是第一个参数,代表该类的实例对象本身,通过 self 可以访问和修改对象的属性
    • csv_path:是自定义的参数

def __init__(self, csv_path):

csv_path = csv_path # 这只是一个局部变量,函数结束后就消失

self.csv_path = csv_path # 这是对象的属性,其他方法可以访问

#初始化部分
self.LEVEL1 = set()
self.LEVEL2 = defaultdict(set)
  • LEVEL1 是一个「集合」,用来存放所有的一级标签
    • set():创建一个空的集合
  • LEVEL2 是一个「特殊字典」,键值对,键是一级标签,值是对应二级标签的集合
    • defaultdict(set):创建一个默认值为集合的字典,当访问不存在的键时,会自动创建一个新的 set 作为默认值。

#读取 CSV 文件并处理标签
with open(csv_path, 'r', encoding='utf-8') as f:reader = csv.DictReader(f)for row in reader:if '一级标签' in row and row['一级标签']:l1 = row['一级标签'].strip()self.LEVEL1.add(l1)if '二级标签' in row and row['二级标签']:l2_list = [t.strip() for t in row['二级标签'].split('/') if t.strip()]for l2 in l2_list:self.LEVEL2[l1].add(l2)
  • 代码会打开并读取 CSV 文件
  • 它假设 CSV 文件里有「一级标签」和「二级标签」这两列
  • 遍历每一行数据:
    • 先提取一级标签(确保该行包含 '一级标签' 列且值不为空), strip() 方法去掉前后空格后,添加到 LEVEL1 集合中
    • 如果这一行还有二级标签,会把它们按 / 分割成多个标签,逐个添加到对应一级标签的二级标签集合中
  • csv.DictReader(f):将 CSV 文件的每一行转换为 字典,键是 CSV 的表头(第一行),值是对应列的数据。

一级标签,二级标签

水果,苹果/香蕉

蔬菜,胡萝卜/西红柿

上述为CSV文件内容,则第一行解析为:{'一级标签': '水果', '二级标签': '苹果/香蕉'}

#对标签进行排序
self.LEVEL1 = sorted(self.LEVEL1)
self.LEVEL2 = {k: sorted(v) for k, v in self.LEVEL2.items()}
  • sorted() 函数把一级标签和二级标签都按字母升序排列

调用,TagConfig 类在 QASystem 类的构造函数中被调用

class QASystem:def __init__(self, config: Config):self.config = configself.tag_config = TagConfig(config.QA_FILE)# ...
  • 在 QASystem 类的构造函数中,创建了一个 TagConfig 类的实例,并将 Config 类中的 QA_FILE 属性作为参数传递给 TagConfig 类的构造函数。

1.2. Config

class Config:BERT_PATH = "./BERT"QA_FILE = "心法问答.csv"STOPWORDS_FILE = "stopwords.txt"TOP_K = 1000MIN_SIMILARITY = 0.6DEVICE = "cuda" if torch.cuda.is_available() else "cpu"LAYER_WEIGHTS = [0.15, 0.25, 0.35, 0.25]

该类定义了系统的各种配置参数,包括 BERT 模型路径、问答文件路径、停用词文件路径、返回结果数量、最小相似度阈值、设备选择和 BERT 最后四层隐藏状态的加权系数。

  • BERT_PATH = "./BERT":BERT 预训练模型的本地路径,用于文本编码或特征提取
  • QA_FILE = "心法问答.csv":问答对数据文件,包含问题和对应的答案,格式为 CSV
  • STOPWORDS_FILE = "stopwords.txt":停用词文件路径,包含需要过滤的的词汇
  • TOP_K = 1000:在检索或排序任务中,保留得分最高的前 1000 个结果
  • MIN_SIMILARITY = 0.6:最小相似度阈值,只有相似度高于 0.6 的结果才会被保留
  • DEVICE = "cuda" if torch.cuda.is_available() else "cpu":自动检测是否有可用的 GPU,如果有则使用 GPU 加速计算,否则使用 CPU。
  • LAYER_WEIGHTS = [0.15, 0.25, 0.35, 0.25]:BERT 不同隐藏层的权重系数,用于层加权策略,第 3 层权重最高(0.35),该层最重要

2. Flask 路由定义

是一个基于 Flask 的 Web 应用,提供了问答系统的前端交互和后端 API 服务。

定义了四个 Flask 路由:

  • /:返回主页。
  • /submit:返回提交页面。
  • /search:处理搜索请求,调用 search_with_tags 方法进行搜索,并返回搜索结果。
  • /submit_question:处理问题提交请求,对提交的问题进行验证,若问题不存在则将其添加到问答文件和内存中,并更新问题向量。

2.1. 应用初始化

app = Flask(__name__)
system = None
  • __name__:当前模块的名称,是Python 内置变量,便于Flask 确定应用的根路径,找到模板和静态文件
  • app = Flask(__name__):创建 Flask 应用实例
  • system :全局变量,用于存储问答系统的实例;在应用启动时不立即初始化 system,而在需要时再初始化;延时初始化便于测试和部署,可以在不同环境中用不同配置初始化 system

2.2. 页面路由

是 Flask 框架中定义的路由规则,将 URL 路径与对应的处理函数绑定,从而返回不同的网页内容。

@app.route('/')
def home():return render_template('index.html')
  • /:返回主页模板,用户搜索问题的界面
  • @app.route('/'):是装饰器,告诉 Flask:当用户访问网站的根路径时,触发下面的 home() 函数。
  • ef home()::是路由对应的处理函数,自定义名字为 home,主页
  • return render_template('index.html')
    • render_template 是 Flask 提供的函数,用于渲染、生成HTML 模板文件。
    • 找到项目中 templates 文件夹下的 index.html 文件,加载到浏览器

@app.route('/submit')
def submit_page():return render_template('submit.html')
  • /submit:返回提交问题的界面

2.3. 搜索 API

定义了一个处理搜索请求的 API 接口,它接收用户的问题和标签筛选条件,调用问答系统进行语义搜索,并返回匹配的结果。

  • 功能:处理用户的搜索请求
  • 输入:
    • question:用户输入的问题
    • tags:用户选择的筛选标签(一级标签和二级标签)
  • 处理流程:
    1. 调用 system.search_with_tags 方法进行带标签的语义搜索
    2. 返回匹配的问题列表(包含问题、答案、相似度和标签信息)

2.3.1. 路由定义与请求方式
@app.route('/search', methods=['POST'])
def search():# ...
  • URL 路径:/search
  • 请求方法:POST,适合传输大量数据、复杂的搜索条件
  • 功能:处理用户的搜索请求,返回匹配的问答结果

2.3.2. 接收请求数据

从 Flask 的 HTTP 请求中提取搜索所需的参数

data = request.json
query = data.get('question', '')
selected_tags = data.get('tags', {})
  • request.json:获取前端通过 JSON 格式发送的数据
  • 参数:
    • question:用户输入的问题文本
    • tags:用户选择的筛选标签

为什么用 .get() 而不是直接索引?

# 直接索引可能报错

query = data['question'] # 如果 'question' 不存在,会抛出 KeyError

# 使用 get() 更安全

query = data.get('question', '') # 不存在时返回默认值

2.3.3. 调用搜索逻辑

调用问答系统的核心搜索方法,执行带标签筛选的语义搜索

results = system.search_with_tags(query, selected_tags)
  • system:全局的问答系统实例
  • search_with_tags 方法:
    • 根据问题query、标签selected_tags进行语义搜索,输出匹配的问答结果列表
    1. 对用户问题进行向量化
    2. 返回排序后的匹配结果列表
    3. 根据标签条件过滤结果
    4. 在向量库中查找最相似的问题

2.3.4. 处理返回结果

将搜索结果转换为 JSON 格式并返回给前端,是 API 接口的最后一步。

return jsonify([{'question': r['question'],'answer': r['answer'],'similarity': r['similarity'],'tags': r['tags']
} for r in results])
  • 返回格式:JSON 数组,每个元素包含:
    • question:匹配的问题文本
    • answer:对应的答案
    • similarity:相似度得分,0-1 之间,越接近 1 越相似
    • tags:问题关联的标签
  • jsonify() 函数:Flask 提供的工具,将 Python 对象转换为 JSON 响应;自动设置 HTTP 头Content-Type: application/json
  • [{...} for r in results]:遍历 results 列表中的每个结果对象 r,提取需要的字段,构建新的字典列表

2.4. 提交问题 API

定义了一个处理用户提交新问题的 API 接口,它接收用户输入的问题、答案和标签信息,验证数据有效性后将其保存到系统中。

2.4.1. 路由定义与请求方式
@app.route('/submit_question', methods=['POST'])
def submit_question():# ...
  • URL 路径:/submit_question
  • 请求方法:POST(用于向服务器提交数据)
  • 功能:接收用户提交的新问题和答案,保存到系统中

2.4.2. 数据验证与处理
data = request.json
question = data.get('question', '').strip()
answer = data.get('answer', '').strip()
level1 = data.get('level1', '').strip()
level2 = data.get('level2', [])
  • request.json:获取前端发送的 JSON 数据
  • 默认值:
    • ''(空字符串):键不存在时返回 None
    • [](空列表):确保 level2 始终是列表类型

if not question or not answer or not level1:return jsonify({'status': 400, 'msg': '问题、答案和一级标签不能为空'})
if level1 not in system.tag_config.LEVEL1:return jsonify({'status': 400, 'msg': '无效的一级标签'})
for tag in level2:if tag not in system.tag_config.LEVEL2.get(level1, []):return jsonify({'status': 400, 'msg': '无效的二级标签组合'})
  • 检查必填字段(问题、答案、一级标签)是否为空
  • 验证一级标签是否存在于系统中
  • 验证二级标签是否属于所选一级标签的合法子标签

2.4.3. 文本清洗与去重
clean_q = clean_text(question)existing = [q['cleaned_question'] for q in system.qa_pairs]
if clean_q in existing:return jsonify({'status': 400, 'msg': '该问题已存在'})
  • clean_text 函数:对原始文本预处理
  • existing = [q['cleaned_question'] for q in system.qa_pairs]:构建已有的问答列表
  • qa_pairs 是问答系统中的核心数据结构,用于存储所有问题 - 答案对及其相关元信息。
    • 数据结构:qa_pairs 通常是一个列表,每个元素是一个字典,包含问题、答案、标签等信息:
qa_pairs = [{"original_question": "原始问题文本","cleaned_question": "清洗后的问题",  # 经过预处理的问题"answer": "对应的答案","tags": {"level1": "一级标签", "level2": ["二级标签1", "二级标签2"] },"embedding": [...]  # 问题的向量表示(可选)},# 更多问答对...
]

2.4.4. 数据持久化存储

将用户提交的新问题保存到 CSV 文件。

with open(system.config.QA_FILE, 'a', newline='', encoding='utf-8') as f:writer = csv.writer(f)writer.writerow([question,answer,level1,'/'.join(level2)])
  • with open(system.config.QA_FILE, 'a', newline='', encoding='utf-8') as f:
  • system.config.QA_FILE:配置中的 CSV 文件路径
  • 'a' 模式:追加模式(Append),在文件末尾添加新行,不会覆盖原有内容
  • newline='':避免 Windows 系统下额外的空行
  • csv.writer(f):创建 CSV 写入器对象
  • writer.writerow():写入一行数据,参数是一个列表,每个元素对应 CSV 的一列,用 / 连接二级标签
2.4.5. 内存数据热更新

将新提交的问题实时添加到内存中,无需重启服务。

new_entry = {"original_question": question,"cleaned_question": clean_q,"answer": answer,"tags": {'level1': level1,'level2': level2}
}
system.qa_pairs.append(new_entry)
new_vec = system._get_embedding(clean_q).reshape(1, -1)
system.question_vectors = np.vstack([system.question_vectors, new_vec])
  • 更新内存中的数据:
  • 将新问题new_entry添加到问答对列表 qa_pairs
  • 计算新问题的向量表示:
    • _get_embedding 方法:使用 BERT 模型将文本转换为向量(数字)
    • reshape(1, -1):调整向量形状,便于后续与原有向量矩阵合并
    • reshape(2, 3):将数组变为 2 行 3 列
    • reshape(1, -1):将数组变为 1 行,列数自动计算
  • 将新向量添加到向量库 question_vectors

3. 文本清理函数

该函数用于对文本进行清理,去除非中文、非数字、非字母的字符,去除停用词,并保留名词、动词、形容词等特定词性的词语。

def clean_text(text: str) -> str:text = re.sub(r"[^\w\u4e00-\u9fa5??!!]", "", text)text = text.strip()words = pseg.cut(text)
  • -> str:返回字符串
  • re.sub(r"[^\w\u4e00-\u9fa5??!!]", "", text):去除无关字符,匹配非单词字符、非中文字符、非问号、非感叹号的任意字符 替换为空
  • words = pseg.cut(text):分词并标记词性

正则表达式符号

含义

[]

字符集,匹配方括号内的任意字符

^

在字符集内表示取反(匹配不在方括号内的字符)

\w

匹配单词字符(等价于 [a-zA-Z0-9_],即字母、数字、下划线)

\u4e00-\u9fa5

匹配所有中文字符

??!!

显式保留中文问号、英文问号、中文感叹号、英文感叹号

 stopwords = set()with open(Config.STOPWORDS_FILE, "r", encoding='utf-8') as f:for line in f:word = line.strip()if word not in ["不", "没", "非常", "极其"]:stopwords.add(word)
  • 保留否定词("不", "没")和程度副词("非常", "极其"),避免影响语义(如 "不好" ≠ "好")

 keep_pos = {'n', 'v', 'a', 'nr', 'ns'}filtered_words = [word for word, flag in wordsif flag[0] in keep_pos and word not in stopwords]return " ".join(filtered_words)
  • 词性过滤:仅保留keep_pos的词性,筛选条件:词性首字母在keep_pos中且词不在停用表
  • for word, flag in wordswords 是一个包含 (词, 词性) 元组的列表,每次循环,word 得到词本身,flag 得到词性

4. 问答系统类定义🔺

QASystem 类是整个问答系统的核心,包含以下主要方法:

  • __init__:初始化系统,加载模型、数据,准备问题向量并验证向量。
  • _load_model:加载 BERT 模型和分词器,并将模型设置为评估模式。
  • _load_data:从 CSV 文件中加载问答对,并进行标签验证。
  • search_with_tags:根据用户选择的标签对搜索结果进行筛选。
  • _get_embedding:获取文本的嵌入向量,采用 BERT 最后四层隐藏状态的加权融合策略。
  • _prepare_vectors:为所有问答对的问题生成嵌入向量。
  • _validate_vectors:验证问题向量的有效性,检查重复问题的向量相似度,并计算向量空间的平均相似度。
  • search:实现搜索功能,采用三重去重机制和语义 - 关键词协同架构,计算查询问题与问答对问题之间的相似度,并返回相似度最高的结果。

4.1. 系统初始化

class QASystem:def __init__(self, config: Config):self.config = configself.tag_config = TagConfig(config.QA_FILE)  # 加载标签配置self.qa_pairs = []  # 存储问答对self._load_model()  # 加载BERT模型self._load_data()  # 加载问答数据self._prepare_vectors()  # 生成问题向量self._validate_vectors()  # 验证向量质量

4.2. BERT 模型加载与文本向量化

def _get_embedding(self, text: str) -> np.ndarray:# 1. 分词与编码inputs = self.tokenizer(text, ...).to(self.config.DEVICE)# 2. 模型推理(无梯度计算)with torch.no_grad():outputs = self.model(**inputs)# 3. 提取最后四层隐藏状态hidden_states = outputs.hidden_states[-4:]# 4. 加载权重并归一化weights = torch.tensor(self.config.LAYER_WEIGHTS).to(self.config.DEVICE)weights /= weights.sum()
  • self.tokenizer(text, ...):分词器
  • torch.no_grad() :禁用梯度计算,不存储中间变量的梯度信息、速度更快
  • self.model(**inputs):让 BERT 模型处理输入的数字文本,输出语义特征
    • ** 是 Python 的 “拆包” 操作,能把字典里的键值对变成函数的参数
  • torch.tensor(...):将 Python 列表转换为 PyTorch 张量(一种可在 GPU 上计算的数字格式)
  • .to(self.config.DEVICE):将张量放到与模型相同的设备CPU/GPU上,否则计算会报错

# 5. 加权融合CLS向量
cls_vectors = torch.stack([layer[:, 0, :] for layer in hidden_states])
fused_vector = torch.sum(cls_vectors * weights.view(-1, 1, 1), dim=0)
  • hidden_states :BERT 输出的最后四层隐藏状态,形状为[4个层, batch_size, 序列长度, 隐藏层维度],(4, 1, 10, 768) 表示 4 层、1 个样本、10 个 token、每个 token768 维向量。
  • layer[:, 0, :]:提取每层的 CLS 向量,即第 0 个 token 的向量
    • layer:单一层的隐藏状态(如 (1, 10, 768)
    • [:, 0, :]:取所有样本(:)、第 0 个 token(0,即 CLS 标记)、所有维度(:
    • 结果:单个 CLS 向量,形状为 (1, 768)
  • torch.stack(...):将不同层的 CLS 向量堆叠成三维张量
    • 输入是 4 个形状为 (1, 768) 的 CLS 向量
    • 输出堆叠后形状:(4, 1, 768):4层,1个样本,向量维度768维

# 6. 归一化并返回numpy数组
return fused_vector.cpu().numpy().squeeze() / np.linalg.norm(fused_vector)
  • fused_vector.cpu():将张量从 GPU移到 CPU 内存, numpy() 方法只能处理 CPU 上的张量
  • .numpy():将 PyTorch 张量转换为 NumPy 数组
  • .squeeze():移除数组中维度为 1 的轴,(1, 768)(768,)

4.3. 语义搜索与结果筛选🔺

def search(self, query: str) -> List[Dict]:cleaned_query = clean_text(query)  # 文本清洗query_emb = self._get_embedding(cleaned_query)  # 查询向量化# 计算余弦相似度raw_similarities = np.dot(self.question_vectors, query_emb)#语义-关键词协同架构# 1.语义相似度校准:#通过Sigmoid变换,将相似度压缩到 (0,1) 区间,重点放大 0.88 附近的差异calibrated = 1 / (1 + np.exp(-25*(raw_similarities-0.88)))similarities = 0.2*raw_similarities + 0.8*calibrated# 2.关键词匹配度query_keywords = set(jieba.lcut(cleaned_query))	#分词去重:jieba将清洗后的查询文本分词,set()去除重复关键词keyword_weights = np.array([#len():问题关键词集合、问答库关键词集合的交集的长度len(set(jieba.lcut(q["cleaned_question"])) & query_keywords)/ max(len(query_keywords), 1)for q in self.qa_pairs	#遍历问答库中所有候选问题])# 语义+关键词混合评分-7:3权重similarities = 0.7*similarities + 0.3*keyword_weightssimilarities = similarities + 0.0266  # 微调偏移# 三重去重机制sorted_indices = np.argsort(-similarities)#argsort 是 NumPy 的升序排列函数,此处"-"为降序results = []seen_hashes = set()for idx in sorted_indices:# 1.相似度阈值过滤:丢弃相关度过低的答案if similarities[idx] < self.config.MIN_SIMILARITY:continue# 2. 内容哈希去重#将问题文本和答案的前20个字符组成字符串,通过hash()函数生成哈希值content_hash = hash(data["cleaned_question"] + data["answer"][:20])if content_hash in seen_hashes:#去重:如果某个结果的哈希值已在seen_hashes集合中,则跳过该结果continue# 3. 向量相似度去重#dot():计算当前候选向量与已选结果向量的点积,即余弦相似度#any(...):只要存在一个已选结果满足条件,立即跳过当前候选if any(np.dot(self.question_vectors[idx], self.question_vectors[r["index"]]) > 0.93 for r in results):continueresults.append({...}) 	#将符合条件的候选结果添加到最终返回列表#按相似度得分降序排列,取前TOP_K个结果return sorted(results, key=lambda x: -x['similarity'])[:self.config.TOP_K]

4.4. 标签筛选功能

def search_with_tags(self, query: str, selected_tags: Dict) -> List[Dict]:# 1. 执行基础搜索base_results = self.search(query)# 2. 初始化过滤结果列表filtered = []# 3. 遍历基础结果,进行标签匹配for item in base_results:# 一级标签匹配:完全匹配或未指定l1_match = not selected_tags.get('level1') or item['tags']['level1'] in selected_tags['level1']# 二级标签匹配:部分匹配或未指定l2_match = not selected_tags.get('level2') or any(tag in item['tags']['level2'] for tag in selected_tags['level2'])# 4. 同时满足两个条件则加入结果if l1_match and l2_match:filtered.append(item)# 5. 排序并返回Top_K结果return sorted(filtered, key=lambda x: -x['similarity'])[:self.config.TOP_K]

4.5. 向量质量验证

def _validate_vectors(self):# 1. 检查重复问题的向量相似度for i in range(len(self.qa_pairs)):for j in range(i+1, len(self.qa_pairs)):# 找到清洗后文本相同的重复问题if self.qa_pairs[i]["cleaned_question"] == self.qa_pairs[j]["cleaned_question"]:# 计算向量相似度(点积,因为向量已归一化)sim = np.dot(self.question_vectors[i], self.question_vectors[j])# 如果相似度显著偏离1.0,发出警告if abs(sim - 1.0) > 1e-6:print(f"警告:重复问题向量差异过大 [{i}] vs [{j}]: {sim:.4f}")# 2. 计算向量空间平均相似度# 计算所有向量对的相似度矩阵sim_matrix = np.dot(self.question_vectors, self.question_vectors.T)# 取矩阵对角线以上的元素(避免重复计算),计算平均值avg_sim = np.mean(sim_matrix[np.triu_indices(len(sim_matrix), k=1)])print(f"向量空间平均相似度: {avg_sim:.2f}")

5. 主函数

def main():global systemjieba.initialize()system = QASystem(Config())app.run(host='0.0.0.0', port=5000, debug=False)if __name__ == "__main__":main()

主函数初始化结巴分词器,创建 QASystem 实例,并启动 Flask 应用。

http://www.lryc.cn/news/595434.html

相关文章:

  • 企业开发转型 | 前端AI化数字化自动化现状
  • 自动化商品监控:利用淘宝API开发实时价格库存采集接口
  • 【unitrix】 6.11 二进制数字标准化模块(normalize.rs)
  • G7打卡——Semi-Supervised GAN
  • Acrobat JavaScript 中的 `app.response()` 方法
  • 【学习路线】C#企业级开发之路:从基础语法到云原生应用
  • 基于MySQL实现分布式调度系统的选举算法
  • 一文速通《矩阵的特征值和特征向量》
  • Tomcat的部署、单体架构、session会话、spring
  • PostgreSQL高可用架构Repmgr部署流程
  • 计算机网络中:传输层和网络层之间是如何配合的
  • socket编程(UDP)
  • vue2使用v-viewer图片预览:打开页面自动预览,禁止关闭预览,解决在微信浏览器的页面点击事件老是触发预览初始化的问题
  • Linux 721 创建实现镜像的逻辑卷
  • 网络数据分层封装与解封过程的详细说明
  • 讯飞输入法3.0.1742功能简介
  • AI Agent开发学习系列 - langchain之LCEL(3):Prompt+LLM
  • 20250721
  • 【React】npm install报错npm : 无法加载文件 D:\APP\nodejs\npm.ps1,因为在此系统上禁止运行脚本。
  • 2x2矩阵教程
  • kafka 日志索引 AbstractIndex
  • 前端包管理工具深度对比:npm、yarn、pnpm 全方位解析
  • maven下载地址以及setting.xml配置
  • 【科研绘图系列】R语言绘制棒棒图和哑铃图
  • Pytorch01:深度学习中的专业名词及基本介绍
  • k8s查看某个pod的svc
  • 【高等数学】第五章 定积分——第一节 定积分的概念与性质
  • PostgreSQL SysCache RelCache
  • OCR 身份识别:让身份信息录入场景更高效安全
  • 低代码/无代码平台如何重塑开发生态