diff options
author | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
---|---|---|
committer | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
commit | 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (patch) | |
tree | 1a7556927bed94377630d33dd29c3bf07d159619 /rag/mmr.py |
init
Diffstat (limited to 'rag/mmr.py')
-rw-r--r-- | rag/mmr.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/rag/mmr.py b/rag/mmr.py new file mode 100644 index 0000000..5e47c4f --- /dev/null +++ b/rag/mmr.py @@ -0,0 +1,26 @@ +import numpy as np + +def cosine(a, b): return float(np.dot(a, b)) + +def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7): + """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids""" + selected, selected_idx = [], [] + remaining = list(range(len(cand_ids))) + + # seed with the most relevant + best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i])) + selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0) + + while remaining and len(selected) < k: + def mmr_score(i): + rel = cosine(query_vec, cand_vecs[i]) + red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0 + return lamb * rel - (1.0 - lamb) * red + nxt = max(remaining, key=mmr_score) + selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt) + return selected + +def embed_unit_np(st_model, texts: list[str]) -> np.ndarray: + V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32) + V = V.astype("float32", copy=False) + return V |