diff options
Diffstat (limited to 'rag/rerank.py')
-rw-r--r-- | rag/rerank.py | 36 |
1 files changed, 30 insertions, 6 deletions
diff --git a/rag/rerank.py b/rag/rerank.py index 6ae8938..8a2870d 100644 --- a/rag/rerank.py +++ b/rag/rerank.py @@ -1,23 +1,25 @@ # rag/rerank.py import torch +from transformers import AutoTokenizer, AutoModelForSequenceClassification from sentence_transformers import CrossEncoder +from rag.constants import BATCH, RERANKER_ID -RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM -# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM # device: "cuda" | "cpu" | "mps" def get_rerank_model(): + id = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM return CrossEncoder( - RERANKER_ID, + id, device="cuda", + max_length=384, model_kwargs={ - "attn_implementation":"flash_attention_2", + # "attn_implementation":"flash_attention_2", "device_map":"auto", "dtype":torch.float16 }, tokenizer_kwargs={"padding_side": "left"} ) -def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = 32): +def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = BATCH): """ candidates: [(id, text), ...] returns: [(id, text, score)] sorted desc by score @@ -25,7 +27,29 @@ def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tu if not candidates: return [] ids, texts = zip(*candidates) - pairs = [(query, t) for t in texts] + # pairs = [(query, t) for t in texts] + pairs = [(query, t[:1000]) for t in texts] scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) return ranked + + + + +# tok = AutoTokenizer.from_pretrained(RERANKER_ID, use_fast=True, model_max_length=384) +# ce = AutoModelForSequenceClassification.from_pretrained( +# RERANKER_ID, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", +# device_map="auto" # or load_in_8bit=True +# ) + +# def ce_scores(query, texts, batch_size=16, max_length=384): +# scores = [] +# for i in range(0, len(texts), batch_size): +# batch = texts[i:i+batch_size] +# enc = tok([ (query, t[:1000]) for t in batch ], +# padding=True, truncation=True, max_length=max_length, +# return_tensors="pt").to(ce.device) +# with torch.inference_mode(): +# logits = ce(**enc).logits.squeeze(-1) # [B] +# scores.extend(logits.float().cpu().tolist()) +# return scores |