diff options
Diffstat (limited to 'rag/rerank.py')
-rw-r--r-- | rag/rerank.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/rag/rerank.py b/rag/rerank.py new file mode 100644 index 0000000..6ae8938 --- /dev/null +++ b/rag/rerank.py @@ -0,0 +1,31 @@ +# rag/rerank.py +import torch +from sentence_transformers import CrossEncoder + +RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM +# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM +# device: "cuda" | "cpu" | "mps" +def get_rerank_model(): + return CrossEncoder( + RERANKER_ID, + device="cuda", + model_kwargs={ + "attn_implementation":"flash_attention_2", + "device_map":"auto", + "dtype":torch.float16 + }, + tokenizer_kwargs={"padding_side": "left"} + ) + +def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = 32): + """ + candidates: [(id, text), ...] + returns: [(id, text, score)] sorted desc by score + """ + if not candidates: + return [] + ids, texts = zip(*candidates) + pairs = [(query, t) for t in texts] + scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better + ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) + return ranked |