summaryrefslogtreecommitdiff
path: root/rag/rerank.py
blob: 8a2870d75ea16e1ed55bb19cf1096851d95fbd04 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# rag/rerank.py
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import CrossEncoder
from rag.constants import BATCH, RERANKER_ID

# device: "cuda" | "cpu" | "mps"
def get_rerank_model():
  id  = "BAAI/bge-reranker-base"  # or -large if you’ve got VRAM
  return CrossEncoder(
            id,
            device="cuda",
            max_length=384,
            model_kwargs={  
             # "attn_implementation":"flash_attention_2",
            "device_map":"auto",
            "dtype":torch.float16
          },
          tokenizer_kwargs={"padding_side": "left"}
        )

def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = BATCH):
    """
    candidates: [(id, text), ...]
    returns:    [(id, text, score)] sorted desc by score
    """
    if not candidates:
        return []
    ids, texts = zip(*candidates)
    # pairs = [(query, t) for t in texts]
    pairs = [(query, t[:1000]) for t in texts]
    scores = reranker.predict(pairs, batch_size=batch_size)  # np.ndarray [N], higher=better
    ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
    return ranked




# tok = AutoTokenizer.from_pretrained(RERANKER_ID, use_fast=True, model_max_length=384)
# ce  = AutoModelForSequenceClassification.from_pretrained(
#     RERANKER_ID, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
#     device_map="auto"  # or load_in_8bit=True
# )

# def ce_scores(query, texts, batch_size=16, max_length=384):
#     scores = []
#     for i in range(0, len(texts), batch_size):
#         batch = texts[i:i+batch_size]
#         enc = tok([ (query, t[:1000]) for t in batch ],
#                   padding=True, truncation=True, max_length=max_length,
#                   return_tensors="pt").to(ce.device)
#         with torch.inference_mode():
#             logits = ce(**enc).logits.squeeze(-1)  # [B]
#         scores.extend(logits.float().cpu().tolist())
#     return scores