summaryrefslogtreecommitdiff
path: root/rag/mmr.py
diff options
context:
space:
mode:
authorpolwex <polwex@sortug.com>2025-09-23 03:50:53 +0700
committerpolwex <polwex@sortug.com>2025-09-23 03:50:53 +0700
commit57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (patch)
tree1a7556927bed94377630d33dd29c3bf07d159619 /rag/mmr.py
init
Diffstat (limited to 'rag/mmr.py')
-rw-r--r--rag/mmr.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/rag/mmr.py b/rag/mmr.py
new file mode 100644
index 0000000..5e47c4f
--- /dev/null
+++ b/rag/mmr.py
@@ -0,0 +1,26 @@
+import numpy as np
+
+def cosine(a, b): return float(np.dot(a, b))
+
+def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7):
+ """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids"""
+ selected, selected_idx = [], []
+ remaining = list(range(len(cand_ids)))
+
+ # seed with the most relevant
+ best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i]))
+ selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0)
+
+ while remaining and len(selected) < k:
+ def mmr_score(i):
+ rel = cosine(query_vec, cand_vecs[i])
+ red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0
+ return lamb * rel - (1.0 - lamb) * red
+ nxt = max(remaining, key=mmr_score)
+ selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt)
+ return selected
+
+def embed_unit_np(st_model, texts: list[str]) -> np.ndarray:
+ V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32)
+ V = V.astype("float32", copy=False)
+ return V