import sqlite3 import numpy as np import torch import gc from typing import List, Tuple from sentence_transformers import CrossEncoder, SentenceTransformer from rag.constants import BATCH from rag.ingest import get_embed_model from rag.rerank import get_rerank_model, rerank_cross_encoder from rag.mmr import mmr, mmr2, embed_unit_np # if you added MMR; else remove from rag.db import vec_topk, bm25_topk, fetch_chunk # ) Query helper (cosine distance; operator may be <#> in sqlite-vec) # def vec_search(db: sqlite3.Connection, col: str, qtext, k=5): # model = get_embed_model() # q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32") # # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine # rows = vec_topk(db, col, q, k) # # db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40)) # return [(rid, db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows] def _dist_to_sim(dist: float) -> float: # L2 on unit vectors ↔ cosine: ||a-b||^2 = 2 - 2 cos => cos = 1 - dist/2 return max(0.0, 1.0 - dist / 2.0) def vec_search(db, model: SentenceTransformer, col: str, qtext: str, k: int = 10, min_sim: float = 0.25, max_per_doc: int | None = None, use_mmr: bool = False, mmr_lambda: float = 0.7): q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") # <- flatten rows = vec_topk(db, col, q, k * 4) # overfetch a bit, filter below # fetch texts + compute cosine sim hits = [] for rid, dist in rows: txt = fetch_chunk(db, col, rid) sim = _dist_to_sim(dist) if sim >= min_sim: hits.append((rid, txt, sim)) # anti-spam (cap near-duplicates from same doc region if you add metadata later) if max_per_doc: capped, seen = [], {} for rid, txt, sim in hits: dockey = col # or derive from a future chunk_meta table cnt = seen.get(dockey, 0) if cnt < max_per_doc: capped.append((rid, txt, sim)) seen[dockey] = cnt + 1 hits = capped # optional light MMR on the filtered set (diversify) if use_mmr and len(hits) > k: from rag.mmr import embed_unit_np, mmr2 ids = [h[0] for h in hits] texts = [h[1] for h in hits] qvec = q cvecs = embed_unit_np(model, texts) # [N,D] unit keep = set(mmr2(qvec, ids, cvecs, k=k, lamb=mmr_lambda)) hits = [h for h in hits if h[0] in keep] # final crop hits.sort(key=lambda x: x[2], reverse=True) return hits[:k] def search_hybrid( db: sqlite3.Connection, col: str, query: str, k_vec: int = 40, k_bm25: int = 40, k_ce: int = 30, # rerank this many k_final: int = 10, # return this many use_mmr: bool = False, mmr_lambda: float = 0.7 ): emodel = get_embed_model() # 1) embed query (unit, float32) for vec search + MMR print("loading") query_embeddings = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") emodel = None with torch.no_grad(): torch.cuda.empty_cache() gc.collect() print("memory should be free by now!!") # qbytes = q.tobytes() # 2) ANN + BM25 print("phase 2", col, query) vhits = vec_topk(db, col, query_embeddings, k_vec) # [(id, dist)] vh_ids = [i for (i, _) in vhits] bm_ids = bm25_topk(db, col, query, k_bm25) # # 3) merge ids [vector-first] merged, seen = [], set() for i in vh_ids + bm_ids: if i not in seen: merged.append(i); seen.add(i) if not merged: return [] # 4) fetch texts qmarks = ",".join("?"*len(merged)) cand = db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN ({qmarks})", merged).fetchall() ids, texts = zip(*cand) # 5) rerank print("loading reranking model") reranker = get_rerank_model() scores = reranker.predict([(query, t[:1000]) for t in texts], batch_size=BATCH) reranker =None with torch.no_grad(): torch.cuda.empty_cache() gc.collect() print("memory should be free by now!!") print("unloading reranking model") ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) if not use_mmr or len(ranked) <= k_final: return ranked[:min(k_ce, k_final)] # 6) MMR ce_ids = [i for (i,_,_) in ranked] ce_texts = [t for (_,t,_) in ranked] st_model = get_embed_model() ce_vecs = st_model.encode(ce_texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32") keep = set(mmr2(query_embeddings, ce_ids, ce_vecs, k=k_final, lamb=mmr_lambda)) return [r for r in ranked if r[0] in keep][:k_final] # # Hybrid + CE rerank query: # results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8) # for rid, txt, score in results: # print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n") # # # # # # # def search_hybrid_with_mmr(db, col, query, k_vec=50, k_bm25=50, k_ce=30, k_final=10, lamb=0.7): ranked = search_hybrid(db, col, query, k_vec, k_bm25, k_ce, k_ce) if not ranked: return [] ids = [i for (i,_,_) in ranked] texts = [t for (_,t,_) in ranked] st_model = get_embed_model() qvec = st_model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32") keep = set(mmr2(qvec, ids, cvecs, k=k_final, lamb=lamb)) return [r for r in ranked if r[0] in keep][:k_final] # clean vec expansion says gpt5 # # # def expand(q, aliases=()): # qs = [q, *aliases] # # embed each, take max similarity per chunk at scoring time # def dist_to_cos(d): return max(0.0, 1.0 - d/2.0) # L2 on unit vecs # def vec_topk(db, table, q_vec_f32, k): # from sqlite_vec import serialize_float32 # return db.execute( # f"SELECT rowid, distance FROM {table} WHERE embedding MATCH ? ORDER BY distance LIMIT ?", # (serialize_float32(q_vec_f32), k) # ).fetchall() # def vec_search(db, st_model, col, qtext, k=12, k_raw=None, min_sim=0.30, use_mmr=True, mmr_lambda=0.7): # if k_raw is None: k_raw = k*4 # q = st_model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") # rows = vec_topk(db, f"vec_{col}", q, k_raw) # hits = [] # for rid, dist in rows: # cos = dist_to_cos(dist) # if cos < min_sim: continue # txt = db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0] # hits.append((rid, txt, cos)) # hits.sort(key=lambda x: x[2], reverse=True) # if not use_mmr or len(hits) <= k: # return hits[:k] # # MMR on the (already filtered) pool # ids = [h[0] for h in hits] # texts = [h[1] for h in hits] # cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32).astype("float32") # # simple MMR # import numpy as np # def cosine(a,b): return float(a@b) # sel, sel_idx = [], [] # rem = list(range(len(ids))) # best0 = max(rem, key=lambda i: cosine(q, cvecs[i])); sel.append(ids[best0]); sel_idx.append(best0); rem.remove(best0) # while rem and len(sel)