summaryrefslogtreecommitdiff
path: root/rag/search.py
diff options
context:
space:
mode:
authorpolwex <polwex@sortug.com>2025-09-23 03:50:53 +0700
committerpolwex <polwex@sortug.com>2025-09-23 03:50:53 +0700
commit57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (patch)
tree1a7556927bed94377630d33dd29c3bf07d159619 /rag/search.py
init
Diffstat (limited to 'rag/search.py')
-rw-r--r--rag/search.py101
1 files changed, 101 insertions, 0 deletions
diff --git a/rag/search.py b/rag/search.py
new file mode 100644
index 0000000..55b8ffd
--- /dev/null
+++ b/rag/search.py
@@ -0,0 +1,101 @@
+import sqlite3
+import numpy as np
+import torch
+import gc
+from typing import List, Tuple
+
+from sentence_transformers import CrossEncoder, SentenceTransformer
+from rag.ingest import get_embed_model
+from rag.rerank import get_rerank_model, rerank_cross_encoder
+from rag.mmr import mmr, embed_unit_np # if you added MMR; else remove
+
+def search_hybrid(
+ db: sqlite3.Connection, query: str,
+ k_vec: int = 40, k_bm25: int = 40,
+ k_ce: int = 30, # rerank this many
+ k_final: int = 10, # return this many
+ use_mmr: bool = False, mmr_lambda: float = 0.7
+):
+ emodel = get_embed_model()
+ # 1) embed query (unit, float32) for vec search + MMR
+ print("loading")
+ q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
+ emodel = None
+ with torch.no_grad():
+ torch.cuda.empty_cache()
+ gc.collect()
+ print("memory should be free by now!!")
+ qbytes = q.tobytes()
+
+ # 2) ANN (cosine) + BM25 pools
+ vec_ids = [i for (i, _) in db.execute(
+ "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ (memoryview(qbytes), k_vec)
+ ).fetchall()]
+ bm25_ids = [i for (i,) in db.execute(
+ "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?",
+ (query, k_bm25)
+ ).fetchall()]
+
+ # 3) merge (vector-first)
+ seen, merged = set(), []
+ for i in vec_ids + bm25_ids:
+ if i not in seen:
+ merged.append(i); seen.add(i)
+ if not merged:
+ return []
+
+ # 4) fetch texts for CE
+ qmarks = ",".join("?"*len(merged))
+ cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall()
+
+ reranker = get_rerank_model()
+ # 5) cross-encoder rerank (returns [(id,text,score)] desc)
+ ranked = rerank_cross_encoder(reranker, query, cand)
+ reranker = None
+ print("freeing again!!")
+ with torch.no_grad():
+ torch.cuda.empty_cache()
+ gc.collect()
+ #
+ ranked = ranked[:min(k_ce, len(ranked))]
+
+ if not use_mmr or len(ranked) <= k_final:
+ return ranked[:k_final]
+
+ # 6) MMR diversity on CE top-k_ce
+ cand_ids = [i for (i,_,_) in ranked]
+ cand_text = [t for (_,t,_) in ranked]
+ emodel = get_embed_model()
+ # god this is annoying I should stop being poor
+ cand_vecs = embed_unit_np(emodel, cand_text) # [N,D], unit vectors
+ sel_ids = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda))
+ final = [trip for trip in ranked if trip[0] in sel_ids] # keep CE order, filter by MMR picks
+ return final[:k_final]
+
+
+# ) Query helper (cosine distance; operator may be <#> in sqlite-vec)
+def vec_search(db: sqlite3.Connection, qtext, k=5):
+ model = get_embed_model()
+ q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
+ # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
+ rows = db.execute(
+ "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ (memoryview(q.tobytes()), k)
+ ).fetchall()
+
+# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
+ return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
+
+
+# # Hybrid + CE rerank query:
+# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8)
+# for rid, txt, score in results:
+# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n")
+#
+#
+#
+#
+#
+#
+#