summaryrefslogtreecommitdiff
path: root/rag/search.py
diff options
context:
space:
mode:
authorpolwex <polwex@sortug.com>2025-09-24 23:38:36 +0700
committerpolwex <polwex@sortug.com>2025-09-24 23:38:36 +0700
commit734b89570040e97f0c7743c4c0bc28e30a3cd4ee (patch)
tree7142d9f37908138c38d0ade066e960c3a1c69f5d /rag/search.py
parent57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (diff)
Diffstat (limited to 'rag/search.py')
-rw-r--r--rag/search.py205
1 files changed, 154 insertions, 51 deletions
diff --git a/rag/search.py b/rag/search.py
index 55b8ffd..291c1b4 100644
--- a/rag/search.py
+++ b/rag/search.py
@@ -5,12 +5,68 @@ import gc
from typing import List, Tuple
from sentence_transformers import CrossEncoder, SentenceTransformer
+from rag.constants import BATCH
from rag.ingest import get_embed_model
from rag.rerank import get_rerank_model, rerank_cross_encoder
-from rag.mmr import mmr, embed_unit_np # if you added MMR; else remove
+from rag.mmr import mmr, mmr2, embed_unit_np # if you added MMR; else remove
+from rag.db import vec_topk, bm25_topk, fetch_chunk
+
+
+# ) Query helper (cosine distance; operator may be <#> in sqlite-vec)
+# def vec_search(db: sqlite3.Connection, col: str, qtext, k=5):
+# model = get_embed_model()
+# q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
+# # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
+# rows = vec_topk(db, col, q, k)
+
+# # db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
+# return [(rid, db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
+
+def _dist_to_sim(dist: float) -> float:
+ # L2 on unit vectors ↔ cosine: ||a-b||^2 = 2 - 2 cos => cos = 1 - dist/2
+ return max(0.0, 1.0 - dist / 2.0)
+
+def vec_search(db, model: SentenceTransformer, col: str, qtext: str, k: int = 10, min_sim: float = 0.25,
+ max_per_doc: int | None = None, use_mmr: bool = False, mmr_lambda: float = 0.7):
+ q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") # <- flatten
+ rows = vec_topk(db, col, q, k * 4) # overfetch a bit, filter below
+
+ # fetch texts + compute cosine sim
+ hits = []
+ for rid, dist in rows:
+ txt = fetch_chunk(db, col, rid)
+ sim = _dist_to_sim(dist)
+ if sim >= min_sim:
+ hits.append((rid, txt, sim))
+
+ # anti-spam (cap near-duplicates from same doc region if you add metadata later)
+ if max_per_doc:
+ capped, seen = [], {}
+ for rid, txt, sim in hits:
+ dockey = col # or derive from a future chunk_meta table
+ cnt = seen.get(dockey, 0)
+ if cnt < max_per_doc:
+ capped.append((rid, txt, sim))
+ seen[dockey] = cnt + 1
+ hits = capped
+
+ # optional light MMR on the filtered set (diversify)
+ if use_mmr and len(hits) > k:
+ from rag.mmr import embed_unit_np, mmr2
+ ids = [h[0] for h in hits]
+ texts = [h[1] for h in hits]
+ qvec = q
+ cvecs = embed_unit_np(model, texts) # [N,D] unit
+ keep = set(mmr2(qvec, ids, cvecs, k=k, lamb=mmr_lambda))
+ hits = [h for h in hits if h[0] in keep]
+
+ # final crop
+ hits.sort(key=lambda x: x[2], reverse=True)
+ return hits[:k]
+
def search_hybrid(
- db: sqlite3.Connection, query: str,
+ db: sqlite3.Connection, col: str, query: str,
k_vec: int = 40, k_bm25: int = 40,
k_ce: int = 30, # rerank this many
k_final: int = 10, # return this many
@@ -19,74 +75,57 @@ def search_hybrid(
emodel = get_embed_model()
# 1) embed query (unit, float32) for vec search + MMR
print("loading")
- q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
+ query_embeddings = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
emodel = None
with torch.no_grad():
torch.cuda.empty_cache()
gc.collect()
print("memory should be free by now!!")
- qbytes = q.tobytes()
-
- # 2) ANN (cosine) + BM25 pools
- vec_ids = [i for (i, _) in db.execute(
- "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
- (memoryview(qbytes), k_vec)
- ).fetchall()]
- bm25_ids = [i for (i,) in db.execute(
- "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?",
- (query, k_bm25)
- ).fetchall()]
-
- # 3) merge (vector-first)
- seen, merged = set(), []
- for i in vec_ids + bm25_ids:
+ # qbytes = q.tobytes()
+
+
+ # 2) ANN + BM25
+ print("phase 2", col, query)
+ vhits = vec_topk(db, col, query_embeddings, k_vec) # [(id, dist)]
+ vh_ids = [i for (i, _) in vhits]
+ bm_ids = bm25_topk(db, col, query, k_bm25)
+ #
+ # 3) merge ids [vector-first]
+ merged, seen = [], set()
+ for i in vh_ids + bm_ids:
if i not in seen:
merged.append(i); seen.add(i)
if not merged:
return []
- # 4) fetch texts for CE
+ # 4) fetch texts
qmarks = ",".join("?"*len(merged))
- cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall()
+ cand = db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN ({qmarks})", merged).fetchall()
+ ids, texts = zip(*cand)
+
+ # 5) rerank
+ print("loading reranking model")
reranker = get_rerank_model()
- # 5) cross-encoder rerank (returns [(id,text,score)] desc)
- ranked = rerank_cross_encoder(reranker, query, cand)
- reranker = None
- print("freeing again!!")
+ scores = reranker.predict([(query, t[:1000]) for t in texts], batch_size=BATCH)
+ reranker =None
with torch.no_grad():
torch.cuda.empty_cache()
gc.collect()
- #
- ranked = ranked[:min(k_ce, len(ranked))]
-
+ print("memory should be free by now!!")
+ print("unloading reranking model")
+ ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
if not use_mmr or len(ranked) <= k_final:
- return ranked[:k_final]
+ return ranked[:min(k_ce, k_final)]
- # 6) MMR diversity on CE top-k_ce
- cand_ids = [i for (i,_,_) in ranked]
- cand_text = [t for (_,t,_) in ranked]
- emodel = get_embed_model()
- # god this is annoying I should stop being poor
- cand_vecs = embed_unit_np(emodel, cand_text) # [N,D], unit vectors
- sel_ids = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda))
- final = [trip for trip in ranked if trip[0] in sel_ids] # keep CE order, filter by MMR picks
- return final[:k_final]
-
-
-# ) Query helper (cosine distance; operator may be <#> in sqlite-vec)
-def vec_search(db: sqlite3.Connection, qtext, k=5):
- model = get_embed_model()
- q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
- # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
- rows = db.execute(
- "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
- (memoryview(q.tobytes()), k)
- ).fetchall()
-
-# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
- return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
+ # 6) MMR
+ ce_ids = [i for (i,_,_) in ranked]
+ ce_texts = [t for (_,t,_) in ranked]
+ st_model = get_embed_model()
+ ce_vecs = st_model.encode(ce_texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32")
+ keep = set(mmr2(query_embeddings, ce_ids, ce_vecs, k=k_final, lamb=mmr_lambda))
+ return [r for r in ranked if r[0] in keep][:k_final]
# # Hybrid + CE rerank query:
# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8)
@@ -99,3 +138,67 @@ def vec_search(db: sqlite3.Connection, qtext, k=5):
#
#
#
+def search_hybrid_with_mmr(db, col, query, k_vec=50, k_bm25=50, k_ce=30, k_final=10, lamb=0.7):
+ ranked = search_hybrid(db, col, query, k_vec, k_bm25, k_ce, k_ce)
+ if not ranked: return []
+ ids = [i for (i,_,_) in ranked]
+ texts = [t for (_,t,_) in ranked]
+ st_model = get_embed_model()
+ qvec = st_model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
+ cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32")
+ keep = set(mmr2(qvec, ids, cvecs, k=k_final, lamb=lamb))
+ return [r for r in ranked if r[0] in keep][:k_final]
+
+
+
+# clean vec expansion says gpt5
+# #
+# def expand(q, aliases=()):
+# qs = [q, *aliases]
+# # embed each, take max similarity per chunk at scoring time
+
+
+# def dist_to_cos(d): return max(0.0, 1.0 - d/2.0) # L2 on unit vecs
+
+# def vec_topk(db, table, q_vec_f32, k):
+# from sqlite_vec import serialize_float32
+# return db.execute(
+# f"SELECT rowid, distance FROM {table} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
+# (serialize_float32(q_vec_f32), k)
+# ).fetchall()
+
+# def vec_search(db, st_model, col, qtext, k=12, k_raw=None, min_sim=0.30, use_mmr=True, mmr_lambda=0.7):
+# if k_raw is None: k_raw = k*4
+# q = st_model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
+# rows = vec_topk(db, f"vec_{col}", q, k_raw)
+
+# hits = []
+# for rid, dist in rows:
+# cos = dist_to_cos(dist)
+# if cos < min_sim: continue
+# txt = db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0]
+# hits.append((rid, txt, cos))
+
+# hits.sort(key=lambda x: x[2], reverse=True)
+# if not use_mmr or len(hits) <= k:
+# return hits[:k]
+
+# # MMR on the (already filtered) pool
+# ids = [h[0] for h in hits]
+# texts = [h[1] for h in hits]
+# cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32).astype("float32")
+# # simple MMR
+# import numpy as np
+# def cosine(a,b): return float(a@b)
+# sel, sel_idx = [], []
+# rem = list(range(len(ids)))
+# best0 = max(rem, key=lambda i: cosine(q, cvecs[i])); sel.append(ids[best0]); sel_idx.append(best0); rem.remove(best0)
+# while rem and len(sel)<k:
+# def score(i):
+# rel = cosine(q, cvecs[i])
+# red = max(cosine(cvecs[i], cvecs[j]) for j in sel_idx)
+# return mmr_lambda*rel - (1.0 - mmr_lambda)*red
+# nxt = max(rem, key=score); sel.append(ids[nxt]); sel_idx.append(nxt); rem.remove(nxt)
+# keep = set(sel)
+# return [h for h in hits if h[0] in keep][:k]
+