summaryrefslogtreecommitdiff
path: root/rag/mmr.py
diff options
context:
space:
mode:
authorpolwex <polwex@sortug.com>2025-09-24 23:38:36 +0700
committerpolwex <polwex@sortug.com>2025-09-24 23:38:36 +0700
commit734b89570040e97f0c7743c4c0bc28e30a3cd4ee (patch)
tree7142d9f37908138c38d0ade066e960c3a1c69f5d /rag/mmr.py
parent57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (diff)
Diffstat (limited to 'rag/mmr.py')
-rw-r--r--rag/mmr.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/rag/mmr.py b/rag/mmr.py
index 5e47c4f..b52751f 100644
--- a/rag/mmr.py
+++ b/rag/mmr.py
@@ -1,3 +1,4 @@
+from rag.constants import BATCH
import numpy as np
def cosine(a, b): return float(np.dot(a, b))
@@ -21,6 +22,20 @@ def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7):
return selected
def embed_unit_np(st_model, texts: list[str]) -> np.ndarray:
- V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32)
+ V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=BATCH)
V = V.astype("float32", copy=False)
return V
+
+def mmr2(qvec: np.ndarray, ids, vecs: np.ndarray, k=8, lamb=0.7):
+ sel, idxs = [], []
+ rest = list(range(len(ids)))
+ best0 = max(rest, key=lambda i: float(qvec @ vecs[i]))
+ sel.append(ids[best0]); idxs.append(best0); rest.remove(best0)
+ while rest and len(sel) < k:
+ def score(i):
+ rel = float(qvec @ vecs[i])
+ red = max(float(vecs[i] @ vecs[j]) for j in idxs)
+ return lamb*rel - (1-lamb)*red
+ nxt = max(rest, key=score)
+ sel.append(ids[nxt]); idxs.append(nxt); rest.remove(nxt)
+ return sel