# rag/rerank.py import torch from sentence_transformers import CrossEncoder RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM # RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM # device: "cuda" | "cpu" | "mps" def get_rerank_model(): return CrossEncoder( RERANKER_ID, device="cuda", model_kwargs={ "attn_implementation":"flash_attention_2", "device_map":"auto", "dtype":torch.float16 }, tokenizer_kwargs={"padding_side": "left"} ) def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = 32): """ candidates: [(id, text), ...] returns: [(id, text, score)] sorted desc by score """ if not candidates: return [] ids, texts = zip(*candidates) pairs = [(query, t) for t in texts] scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) return ranked