import sqlite3 import numpy as np import torch import gc from typing import List, Tuple from sentence_transformers import CrossEncoder, SentenceTransformer from rag.ingest import get_embed_model from rag.rerank import get_rerank_model, rerank_cross_encoder from rag.mmr import mmr, embed_unit_np # if you added MMR; else remove def search_hybrid( db: sqlite3.Connection, query: str, k_vec: int = 40, k_bm25: int = 40, k_ce: int = 30, # rerank this many k_final: int = 10, # return this many use_mmr: bool = False, mmr_lambda: float = 0.7 ): emodel = get_embed_model() # 1) embed query (unit, float32) for vec search + MMR print("loading") q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") emodel = None with torch.no_grad(): torch.cuda.empty_cache() gc.collect() print("memory should be free by now!!") qbytes = q.tobytes() # 2) ANN (cosine) + BM25 pools vec_ids = [i for (i, _) in db.execute( "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", (memoryview(qbytes), k_vec) ).fetchall()] bm25_ids = [i for (i,) in db.execute( "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (query, k_bm25) ).fetchall()] # 3) merge (vector-first) seen, merged = set(), [] for i in vec_ids + bm25_ids: if i not in seen: merged.append(i); seen.add(i) if not merged: return [] # 4) fetch texts for CE qmarks = ",".join("?"*len(merged)) cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall() reranker = get_rerank_model() # 5) cross-encoder rerank (returns [(id,text,score)] desc) ranked = rerank_cross_encoder(reranker, query, cand) reranker = None print("freeing again!!") with torch.no_grad(): torch.cuda.empty_cache() gc.collect() # ranked = ranked[:min(k_ce, len(ranked))] if not use_mmr or len(ranked) <= k_final: return ranked[:k_final] # 6) MMR diversity on CE top-k_ce cand_ids = [i for (i,_,_) in ranked] cand_text = [t for (_,t,_) in ranked] emodel = get_embed_model() # god this is annoying I should stop being poor cand_vecs = embed_unit_np(emodel, cand_text) # [N,D], unit vectors sel_ids = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda)) final = [trip for trip in ranked if trip[0] in sel_ids] # keep CE order, filter by MMR picks return final[:k_final] # ) Query helper (cosine distance; operator may be <#> in sqlite-vec) def vec_search(db: sqlite3.Connection, qtext, k=5): model = get_embed_model() q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32") # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine rows = db.execute( "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", (memoryview(q.tobytes()), k) ).fetchall() # db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40)) return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows] # # Hybrid + CE rerank query: # results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8) # for rid, txt, score in results: # print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n") # # # # # # #