summaryrefslogtreecommitdiff
path: root/rag/rerank.py
diff options
context:
space:
mode:
authorpolwex <polwex@sortug.com>2025-09-23 03:50:53 +0700
committerpolwex <polwex@sortug.com>2025-09-23 03:50:53 +0700
commit57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (patch)
tree1a7556927bed94377630d33dd29c3bf07d159619 /rag/rerank.py
init
Diffstat (limited to 'rag/rerank.py')
-rw-r--r--rag/rerank.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/rag/rerank.py b/rag/rerank.py
new file mode 100644
index 0000000..6ae8938
--- /dev/null
+++ b/rag/rerank.py
@@ -0,0 +1,31 @@
+# rag/rerank.py
+import torch
+from sentence_transformers import CrossEncoder
+
+RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM
+# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
+# device: "cuda" | "cpu" | "mps"
+def get_rerank_model():
+ return CrossEncoder(
+ RERANKER_ID,
+ device="cuda",
+ model_kwargs={
+ "attn_implementation":"flash_attention_2",
+ "device_map":"auto",
+ "dtype":torch.float16
+ },
+ tokenizer_kwargs={"padding_side": "left"}
+ )
+
+def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = 32):
+ """
+ candidates: [(id, text), ...]
+ returns: [(id, text, score)] sorted desc by score
+ """
+ if not candidates:
+ return []
+ ids, texts = zip(*candidates)
+ pairs = [(query, t) for t in texts]
+ scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better
+ ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
+ return ranked