# rag/rerank.py import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from sentence_transformers import CrossEncoder from rag.constants import BATCH, RERANKER_ID # device: "cuda" | "cpu" | "mps" def get_rerank_model(): id = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM return CrossEncoder( id, device="cuda", max_length=384, model_kwargs={ # "attn_implementation":"flash_attention_2", "device_map":"auto", "dtype":torch.float16 }, tokenizer_kwargs={"padding_side": "left"} ) def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = BATCH): """ candidates: [(id, text), ...] returns: [(id, text, score)] sorted desc by score """ if not candidates: return [] ids, texts = zip(*candidates) # pairs = [(query, t) for t in texts] pairs = [(query, t[:1000]) for t in texts] scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) return ranked # tok = AutoTokenizer.from_pretrained(RERANKER_ID, use_fast=True, model_max_length=384) # ce = AutoModelForSequenceClassification.from_pretrained( # RERANKER_ID, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", # device_map="auto" # or load_in_8bit=True # ) # def ce_scores(query, texts, batch_size=16, max_length=384): # scores = [] # for i in range(0, len(texts), batch_size): # batch = texts[i:i+batch_size] # enc = tok([ (query, t[:1000]) for t in batch ], # padding=True, truncation=True, max_length=max_length, # return_tensors="pt").to(ce.device) # with torch.inference_mode(): # logits = ce(**enc).logits.squeeze(-1) # [B] # scores.extend(logits.float().cpu().tolist()) # return scores