趋近智
检索文档的质量对于整个 RAG 系统很重要。尽管像领域专用嵌入模型和混合搜索这样的方法能扩大检索范围并改进初步候选选择,但重排序则像一把密齿梳,细致筛选这些候选,将最符合的结果置于首位。一个使用交叉编码器模型的进阶重排序阶段的实施和其效果评估将在此处演示。
我们将模拟一个常见情形:用户提出问题,我们的初步检索(通常称作基于双编码器的系统)获取一份可能相关文档的列表,然后重排序器(交叉编码器)重新评估这些最佳候选,以生成一份更准确的最终列表。
首先,请确保你已安装必要的库。我们将主要使用 sentence-transformers 作为初步检索器和重排序器,因为它为各种预训练模型提供了便捷接口。
# 确保你已安装这些库:
# pip install sentence-transformers torch
让我们定义一个小规模文档语料库和几个带有已知相关文档的示例查询。在实际情形中,这个语料库会大得多,用于评估的真实数据也会更广泛。
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
# 示例文档(我们的知识库)
documents = [
{"id": "doc1", "text": "我们的软件支持 Windows 10、Windows 11 以及 macOS Monterey 或更新版本。"},
{"id": "doc2", "text": "要安装,请从我们的网站下载安装程序并运行。按照屏幕上的提示操作。"},
{"id": "doc3", "text": "许可证密钥可在您的购买确认邮件中找到。在‘激活’窗口中输入。"}
{"id": "doc4", "text": "如需故障排除,请查看我们的在线知识库或通过 [email protected] 联系支持。"},
{"id": "doc5", "text": "系统要求至少包含 4GB 内存和 10GB 可用磁盘空间。建议使用现代 CPU 以获得最佳性能。"},
{"id": "doc6", "text": "更新会自动下载和安装。您可以通过‘帮助’菜单手动检查更新。"}
]
doc_texts = [doc['text'] for doc in documents]
# 带有真实数据的示例查询,用于评估
queries_with_ground_truth = [
{"query": "如何安装软件?", "relevant_doc_id": "doc2", "relevant_doc_text": documents[1]["text"]},
{"query": "支持哪些操作系统?", "relevant_doc_id": "doc1", "relevant_doc_text": documents[0]["text"]},
{"query": "我的许可证密钥在哪里?", "relevant_doc_id": "doc3", "relevant_doc_text": documents[2]["text"]},
{"query": "内存要求是什么?", "relevant_doc_id": "doc5", "relevant_doc_text": documents[4]["text"]}
]
双编码器模型,类似于语义搜索中常用的模型,独立计算查询和所有文档的嵌入向量。相关性随后通过这些嵌入向量之间的相似度(例如余弦相似度)来确定。
# 加载用于初步检索的双编码器模型
bi_encoder = SentenceTransformer('all-MiniLM-L6-v2')
# 编码我们的文档语料库
doc_embeddings = bi_encoder.encode(doc_texts, convert_to_tensor=True)
# 执行初步检索的函数
def retrieve_initial_documents(query_text, top_k=3):
query_embedding = bi_encoder.encode(query_text, convert_to_tensor=True)
# 我们使用余弦相似度和 torch.topk 来找到最高分数
cos_scores = util.cos_sim(query_embedding, doc_embeddings)[0]
top_results = torch.topk(cos_scores, k=top_k)
retrieved_docs = []
print(f"\n查询: {query_text}")
print("初步最高结果(双编码器):")
for i, (score, idx) in enumerate(zip(top_results[0], top_results[1])):
retrieved_docs.append({"id": documents[idx.item()]["id"], "text": documents[idx.item()]["text"], "score": score.item()})
print(f"{i+1}. ID: {documents[idx.item()]['id']}, 分数: {score.item():.4f}, 文本: {documents[idx.item()]['text'][:100]}...")
return retrieved_docs
# 让我们测试一个查询的初步检索
sample_query = queries_with_ground_truth[0]["query"] # "如何安装软件?"
initial_candidates = retrieve_initial_documents(sample_query, top_k=3)
你会发现初步检索速度很快。然而,最高结果不一定总是将最相关的文档排在首位,或者可能包含仅是略微相关的文档。对于“如何安装软件?”,doc2 是理想的。让我们看看它是否排在首位。有时,文档 doc6(“更新会自动下载和安装...”)可能会因为共享“安装”等词语而排名靠前,即使它不是关于初次设置的。
交叉编码器模型的工作方式不同。它们不是比较独立的嵌入向量,而是将查询和文档 对 作为输入,并输出一个表示其相关性的单一分数。这使得模型能够进行更深层次、更精细的比较,通常能带来更优的相关性排序,但计算成本更高。因此,它们通常用于对初步、更快的检索阶段得到的一小组候选文档进行重排序。
# 加载用于重排序的交叉编码器模型
# 常见选择包括在 MS MARCO 或类似段落排序数据集上微调的模型。
# 'cross-encoder/ms-marco-MiniLM-L-6-v2' 是一个不错的、相对较小的模型。
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# 使用交叉编码器重排序文档的函数
def rerank_documents(query_text, candidate_docs):
# 为交叉编码器准备对:[ (查询, doc_text1), (查询, doc_text2), ... ]
pairs = []
for doc in candidate_docs:
pairs.append((query_text, doc['text']))
# 从交叉编码器获取分数
# cross_encoder.predict() 方法接受一个对列表并返回一个分数列表。
scores = cross_encoder.predict(pairs)
# 将候选文档与其新分数结合并排序
for i in range(len(candidate_docs)):
candidate_docs[i]['cross_score'] = scores[i]
# 按新的交叉编码器分数降序排序
reranked_docs = sorted(candidate_docs, key=lambda x: x['cross_score'], reverse=True)
print("\n重排序结果(交叉编码器):")
for i, doc in enumerate(reranked_docs):
print(f"{i+1}. ID: {doc['id']}, 交叉分数: {doc['cross_score']:.4f}, 文本: {doc['text'][:100]}...")
return reranked_docs
# 重排序我们之前示例中的候选文档
reranked_candidates = rerank_documents(sample_query, initial_candidates)
观察输出。你应该会看到交叉编码器可能重新排序了 initial_candidates。理想情况下,最相关的文档(例如,针对“如何安装软件?”的 doc2)现在拥有最高的 cross_score 并排在首位。这些分数本身不同于双编码器的余弦相似度;交叉编码器分数通常是 logits,它们不限于 0 到 1 之间,而是直接反映相关性。
为了客观衡量改进,我们需要评估指标。对于排序任务,常用指标包括:
让我们实施一个简单的评估。
def calculate_mrr_and_precision_at_1(ranked_results_list, ground_truth_list):
reciprocal_ranks = []
precision_at_1_scores = []
for i, ranked_docs in enumerate(ranked_results_list):
query_info = ground_truth_list[i]
relevant_id = query_info["relevant_doc_id"]
found_rank = -1
for rank, doc in enumerate(ranked_docs):
if doc["id"] == relevant_id:
found_rank = rank + 1
break
if found_rank != -1:
reciprocal_ranks.append(1.0 / found_rank)
if found_rank == 1:
precision_at_1_scores.append(1.0)
else:
precision_at_1_scores.append(0.0)
else:
reciprocal_ranks.append(0.0) # Relevant document not found in top_k
precision_at_1_scores.append(0.0)
mrr = sum(reciprocal_ranks) / len(reciprocal_ranks) if reciprocal_ranks else 0
p_at_1 = sum(precision_at_1_scores) / len(precision_at_1_scores) if precision_at_1_scores else 0
return mrr, p_at_1
# --- 评估 ---
print("\n--- 评估性能 ---")
initial_retrieval_results_all_queries = []
reranked_results_all_queries = []
K_INITIAL = 3 # 从初步检索中考虑用于重排序的文档数量
for item in queries_with_ground_truth:
query = item["query"]
print(f"\n处理查询: {query}")
# 初步检索
initial_docs = retrieve_initial_documents(query, top_k=K_INITIAL)
initial_retrieval_results_all_queries.append(initial_docs)
# 重排序
reranked_docs = rerank_documents(query, initial_docs) # 重排序相同的初始集合
reranked_results_all_queries.append(reranked_docs)
# 计算指标
mrr_initial, p1_initial = calculate_mrr_and_precision_at_1(initial_retrieval_results_all_queries, queries_with_ground_truth)
mrr_reranked, p1_reranked = calculate_mrr_and_precision_at_1(reranked_results_all_queries, queries_with_ground_truth)
print("\n--- 评估总结 ---")
print(f"初步检索(双编码器)-> MRR: {mrr_initial:.4f}, 精确率@1: {p1_initial:.4f}")
print(f"重排序后(交叉编码器)-> MRR: {mrr_reranked:.4f}, 精确率@1: {p1_reranked:.4f}")
应用重排序阶段前后的性能比较。实际值取决于数据集和模型,但通常会呈现上升趋势。(注意:0.625、0.9375、0.50、0.75 这些值是基于样本数据的一次良好运行的说明;你的具体结果可能有所不同。)
应用重排序器后,你应该通常会看到 MRR 和精确率@1 都有所改进。这表明重排序步骤有效地将更多相关文档提升到更高位置。
K_INITIAL 个文档。这就是为什么它是一个对候选文档 子集 的重排序步骤,而不是用于大型语料库的主要检索方法。K_INITIAL: 从初步检索器传递到重排序器的文档数量(我们代码中的 K_INITIAL,通常称为 k' 或 top_n_for_reranking)是一个重要的超参数。
K_INITIAL 太小,真正相关的文档甚至可能无法进入重排序阶段。K_INITIAL 太大,重排序的延迟开销大幅增加。
典型值范围从 20 到 100,取决于应用的延迟预算和初步检索器的质量。这个实践练习呈现了一种强大技术,可大幅提升 RAG 系统检索组件的精确度。通过仔细选择初步候选集,然后应用更复杂的重排序模型,你可以确保提供给生成器的上下文具有最高相关性,这直接影响最终生成输出的质量和事实准确性。请记住,始终评估对相关性指标和系统延迟两方面的影响,以便为你的生产环境找到合适的平衡点。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造