diff options
author | polwex <polwex@sortug.com> | 2025-09-24 23:38:36 +0700 |
---|---|---|
committer | polwex <polwex@sortug.com> | 2025-09-24 23:38:36 +0700 |
commit | 734b89570040e97f0c7743c4c0bc28e30a3cd4ee (patch) | |
tree | 7142d9f37908138c38d0ade066e960c3a1c69f5d /rag/search.py | |
parent | 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (diff) |
Diffstat (limited to 'rag/search.py')
-rw-r--r-- | rag/search.py | 205 |
1 files changed, 154 insertions, 51 deletions
diff --git a/rag/search.py b/rag/search.py index 55b8ffd..291c1b4 100644 --- a/rag/search.py +++ b/rag/search.py @@ -5,12 +5,68 @@ import gc from typing import List, Tuple from sentence_transformers import CrossEncoder, SentenceTransformer +from rag.constants import BATCH from rag.ingest import get_embed_model from rag.rerank import get_rerank_model, rerank_cross_encoder -from rag.mmr import mmr, embed_unit_np # if you added MMR; else remove +from rag.mmr import mmr, mmr2, embed_unit_np # if you added MMR; else remove +from rag.db import vec_topk, bm25_topk, fetch_chunk + + +# ) Query helper (cosine distance; operator may be <#> in sqlite-vec) +# def vec_search(db: sqlite3.Connection, col: str, qtext, k=5): +# model = get_embed_model() +# q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32") +# # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine +# rows = vec_topk(db, col, q, k) + +# # db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40)) +# return [(rid, db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows] + +def _dist_to_sim(dist: float) -> float: + # L2 on unit vectors ↔ cosine: ||a-b||^2 = 2 - 2 cos => cos = 1 - dist/2 + return max(0.0, 1.0 - dist / 2.0) + +def vec_search(db, model: SentenceTransformer, col: str, qtext: str, k: int = 10, min_sim: float = 0.25, + max_per_doc: int | None = None, use_mmr: bool = False, mmr_lambda: float = 0.7): + q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") # <- flatten + rows = vec_topk(db, col, q, k * 4) # overfetch a bit, filter below + + # fetch texts + compute cosine sim + hits = [] + for rid, dist in rows: + txt = fetch_chunk(db, col, rid) + sim = _dist_to_sim(dist) + if sim >= min_sim: + hits.append((rid, txt, sim)) + + # anti-spam (cap near-duplicates from same doc region if you add metadata later) + if max_per_doc: + capped, seen = [], {} + for rid, txt, sim in hits: + dockey = col # or derive from a future chunk_meta table + cnt = seen.get(dockey, 0) + if cnt < max_per_doc: + capped.append((rid, txt, sim)) + seen[dockey] = cnt + 1 + hits = capped + + # optional light MMR on the filtered set (diversify) + if use_mmr and len(hits) > k: + from rag.mmr import embed_unit_np, mmr2 + ids = [h[0] for h in hits] + texts = [h[1] for h in hits] + qvec = q + cvecs = embed_unit_np(model, texts) # [N,D] unit + keep = set(mmr2(qvec, ids, cvecs, k=k, lamb=mmr_lambda)) + hits = [h for h in hits if h[0] in keep] + + # final crop + hits.sort(key=lambda x: x[2], reverse=True) + return hits[:k] + def search_hybrid( - db: sqlite3.Connection, query: str, + db: sqlite3.Connection, col: str, query: str, k_vec: int = 40, k_bm25: int = 40, k_ce: int = 30, # rerank this many k_final: int = 10, # return this many @@ -19,74 +75,57 @@ def search_hybrid( emodel = get_embed_model() # 1) embed query (unit, float32) for vec search + MMR print("loading") - q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") + query_embeddings = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") emodel = None with torch.no_grad(): torch.cuda.empty_cache() gc.collect() print("memory should be free by now!!") - qbytes = q.tobytes() - - # 2) ANN (cosine) + BM25 pools - vec_ids = [i for (i, _) in db.execute( - "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", - (memoryview(qbytes), k_vec) - ).fetchall()] - bm25_ids = [i for (i,) in db.execute( - "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", - (query, k_bm25) - ).fetchall()] - - # 3) merge (vector-first) - seen, merged = set(), [] - for i in vec_ids + bm25_ids: + # qbytes = q.tobytes() + + + # 2) ANN + BM25 + print("phase 2", col, query) + vhits = vec_topk(db, col, query_embeddings, k_vec) # [(id, dist)] + vh_ids = [i for (i, _) in vhits] + bm_ids = bm25_topk(db, col, query, k_bm25) + # + # 3) merge ids [vector-first] + merged, seen = [], set() + for i in vh_ids + bm_ids: if i not in seen: merged.append(i); seen.add(i) if not merged: return [] - # 4) fetch texts for CE + # 4) fetch texts qmarks = ",".join("?"*len(merged)) - cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall() + cand = db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN ({qmarks})", merged).fetchall() + ids, texts = zip(*cand) + + # 5) rerank + print("loading reranking model") reranker = get_rerank_model() - # 5) cross-encoder rerank (returns [(id,text,score)] desc) - ranked = rerank_cross_encoder(reranker, query, cand) - reranker = None - print("freeing again!!") + scores = reranker.predict([(query, t[:1000]) for t in texts], batch_size=BATCH) + reranker =None with torch.no_grad(): torch.cuda.empty_cache() gc.collect() - # - ranked = ranked[:min(k_ce, len(ranked))] - + print("memory should be free by now!!") + print("unloading reranking model") + ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) if not use_mmr or len(ranked) <= k_final: - return ranked[:k_final] + return ranked[:min(k_ce, k_final)] - # 6) MMR diversity on CE top-k_ce - cand_ids = [i for (i,_,_) in ranked] - cand_text = [t for (_,t,_) in ranked] - emodel = get_embed_model() - # god this is annoying I should stop being poor - cand_vecs = embed_unit_np(emodel, cand_text) # [N,D], unit vectors - sel_ids = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda)) - final = [trip for trip in ranked if trip[0] in sel_ids] # keep CE order, filter by MMR picks - return final[:k_final] - - -# ) Query helper (cosine distance; operator may be <#> in sqlite-vec) -def vec_search(db: sqlite3.Connection, qtext, k=5): - model = get_embed_model() - q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32") - # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine - rows = db.execute( - "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", - (memoryview(q.tobytes()), k) - ).fetchall() - -# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40)) - return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows] + # 6) MMR + ce_ids = [i for (i,_,_) in ranked] + ce_texts = [t for (_,t,_) in ranked] + st_model = get_embed_model() + ce_vecs = st_model.encode(ce_texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32") + keep = set(mmr2(query_embeddings, ce_ids, ce_vecs, k=k_final, lamb=mmr_lambda)) + return [r for r in ranked if r[0] in keep][:k_final] # # Hybrid + CE rerank query: # results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8) @@ -99,3 +138,67 @@ def vec_search(db: sqlite3.Connection, qtext, k=5): # # # +def search_hybrid_with_mmr(db, col, query, k_vec=50, k_bm25=50, k_ce=30, k_final=10, lamb=0.7): + ranked = search_hybrid(db, col, query, k_vec, k_bm25, k_ce, k_ce) + if not ranked: return [] + ids = [i for (i,_,_) in ranked] + texts = [t for (_,t,_) in ranked] + st_model = get_embed_model() + qvec = st_model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") + cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32") + keep = set(mmr2(qvec, ids, cvecs, k=k_final, lamb=lamb)) + return [r for r in ranked if r[0] in keep][:k_final] + + + +# clean vec expansion says gpt5 +# # +# def expand(q, aliases=()): +# qs = [q, *aliases] +# # embed each, take max similarity per chunk at scoring time + + +# def dist_to_cos(d): return max(0.0, 1.0 - d/2.0) # L2 on unit vecs + +# def vec_topk(db, table, q_vec_f32, k): +# from sqlite_vec import serialize_float32 +# return db.execute( +# f"SELECT rowid, distance FROM {table} WHERE embedding MATCH ? ORDER BY distance LIMIT ?", +# (serialize_float32(q_vec_f32), k) +# ).fetchall() + +# def vec_search(db, st_model, col, qtext, k=12, k_raw=None, min_sim=0.30, use_mmr=True, mmr_lambda=0.7): +# if k_raw is None: k_raw = k*4 +# q = st_model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") +# rows = vec_topk(db, f"vec_{col}", q, k_raw) + +# hits = [] +# for rid, dist in rows: +# cos = dist_to_cos(dist) +# if cos < min_sim: continue +# txt = db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0] +# hits.append((rid, txt, cos)) + +# hits.sort(key=lambda x: x[2], reverse=True) +# if not use_mmr or len(hits) <= k: +# return hits[:k] + +# # MMR on the (already filtered) pool +# ids = [h[0] for h in hits] +# texts = [h[1] for h in hits] +# cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32).astype("float32") +# # simple MMR +# import numpy as np +# def cosine(a,b): return float(a@b) +# sel, sel_idx = [], [] +# rem = list(range(len(ids))) +# best0 = max(rem, key=lambda i: cosine(q, cvecs[i])); sel.append(ids[best0]); sel_idx.append(best0); rem.remove(best0) +# while rem and len(sel)<k: +# def score(i): +# rel = cosine(q, cvecs[i]) +# red = max(cosine(cvecs[i], cvecs[j]) for j in sel_idx) +# return mmr_lambda*rel - (1.0 - mmr_lambda)*red +# nxt = max(rem, key=score); sel.append(ids[nxt]); sel_idx.append(nxt); rem.remove(nxt) +# keep = set(sel) +# return [h for h in hits if h[0] in keep][:k] + |