20250720问答课题-基于BERT与混合检索问答系统代码解读
1. 配置
1.1. TagConfig 类
该类用于从 CSV 文件中读取标签数据,并把这些标签整理成一级、二级标签,并进行整理和排序。
def __init__(self, csv_path):
- 该类的构造函数 __init__ 接受一个 CSV 文件路径作为参数,说明TagConfig 类在创建时需要一个 CSV 文件的路径
- __init__ 是Python 类中的构造函数
- self 必须是第一个参数,代表该类的实例对象本身,通过 self 可以访问和修改对象的属性
- csv_path:是自定义的参数
def __init__(self, csv_path):
csv_path = csv_path # 这只是一个局部变量,函数结束后就消失
self.csv_path = csv_path # 这是对象的属性,其他方法可以访问
#初始化部分
self.LEVEL1 = set()
self.LEVEL2 = defaultdict(set)
- LEVEL1 是一个「集合」,用来存放所有的一级标签
- set():创建一个空的集合
- LEVEL2 是一个「特殊字典」,键值对,键是一级标签,值是对应二级标签的集合
- defaultdict(set):创建一个默认值为集合的字典,当访问不存在的键时,会自动创建一个新的 set 作为默认值。
#读取 CSV 文件并处理标签
with open(csv_path, 'r', encoding='utf-8') as f:reader = csv.DictReader(f)for row in reader:if '一级标签' in row and row['一级标签']:l1 = row['一级标签'].strip()self.LEVEL1.add(l1)if '二级标签' in row and row['二级标签']:l2_list = [t.strip() for t in row['二级标签'].split('/') if t.strip()]for l2 in l2_list:self.LEVEL2[l1].add(l2)
- 代码会打开并读取 CSV 文件
- 它假设 CSV 文件里有「一级标签」和「二级标签」这两列
- 遍历每一行数据:
- 先提取一级标签(确保该行包含 '一级标签' 列且值不为空), strip() 方法去掉前后空格后,添加到 LEVEL1 集合中
- 如果这一行还有二级标签,会把它们按
/
分割成多个标签,逐个添加到对应一级标签的二级标签集合中
csv.DictReader(f)
:将 CSV 文件的每一行转换为 字典,键是 CSV 的表头(第一行),值是对应列的数据。
一级标签,二级标签
水果,苹果/香蕉
蔬菜,胡萝卜/西红柿
上述为CSV文件内容,则第一行解析为:{'一级标签': '水果', '二级标签': '苹果/香蕉'}
#对标签进行排序
self.LEVEL1 = sorted(self.LEVEL1)
self.LEVEL2 = {k: sorted(v) for k, v in self.LEVEL2.items()}
sorted()
函数把一级标签和二级标签都按字母升序排列
调用,TagConfig 类在 QASystem 类的构造函数中被调用
class QASystem:def __init__(self, config: Config):self.config = configself.tag_config = TagConfig(config.QA_FILE)# ...
- 在 QASystem 类的构造函数中,创建了一个 TagConfig 类的实例,并将 Config 类中的 QA_FILE 属性作为参数传递给 TagConfig 类的构造函数。
1.2. Config
类
class Config:BERT_PATH = "./BERT"QA_FILE = "心法问答.csv"STOPWORDS_FILE = "stopwords.txt"TOP_K = 1000MIN_SIMILARITY = 0.6DEVICE = "cuda" if torch.cuda.is_available() else "cpu"LAYER_WEIGHTS = [0.15, 0.25, 0.35, 0.25]
该类定义了系统的各种配置参数,包括 BERT 模型路径、问答文件路径、停用词文件路径、返回结果数量、最小相似度阈值、设备选择和 BERT 最后四层隐藏状态的加权系数。
BERT_PATH = "./BERT"
:BERT 预训练模型的本地路径,用于文本编码或特征提取QA_FILE = "心法问答.csv"
:问答对数据文件,包含问题和对应的答案,格式为 CSVSTOPWORDS_FILE = "stopwords.txt"
:停用词文件路径,包含需要过滤的的词汇TOP_K = 1000
:在检索或排序任务中,保留得分最高的前 1000 个结果MIN_SIMILARITY = 0.6
:最小相似度阈值,只有相似度高于 0.6 的结果才会被保留DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
:自动检测是否有可用的 GPU,如果有则使用 GPU 加速计算,否则使用 CPU。
LAYER_WEIGHTS = [0.15, 0.25, 0.35, 0.25]
:BERT 不同隐藏层的权重系数,用于层加权策略,第 3 层权重最高(0.35),该层最重要
2. Flask 路由定义
是一个基于 Flask 的 Web 应用,提供了问答系统的前端交互和后端 API 服务。
定义了四个 Flask 路由:
/
:返回主页。/submit
:返回提交页面。/search
:处理搜索请求,调用search_with_tags
方法进行搜索,并返回搜索结果。/submit_question
:处理问题提交请求,对提交的问题进行验证,若问题不存在则将其添加到问答文件和内存中,并更新问题向量。
2.1. 应用初始化
app = Flask(__name__)
system = None
__name__
:当前模块的名称,是Python 内置变量,便于Flask 确定应用的根路径,找到模板和静态文件app = Flask(__name__)
:创建 Flask 应用实例system
:全局变量,用于存储问答系统的实例;在应用启动时不立即初始化system
,而在需要时再初始化;延时初始化便于测试和部署,可以在不同环境中用不同配置初始化system
2.2. 页面路由
是 Flask 框架中定义的路由规则,将 URL 路径与对应的处理函数绑定,从而返回不同的网页内容。
@app.route('/')
def home():return render_template('index.html')
/
:返回主页模板,用户搜索问题的界面@app.route('/')
:是装饰器,告诉 Flask:当用户访问网站的根路径时,触发下面的home()
函数。ef home():
:是路由对应的处理函数,自定义名字为 home,主页return render_template('index.html')
:render_template
是 Flask 提供的函数,用于渲染、生成HTML 模板文件。- 找到项目中
templates
文件夹下的index.html
文件,加载到浏览器
@app.route('/submit')
def submit_page():return render_template('submit.html')
/submit
:返回提交问题的界面
2.3. 搜索 API
定义了一个处理搜索请求的 API 接口,它接收用户的问题和标签筛选条件,调用问答系统进行语义搜索,并返回匹配的结果。
- 功能:处理用户的搜索请求
- 输入:
question
:用户输入的问题tags
:用户选择的筛选标签(一级标签和二级标签)
- 处理流程:
- 调用
system.search_with_tags
方法进行带标签的语义搜索 - 返回匹配的问题列表(包含问题、答案、相似度和标签信息)
- 调用
2.3.1. 路由定义与请求方式
@app.route('/search', methods=['POST'])
def search():# ...
- URL 路径:
/search
- 请求方法:
POST
,适合传输大量数据、复杂的搜索条件 - 功能:处理用户的搜索请求,返回匹配的问答结果
2.3.2. 接收请求数据
从 Flask 的 HTTP 请求中提取搜索所需的参数
data = request.json
query = data.get('question', '')
selected_tags = data.get('tags', {})
request.json
:获取前端通过 JSON 格式发送的数据- 参数:
question
:用户输入的问题文本tags
:用户选择的筛选标签
为什么用
.get()
而不是直接索引?# 直接索引可能报错
query = data['question'] # 如果 'question' 不存在,会抛出 KeyError
# 使用 get() 更安全
query = data.get('question', '') # 不存在时返回默认值
2.3.3. 调用搜索逻辑
调用问答系统的核心搜索方法,执行带标签筛选的语义搜索
results = system.search_with_tags(query, selected_tags)
system
:全局的问答系统实例search_with_tags
方法:- 根据问题
query
、标签selected_tags
进行语义搜索,输出匹配的问答结果列表
- 对用户问题进行向量化
- 返回排序后的匹配结果列表
- 根据标签条件过滤结果
- 在向量库中查找最相似的问题
- 根据问题
2.3.4. 处理返回结果
将搜索结果转换为 JSON 格式并返回给前端,是 API 接口的最后一步。
return jsonify([{'question': r['question'],'answer': r['answer'],'similarity': r['similarity'],'tags': r['tags']
} for r in results])
- 返回格式:JSON 数组,每个元素包含:
question
:匹配的问题文本answer
:对应的答案similarity
:相似度得分,0-1 之间,越接近 1 越相似tags
:问题关联的标签
jsonify()
函数:Flask 提供的工具,将 Python 对象转换为 JSON 响应;自动设置 HTTP 头Content-Type: application/json[{...} for r in results]
:遍历 results 列表中的每个结果对象 r,提取需要的字段,构建新的字典列表
2.4. 提交问题 API
定义了一个处理用户提交新问题的 API 接口,它接收用户输入的问题、答案和标签信息,验证数据有效性后将其保存到系统中。
2.4.1. 路由定义与请求方式
@app.route('/submit_question', methods=['POST'])
def submit_question():# ...
- URL 路径:
/submit_question
- 请求方法:
POST
(用于向服务器提交数据) - 功能:接收用户提交的新问题和答案,保存到系统中
2.4.2. 数据验证与处理
data = request.json
question = data.get('question', '').strip()
answer = data.get('answer', '').strip()
level1 = data.get('level1', '').strip()
level2 = data.get('level2', [])
request.json
:获取前端发送的 JSON 数据- 默认值:
''
(空字符串):键不存在时返回None
[]
(空列表):确保level2
始终是列表类型
if not question or not answer or not level1:return jsonify({'status': 400, 'msg': '问题、答案和一级标签不能为空'})
if level1 not in system.tag_config.LEVEL1:return jsonify({'status': 400, 'msg': '无效的一级标签'})
for tag in level2:if tag not in system.tag_config.LEVEL2.get(level1, []):return jsonify({'status': 400, 'msg': '无效的二级标签组合'})
- 检查必填字段(问题、答案、一级标签)是否为空
- 验证一级标签是否存在于系统中
- 验证二级标签是否属于所选一级标签的合法子标签
2.4.3. 文本清洗与去重
clean_q = clean_text(question)existing = [q['cleaned_question'] for q in system.qa_pairs]
if clean_q in existing:return jsonify({'status': 400, 'msg': '该问题已存在'})
clean_text
函数:对原始文本预处理existing = [q['cleaned_question'] for q in system.qa_pairs]
:构建已有的问答列表qa_pairs
是问答系统中的核心数据结构,用于存储所有问题 - 答案对及其相关元信息。- 数据结构:
qa_pairs
通常是一个列表,每个元素是一个字典,包含问题、答案、标签等信息:
- 数据结构:
qa_pairs = [{"original_question": "原始问题文本","cleaned_question": "清洗后的问题", # 经过预处理的问题"answer": "对应的答案","tags": {"level1": "一级标签", "level2": ["二级标签1", "二级标签2"] },"embedding": [...] # 问题的向量表示(可选)},# 更多问答对...
]
2.4.4. 数据持久化存储
将用户提交的新问题保存到 CSV 文件。
with open(system.config.QA_FILE, 'a', newline='', encoding='utf-8') as f:writer = csv.writer(f)writer.writerow([question,answer,level1,'/'.join(level2)])
with open(system.config.QA_FILE, 'a', newline='', encoding='utf-8') as f:
:
system.config.QA_FILE
:配置中的 CSV 文件路径'a'
模式:追加模式(Append),在文件末尾添加新行,不会覆盖原有内容newline=''
:避免 Windows 系统下额外的空行csv.writer(f)
:创建 CSV 写入器对象writer.writerow()
:写入一行数据,参数是一个列表,每个元素对应 CSV 的一列,用/
连接二级标签
2.4.5. 内存数据热更新
将新提交的问题实时添加到内存中,无需重启服务。
new_entry = {"original_question": question,"cleaned_question": clean_q,"answer": answer,"tags": {'level1': level1,'level2': level2}
}
system.qa_pairs.append(new_entry)
new_vec = system._get_embedding(clean_q).reshape(1, -1)
system.question_vectors = np.vstack([system.question_vectors, new_vec])
- 更新内存中的数据:
- 将新问题
new_entry
添加到问答对列表qa_pairs
- 计算新问题的向量表示:
_get_embedding
方法:使用 BERT 模型将文本转换为向量(数字)reshape(1, -1)
:调整向量形状,便于后续与原有向量矩阵合并reshape(2, 3)
:将数组变为 2 行 3 列reshape(1, -1)
:将数组变为 1 行,列数自动计算
- 将新向量添加到向量库
question_vectors
3. 文本清理函数
该函数用于对文本进行清理,去除非中文、非数字、非字母的字符,去除停用词,并保留名词、动词、形容词等特定词性的词语。
def clean_text(text: str) -> str:text = re.sub(r"[^\w\u4e00-\u9fa5??!!]", "", text)text = text.strip()words = pseg.cut(text)
-> str
:返回字符串re.sub(r"[^\w\u4e00-\u9fa5??!!]", "", text)
:去除无关字符,匹配非单词字符、非中文字符、非问号、非感叹号的任意字符 替换为空words = pseg.cut(text)
:分词并标记词性
正则表达式符号 | 含义 |
| 字符集,匹配方括号内的任意字符 |
| 在字符集内表示取反(匹配不在方括号内的字符) |
| 匹配单词字符(等价于 |
| 匹配所有中文字符 |
| 显式保留中文问号、英文问号、中文感叹号、英文感叹号 |
stopwords = set()with open(Config.STOPWORDS_FILE, "r", encoding='utf-8') as f:for line in f:word = line.strip()if word not in ["不", "没", "非常", "极其"]:stopwords.add(word)
- 保留否定词("不", "没")和程度副词("非常", "极其"),避免影响语义(如 "不好" ≠ "好")
keep_pos = {'n', 'v', 'a', 'nr', 'ns'}filtered_words = [word for word, flag in wordsif flag[0] in keep_pos and word not in stopwords]return " ".join(filtered_words)
- 词性过滤:仅保留
keep_pos
的词性,筛选条件:词性首字母在keep_pos
中且词不在停用表 for word, flag in words
:words
是一个包含 (词, 词性) 元组的列表,每次循环,word
得到词本身,flag
得到词性
4. 问答系统类定义🔺
QASystem
类是整个问答系统的核心,包含以下主要方法:
__init__
:初始化系统,加载模型、数据,准备问题向量并验证向量。_load_model
:加载 BERT 模型和分词器,并将模型设置为评估模式。_load_data
:从 CSV 文件中加载问答对,并进行标签验证。search_with_tags
:根据用户选择的标签对搜索结果进行筛选。_get_embedding
:获取文本的嵌入向量,采用 BERT 最后四层隐藏状态的加权融合策略。_prepare_vectors
:为所有问答对的问题生成嵌入向量。_validate_vectors
:验证问题向量的有效性,检查重复问题的向量相似度,并计算向量空间的平均相似度。search
:实现搜索功能,采用三重去重机制和语义 - 关键词协同架构,计算查询问题与问答对问题之间的相似度,并返回相似度最高的结果。
4.1. 系统初始化
class QASystem:def __init__(self, config: Config):self.config = configself.tag_config = TagConfig(config.QA_FILE) # 加载标签配置self.qa_pairs = [] # 存储问答对self._load_model() # 加载BERT模型self._load_data() # 加载问答数据self._prepare_vectors() # 生成问题向量self._validate_vectors() # 验证向量质量
4.2. BERT 模型加载与文本向量化
def _get_embedding(self, text: str) -> np.ndarray:# 1. 分词与编码inputs = self.tokenizer(text, ...).to(self.config.DEVICE)# 2. 模型推理(无梯度计算)with torch.no_grad():outputs = self.model(**inputs)# 3. 提取最后四层隐藏状态hidden_states = outputs.hidden_states[-4:]# 4. 加载权重并归一化weights = torch.tensor(self.config.LAYER_WEIGHTS).to(self.config.DEVICE)weights /= weights.sum()
self.tokenizer(text, ...)
:分词器torch.no_grad()
:禁用梯度计算,不存储中间变量的梯度信息、速度更快self.model(**inputs)
:让 BERT 模型处理输入的数字文本,输出语义特征**
是 Python 的 “拆包” 操作,能把字典里的键值对变成函数的参数
torch.tensor(...)
:将 Python 列表转换为 PyTorch 张量(一种可在 GPU 上计算的数字格式).to(self.config.DEVICE)
:将张量放到与模型相同的设备CPU/GPU上,否则计算会报错
# 5. 加权融合CLS向量
cls_vectors = torch.stack([layer[:, 0, :] for layer in hidden_states])
fused_vector = torch.sum(cls_vectors * weights.view(-1, 1, 1), dim=0)
hidden_states
:BERT 输出的最后四层隐藏状态,形状为[4个层, batch_size, 序列长度, 隐藏层维度],(4, 1, 10, 768) 表示 4 层、1 个样本、10 个 token、每个 token768 维向量。
layer[:, 0, :]
:提取每层的 CLS 向量,即第 0 个 token 的向量layer
:单一层的隐藏状态(如(1, 10, 768)
)[:, 0, :]
:取所有样本(:
)、第 0 个 token(0
,即 CLS 标记)、所有维度(:
)- 结果:单个 CLS 向量,形状为
(1, 768)
torch.stack(...)
:将不同层的 CLS 向量堆叠成三维张量- 输入是 4 个形状为
(1, 768)
的 CLS 向量 - 输出堆叠后形状:
(4, 1, 768)
:4层,1个样本,向量维度768维
- 输入是 4 个形状为
# 6. 归一化并返回numpy数组
return fused_vector.cpu().numpy().squeeze() / np.linalg.norm(fused_vector)
fused_vector.cpu()
:将张量从 GPU移到 CPU 内存,numpy()
方法只能处理 CPU 上的张量.numpy()
:将 PyTorch 张量转换为 NumPy 数组.squeeze()
:移除数组中维度为 1 的轴,(1, 768)
→(768,)
4.3. 语义搜索与结果筛选🔺
def search(self, query: str) -> List[Dict]:cleaned_query = clean_text(query) # 文本清洗query_emb = self._get_embedding(cleaned_query) # 查询向量化# 计算余弦相似度raw_similarities = np.dot(self.question_vectors, query_emb)#语义-关键词协同架构# 1.语义相似度校准:#通过Sigmoid变换,将相似度压缩到 (0,1) 区间,重点放大 0.88 附近的差异calibrated = 1 / (1 + np.exp(-25*(raw_similarities-0.88)))similarities = 0.2*raw_similarities + 0.8*calibrated# 2.关键词匹配度query_keywords = set(jieba.lcut(cleaned_query)) #分词去重:jieba将清洗后的查询文本分词,set()去除重复关键词keyword_weights = np.array([#len():问题关键词集合、问答库关键词集合的交集的长度len(set(jieba.lcut(q["cleaned_question"])) & query_keywords)/ max(len(query_keywords), 1)for q in self.qa_pairs #遍历问答库中所有候选问题])# 语义+关键词混合评分-7:3权重similarities = 0.7*similarities + 0.3*keyword_weightssimilarities = similarities + 0.0266 # 微调偏移# 三重去重机制sorted_indices = np.argsort(-similarities)#argsort 是 NumPy 的升序排列函数,此处"-"为降序results = []seen_hashes = set()for idx in sorted_indices:# 1.相似度阈值过滤:丢弃相关度过低的答案if similarities[idx] < self.config.MIN_SIMILARITY:continue# 2. 内容哈希去重#将问题文本和答案的前20个字符组成字符串,通过hash()函数生成哈希值content_hash = hash(data["cleaned_question"] + data["answer"][:20])if content_hash in seen_hashes:#去重:如果某个结果的哈希值已在seen_hashes集合中,则跳过该结果continue# 3. 向量相似度去重#dot():计算当前候选向量与已选结果向量的点积,即余弦相似度#any(...):只要存在一个已选结果满足条件,立即跳过当前候选if any(np.dot(self.question_vectors[idx], self.question_vectors[r["index"]]) > 0.93 for r in results):continueresults.append({...}) #将符合条件的候选结果添加到最终返回列表#按相似度得分降序排列,取前TOP_K个结果return sorted(results, key=lambda x: -x['similarity'])[:self.config.TOP_K]
4.4. 标签筛选功能
def search_with_tags(self, query: str, selected_tags: Dict) -> List[Dict]:# 1. 执行基础搜索base_results = self.search(query)# 2. 初始化过滤结果列表filtered = []# 3. 遍历基础结果,进行标签匹配for item in base_results:# 一级标签匹配:完全匹配或未指定l1_match = not selected_tags.get('level1') or item['tags']['level1'] in selected_tags['level1']# 二级标签匹配:部分匹配或未指定l2_match = not selected_tags.get('level2') or any(tag in item['tags']['level2'] for tag in selected_tags['level2'])# 4. 同时满足两个条件则加入结果if l1_match and l2_match:filtered.append(item)# 5. 排序并返回Top_K结果return sorted(filtered, key=lambda x: -x['similarity'])[:self.config.TOP_K]
4.5. 向量质量验证
def _validate_vectors(self):# 1. 检查重复问题的向量相似度for i in range(len(self.qa_pairs)):for j in range(i+1, len(self.qa_pairs)):# 找到清洗后文本相同的重复问题if self.qa_pairs[i]["cleaned_question"] == self.qa_pairs[j]["cleaned_question"]:# 计算向量相似度(点积,因为向量已归一化)sim = np.dot(self.question_vectors[i], self.question_vectors[j])# 如果相似度显著偏离1.0,发出警告if abs(sim - 1.0) > 1e-6:print(f"警告:重复问题向量差异过大 [{i}] vs [{j}]: {sim:.4f}")# 2. 计算向量空间平均相似度# 计算所有向量对的相似度矩阵sim_matrix = np.dot(self.question_vectors, self.question_vectors.T)# 取矩阵对角线以上的元素(避免重复计算),计算平均值avg_sim = np.mean(sim_matrix[np.triu_indices(len(sim_matrix), k=1)])print(f"向量空间平均相似度: {avg_sim:.2f}")
5. 主函数
def main():global systemjieba.initialize()system = QASystem(Config())app.run(host='0.0.0.0', port=5000, debug=False)if __name__ == "__main__":main()
主函数初始化结巴分词器,创建 QASystem
实例,并启动 Flask 应用。