summaryrefslogtreecommitdiff
path: root/rag/rerank.py
diff options
context:
space:
mode:
authorpolwex <polwex@sortug.com>2025-09-24 23:38:36 +0700
committerpolwex <polwex@sortug.com>2025-09-24 23:38:36 +0700
commit734b89570040e97f0c7743c4c0bc28e30a3cd4ee (patch)
tree7142d9f37908138c38d0ade066e960c3a1c69f5d /rag/rerank.py
parent57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (diff)
Diffstat (limited to 'rag/rerank.py')
-rw-r--r--rag/rerank.py36
1 files changed, 30 insertions, 6 deletions
diff --git a/rag/rerank.py b/rag/rerank.py
index 6ae8938..8a2870d 100644
--- a/rag/rerank.py
+++ b/rag/rerank.py
@@ -1,23 +1,25 @@
# rag/rerank.py
import torch
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import CrossEncoder
+from rag.constants import BATCH, RERANKER_ID
-RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM
-# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
# device: "cuda" | "cpu" | "mps"
def get_rerank_model():
+ id = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
return CrossEncoder(
- RERANKER_ID,
+ id,
device="cuda",
+ max_length=384,
model_kwargs={
- "attn_implementation":"flash_attention_2",
+ # "attn_implementation":"flash_attention_2",
"device_map":"auto",
"dtype":torch.float16
},
tokenizer_kwargs={"padding_side": "left"}
)
-def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = 32):
+def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = BATCH):
"""
candidates: [(id, text), ...]
returns: [(id, text, score)] sorted desc by score
@@ -25,7 +27,29 @@ def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tu
if not candidates:
return []
ids, texts = zip(*candidates)
- pairs = [(query, t) for t in texts]
+ # pairs = [(query, t) for t in texts]
+ pairs = [(query, t[:1000]) for t in texts]
scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better
ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
return ranked
+
+
+
+
+# tok = AutoTokenizer.from_pretrained(RERANKER_ID, use_fast=True, model_max_length=384)
+# ce = AutoModelForSequenceClassification.from_pretrained(
+# RERANKER_ID, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
+# device_map="auto" # or load_in_8bit=True
+# )
+
+# def ce_scores(query, texts, batch_size=16, max_length=384):
+# scores = []
+# for i in range(0, len(texts), batch_size):
+# batch = texts[i:i+batch_size]
+# enc = tok([ (query, t[:1000]) for t in batch ],
+# padding=True, truncation=True, max_length=max_length,
+# return_tensors="pt").to(ce.device)
+# with torch.inference_mode():
+# logits = ce(**enc).logits.squeeze(-1) # [B]
+# scores.extend(logits.float().cpu().tolist())
+# return scores