1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
|
import sqlite3
import numpy as np
import torch
import gc
from typing import List, Tuple
from sentence_transformers import CrossEncoder, SentenceTransformer
from rag.ingest import get_embed_model
from rag.rerank import get_rerank_model, rerank_cross_encoder
from rag.mmr import mmr, embed_unit_np # if you added MMR; else remove
def search_hybrid(
db: sqlite3.Connection, query: str,
k_vec: int = 40, k_bm25: int = 40,
k_ce: int = 30, # rerank this many
k_final: int = 10, # return this many
use_mmr: bool = False, mmr_lambda: float = 0.7
):
emodel = get_embed_model()
# 1) embed query (unit, float32) for vec search + MMR
print("loading")
q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
emodel = None
with torch.no_grad():
torch.cuda.empty_cache()
gc.collect()
print("memory should be free by now!!")
qbytes = q.tobytes()
# 2) ANN (cosine) + BM25 pools
vec_ids = [i for (i, _) in db.execute(
"SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
(memoryview(qbytes), k_vec)
).fetchall()]
bm25_ids = [i for (i,) in db.execute(
"SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?",
(query, k_bm25)
).fetchall()]
# 3) merge (vector-first)
seen, merged = set(), []
for i in vec_ids + bm25_ids:
if i not in seen:
merged.append(i); seen.add(i)
if not merged:
return []
# 4) fetch texts for CE
qmarks = ",".join("?"*len(merged))
cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall()
reranker = get_rerank_model()
# 5) cross-encoder rerank (returns [(id,text,score)] desc)
ranked = rerank_cross_encoder(reranker, query, cand)
reranker = None
print("freeing again!!")
with torch.no_grad():
torch.cuda.empty_cache()
gc.collect()
#
ranked = ranked[:min(k_ce, len(ranked))]
if not use_mmr or len(ranked) <= k_final:
return ranked[:k_final]
# 6) MMR diversity on CE top-k_ce
cand_ids = [i for (i,_,_) in ranked]
cand_text = [t for (_,t,_) in ranked]
emodel = get_embed_model()
# god this is annoying I should stop being poor
cand_vecs = embed_unit_np(emodel, cand_text) # [N,D], unit vectors
sel_ids = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda))
final = [trip for trip in ranked if trip[0] in sel_ids] # keep CE order, filter by MMR picks
return final[:k_final]
# ) Query helper (cosine distance; operator may be <#> in sqlite-vec)
def vec_search(db: sqlite3.Connection, qtext, k=5):
model = get_embed_model()
q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
# Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
rows = db.execute(
"SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
(memoryview(q.tobytes()), k)
).fetchall()
# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
# # Hybrid + CE rerank query:
# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8)
# for rid, txt, score in results:
# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n")
#
#
#
#
#
#
#
|