From 734b89570040e97f0c7743c4c0bc28e30a3cd4ee Mon Sep 17 00:00:00 2001 From: polwex Date: Wed, 24 Sep 2025 23:38:36 +0700 Subject: init --- rag/mmr.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) (limited to 'rag/mmr.py') diff --git a/rag/mmr.py b/rag/mmr.py index 5e47c4f..b52751f 100644 --- a/rag/mmr.py +++ b/rag/mmr.py @@ -1,3 +1,4 @@ +from rag.constants import BATCH import numpy as np def cosine(a, b): return float(np.dot(a, b)) @@ -21,6 +22,20 @@ def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7): return selected def embed_unit_np(st_model, texts: list[str]) -> np.ndarray: - V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32) + V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=BATCH) V = V.astype("float32", copy=False) return V + +def mmr2(qvec: np.ndarray, ids, vecs: np.ndarray, k=8, lamb=0.7): + sel, idxs = [], [] + rest = list(range(len(ids))) + best0 = max(rest, key=lambda i: float(qvec @ vecs[i])) + sel.append(ids[best0]); idxs.append(best0); rest.remove(best0) + while rest and len(sel) < k: + def score(i): + rel = float(qvec @ vecs[i]) + red = max(float(vecs[i] @ vecs[j]) for j in idxs) + return lamb*rel - (1-lamb)*red + nxt = max(rest, key=score) + sel.append(ids[nxt]); idxs.append(nxt); rest.remove(nxt) + return sel -- cgit v1.2.3