1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
# rag/rerank.py
import torch
from sentence_transformers import CrossEncoder
RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM
# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
# device: "cuda" | "cpu" | "mps"
def get_rerank_model():
return CrossEncoder(
RERANKER_ID,
device="cuda",
model_kwargs={
"attn_implementation":"flash_attention_2",
"device_map":"auto",
"dtype":torch.float16
},
tokenizer_kwargs={"padding_side": "left"}
)
def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = 32):
"""
candidates: [(id, text), ...]
returns: [(id, text, score)] sorted desc by score
"""
if not candidates:
return []
ids, texts = zip(*candidates)
pairs = [(query, t) for t in texts]
scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better
ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
return ranked
|