From 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 Mon Sep 17 00:00:00 2001 From: polwex Date: Tue, 23 Sep 2025 03:50:53 +0700 Subject: init --- rag/mmr.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 rag/mmr.py (limited to 'rag/mmr.py') diff --git a/rag/mmr.py b/rag/mmr.py new file mode 100644 index 0000000..5e47c4f --- /dev/null +++ b/rag/mmr.py @@ -0,0 +1,26 @@ +import numpy as np + +def cosine(a, b): return float(np.dot(a, b)) + +def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7): + """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids""" + selected, selected_idx = [], [] + remaining = list(range(len(cand_ids))) + + # seed with the most relevant + best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i])) + selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0) + + while remaining and len(selected) < k: + def mmr_score(i): + rel = cosine(query_vec, cand_vecs[i]) + red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0 + return lamb * rel - (1.0 - lamb) * red + nxt = max(remaining, key=mmr_score) + selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt) + return selected + +def embed_unit_np(st_model, texts: list[str]) -> np.ndarray: + V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32) + V = V.astype("float32", copy=False) + return V -- cgit v1.2.3