diff options
author | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
---|---|---|
committer | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
commit | 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (patch) | |
tree | 1a7556927bed94377630d33dd29c3bf07d159619 /rag/search.py |
init
Diffstat (limited to 'rag/search.py')
-rw-r--r-- | rag/search.py | 101 |
1 files changed, 101 insertions, 0 deletions
diff --git a/rag/search.py b/rag/search.py new file mode 100644 index 0000000..55b8ffd --- /dev/null +++ b/rag/search.py @@ -0,0 +1,101 @@ +import sqlite3 +import numpy as np +import torch +import gc +from typing import List, Tuple + +from sentence_transformers import CrossEncoder, SentenceTransformer +from rag.ingest import get_embed_model +from rag.rerank import get_rerank_model, rerank_cross_encoder +from rag.mmr import mmr, embed_unit_np # if you added MMR; else remove + +def search_hybrid( + db: sqlite3.Connection, query: str, + k_vec: int = 40, k_bm25: int = 40, + k_ce: int = 30, # rerank this many + k_final: int = 10, # return this many + use_mmr: bool = False, mmr_lambda: float = 0.7 +): + emodel = get_embed_model() + # 1) embed query (unit, float32) for vec search + MMR + print("loading") + q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") + emodel = None + with torch.no_grad(): + torch.cuda.empty_cache() + gc.collect() + print("memory should be free by now!!") + qbytes = q.tobytes() + + # 2) ANN (cosine) + BM25 pools + vec_ids = [i for (i, _) in db.execute( + "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", + (memoryview(qbytes), k_vec) + ).fetchall()] + bm25_ids = [i for (i,) in db.execute( + "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", + (query, k_bm25) + ).fetchall()] + + # 3) merge (vector-first) + seen, merged = set(), [] + for i in vec_ids + bm25_ids: + if i not in seen: + merged.append(i); seen.add(i) + if not merged: + return [] + + # 4) fetch texts for CE + qmarks = ",".join("?"*len(merged)) + cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall() + + reranker = get_rerank_model() + # 5) cross-encoder rerank (returns [(id,text,score)] desc) + ranked = rerank_cross_encoder(reranker, query, cand) + reranker = None + print("freeing again!!") + with torch.no_grad(): + torch.cuda.empty_cache() + gc.collect() + # + ranked = ranked[:min(k_ce, len(ranked))] + + if not use_mmr or len(ranked) <= k_final: + return ranked[:k_final] + + # 6) MMR diversity on CE top-k_ce + cand_ids = [i for (i,_,_) in ranked] + cand_text = [t for (_,t,_) in ranked] + emodel = get_embed_model() + # god this is annoying I should stop being poor + cand_vecs = embed_unit_np(emodel, cand_text) # [N,D], unit vectors + sel_ids = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda)) + final = [trip for trip in ranked if trip[0] in sel_ids] # keep CE order, filter by MMR picks + return final[:k_final] + + +# ) Query helper (cosine distance; operator may be <#> in sqlite-vec) +def vec_search(db: sqlite3.Connection, qtext, k=5): + model = get_embed_model() + q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32") + # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine + rows = db.execute( + "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", + (memoryview(q.tobytes()), k) + ).fetchall() + +# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40)) + return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows] + + +# # Hybrid + CE rerank query: +# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8) +# for rid, txt, score in results: +# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n") +# +# +# +# +# +# +# |