diff options
author | polwex <polwex@sortug.com> | 2025-09-24 23:38:36 +0700 |
---|---|---|
committer | polwex <polwex@sortug.com> | 2025-09-24 23:38:36 +0700 |
commit | 734b89570040e97f0c7743c4c0bc28e30a3cd4ee (patch) | |
tree | 7142d9f37908138c38d0ade066e960c3a1c69f5d /rag/mmr.py | |
parent | 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (diff) |
Diffstat (limited to 'rag/mmr.py')
-rw-r--r-- | rag/mmr.py | 17 |
1 files changed, 16 insertions, 1 deletions
@@ -1,3 +1,4 @@ +from rag.constants import BATCH import numpy as np def cosine(a, b): return float(np.dot(a, b)) @@ -21,6 +22,20 @@ def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7): return selected def embed_unit_np(st_model, texts: list[str]) -> np.ndarray: - V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32) + V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=BATCH) V = V.astype("float32", copy=False) return V + +def mmr2(qvec: np.ndarray, ids, vecs: np.ndarray, k=8, lamb=0.7): + sel, idxs = [], [] + rest = list(range(len(ids))) + best0 = max(rest, key=lambda i: float(qvec @ vecs[i])) + sel.append(ids[best0]); idxs.append(best0); rest.remove(best0) + while rest and len(sel) < k: + def score(i): + rel = float(qvec @ vecs[i]) + red = max(float(vecs[i] @ vecs[j]) for j in idxs) + return lamb*rel - (1-lamb)*red + nxt = max(rest, key=score) + sel.append(ids[nxt]); idxs.append(nxt); rest.remove(nxt) + return sel |