1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
|
import sqlite3
import numpy as np
import torch
import gc
from typing import List, Tuple
from sentence_transformers import CrossEncoder, SentenceTransformer
from rag.constants import BATCH
from rag.ingest import get_embed_model
from rag.rerank import get_rerank_model, rerank_cross_encoder
from rag.mmr import mmr, mmr2, embed_unit_np # if you added MMR; else remove
from rag.db import vec_topk, bm25_topk, fetch_chunk
# ) Query helper (cosine distance; operator may be <#> in sqlite-vec)
# def vec_search(db: sqlite3.Connection, col: str, qtext, k=5):
# model = get_embed_model()
# q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
# # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
# rows = vec_topk(db, col, q, k)
# # db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
# return [(rid, db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
def _dist_to_sim(dist: float) -> float:
# L2 on unit vectors ↔ cosine: ||a-b||^2 = 2 - 2 cos => cos = 1 - dist/2
return max(0.0, 1.0 - dist / 2.0)
def vec_search(db, model: SentenceTransformer, col: str, qtext: str, k: int = 10, min_sim: float = 0.25,
max_per_doc: int | None = None, use_mmr: bool = False, mmr_lambda: float = 0.7):
q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") # <- flatten
rows = vec_topk(db, col, q, k * 4) # overfetch a bit, filter below
# fetch texts + compute cosine sim
hits = []
for rid, dist in rows:
txt = fetch_chunk(db, col, rid)
sim = _dist_to_sim(dist)
if sim >= min_sim:
hits.append((rid, txt, sim))
# anti-spam (cap near-duplicates from same doc region if you add metadata later)
if max_per_doc:
capped, seen = [], {}
for rid, txt, sim in hits:
dockey = col # or derive from a future chunk_meta table
cnt = seen.get(dockey, 0)
if cnt < max_per_doc:
capped.append((rid, txt, sim))
seen[dockey] = cnt + 1
hits = capped
# optional light MMR on the filtered set (diversify)
if use_mmr and len(hits) > k:
from rag.mmr import embed_unit_np, mmr2
ids = [h[0] for h in hits]
texts = [h[1] for h in hits]
qvec = q
cvecs = embed_unit_np(model, texts) # [N,D] unit
keep = set(mmr2(qvec, ids, cvecs, k=k, lamb=mmr_lambda))
hits = [h for h in hits if h[0] in keep]
# final crop
hits.sort(key=lambda x: x[2], reverse=True)
return hits[:k]
def search_hybrid(
db: sqlite3.Connection, col: str, query: str,
k_vec: int = 40, k_bm25: int = 40,
k_ce: int = 30, # rerank this many
k_final: int = 10, # return this many
use_mmr: bool = False, mmr_lambda: float = 0.7
):
emodel = get_embed_model()
# 1) embed query (unit, float32) for vec search + MMR
print("loading")
query_embeddings = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
emodel = None
with torch.no_grad():
torch.cuda.empty_cache()
gc.collect()
print("memory should be free by now!!")
# qbytes = q.tobytes()
# 2) ANN + BM25
print("phase 2", col, query)
vhits = vec_topk(db, col, query_embeddings, k_vec) # [(id, dist)]
vh_ids = [i for (i, _) in vhits]
bm_ids = bm25_topk(db, col, query, k_bm25)
#
# 3) merge ids [vector-first]
merged, seen = [], set()
for i in vh_ids + bm_ids:
if i not in seen:
merged.append(i); seen.add(i)
if not merged:
return []
# 4) fetch texts
qmarks = ",".join("?"*len(merged))
cand = db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN ({qmarks})", merged).fetchall()
ids, texts = zip(*cand)
# 5) rerank
print("loading reranking model")
reranker = get_rerank_model()
scores = reranker.predict([(query, t[:1000]) for t in texts], batch_size=BATCH)
reranker =None
with torch.no_grad():
torch.cuda.empty_cache()
gc.collect()
print("memory should be free by now!!")
print("unloading reranking model")
ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
if not use_mmr or len(ranked) <= k_final:
return ranked[:min(k_ce, k_final)]
# 6) MMR
ce_ids = [i for (i,_,_) in ranked]
ce_texts = [t for (_,t,_) in ranked]
st_model = get_embed_model()
ce_vecs = st_model.encode(ce_texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32")
keep = set(mmr2(query_embeddings, ce_ids, ce_vecs, k=k_final, lamb=mmr_lambda))
return [r for r in ranked if r[0] in keep][:k_final]
# # Hybrid + CE rerank query:
# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8)
# for rid, txt, score in results:
# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n")
#
#
#
#
#
#
#
def search_hybrid_with_mmr(db, col, query, k_vec=50, k_bm25=50, k_ce=30, k_final=10, lamb=0.7):
ranked = search_hybrid(db, col, query, k_vec, k_bm25, k_ce, k_ce)
if not ranked: return []
ids = [i for (i,_,_) in ranked]
texts = [t for (_,t,_) in ranked]
st_model = get_embed_model()
qvec = st_model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32")
keep = set(mmr2(qvec, ids, cvecs, k=k_final, lamb=lamb))
return [r for r in ranked if r[0] in keep][:k_final]
# clean vec expansion says gpt5
# #
# def expand(q, aliases=()):
# qs = [q, *aliases]
# # embed each, take max similarity per chunk at scoring time
# def dist_to_cos(d): return max(0.0, 1.0 - d/2.0) # L2 on unit vecs
# def vec_topk(db, table, q_vec_f32, k):
# from sqlite_vec import serialize_float32
# return db.execute(
# f"SELECT rowid, distance FROM {table} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
# (serialize_float32(q_vec_f32), k)
# ).fetchall()
# def vec_search(db, st_model, col, qtext, k=12, k_raw=None, min_sim=0.30, use_mmr=True, mmr_lambda=0.7):
# if k_raw is None: k_raw = k*4
# q = st_model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
# rows = vec_topk(db, f"vec_{col}", q, k_raw)
# hits = []
# for rid, dist in rows:
# cos = dist_to_cos(dist)
# if cos < min_sim: continue
# txt = db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0]
# hits.append((rid, txt, cos))
# hits.sort(key=lambda x: x[2], reverse=True)
# if not use_mmr or len(hits) <= k:
# return hits[:k]
# # MMR on the (already filtered) pool
# ids = [h[0] for h in hits]
# texts = [h[1] for h in hits]
# cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32).astype("float32")
# # simple MMR
# import numpy as np
# def cosine(a,b): return float(a@b)
# sel, sel_idx = [], []
# rem = list(range(len(ids)))
# best0 = max(rem, key=lambda i: cosine(q, cvecs[i])); sel.append(ids[best0]); sel_idx.append(best0); rem.remove(best0)
# while rem and len(sel)<k:
# def score(i):
# rel = cosine(q, cvecs[i])
# red = max(cosine(cvecs[i], cvecs[j]) for j in sel_idx)
# return mmr_lambda*rel - (1.0 - mmr_lambda)*red
# nxt = max(rem, key=score); sel.append(ids[nxt]); sel_idx.append(nxt); rem.remove(nxt)
# keep = set(sel)
# return [h for h in hits if h[0] in keep][:k]
|