1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
|
>>>> __init__.py
>>>> constants.py
EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM
# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
MAX_TOKENS = 600
BATCH = 16
>>>> db.py
import sqlite3
from numpy import ndarray
from sqlite_vec import serialize_float32
import sqlite_vec
def get_db():
db = sqlite3.connect("./rag.db")
db.execute("PRAGMA journal_mode=WAL;")
db.execute("PRAGMA synchronous=NORMAL;")
db.execute("PRAGMA temp_store=MEMORY;")
db.execute("PRAGMA mmap_size=300000000;") # 30GB if your OS allows
db.enable_load_extension(True)
sqlite_vec.load(db)
db.enable_load_extension(False)
return db
def init_schema(db: sqlite3.Connection, col: str, DIM:int, model_id: str, tok_id: str, normalize: bool, preproc_hash: str):
print("initing schema", col)
db.execute("""
CREATE TABLE IF NOT EXISTS collections(
name TEXT PRIMARY KEY,
model TEXT,
tokenizer TEXT,
dim INTEGER,
normalize INTEGER,
preproc_hash TEXT,
created_at INTEGER DEFAULT (unixepoch())
)""")
db.execute("BEGIN")
db.execute("INSERT INTO collections(name, model, tokenizer, dim, normalize, preproc_hash) VALUES(?, ?, ?, ?, ?, ?)",
(col, model_id, tok_id, DIM, int(normalize), preproc_hash))
db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec_{col} USING vec0(embedding float[{DIM}])")
db.execute(f'''
CREATE TABLE IF NOT EXISTS chunks_{col} (
id INTEGER PRIMARY KEY,
text TEXT
)''')
db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS fts_{col} USING fts5(text)")
db.commit()
def check_db(db: sqlite3.Connection, coll: str):
row = db.execute("SELECT dim, model, normalize FROM collections WHERE name=?", (coll,)).fetchone()
return row
# assert row and row[0] == DIM and row[1] == EMBED_MODEL_ID and row[2] == 1 # if you normalize
def check_db2(db: sqlite3.Connection, coll: str):
row = db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(f"vec_{coll}",)
).fetchone()
return bool(row)
def store_chunks(db: sqlite3.Connection, col: str, chunks: list[str], V_np:ndarray):
assert len(chunks) == len(V_np)
db.execute("BEGIN")
db.executemany(f'''
INSERT INTO chunks_{col}(id, text) VALUES (?, ?)
''', list(enumerate(chunks, start=1)))
db.executemany(
f"INSERT INTO vec_{col}(rowid, embedding) VALUES (?, ?)",
[(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]
)
db.executemany(f"INSERT INTO fts_{col}(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
db.commit()
def vec_topk(db,col: str, q_vec_f32, k=10):
# rows = db.execute(
# "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
# (memoryview(q.tobytes()), k)
# ).fetchall()
rows = db.execute(
f"SELECT rowid, distance FROM vec_{col} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
(serialize_float32(q_vec_f32), k)
).fetchall()
return rows # [(rowid, distance)]
def bm25_topk(db: sqlite3.Connection, col: str, qtext, k=10):
safe_q = f'"{qtext}"'
return [rid for (rid,) in db.execute(
f"SELECT rowid FROM fts_{col} WHERE fts_{col} MATCH ? LIMIT ?", (safe_q, k)
).fetchall()]
def wipe_db(col: str):
db = sqlite3.connect("./rag.db")
db.executescript(f"DROP TABLE IF EXISTS chunks_{col}; DROP TABLE IF EXISTS fts_{col}; DROP TABLE IF EXISTS vec_{col};")
db.close()
>>>> ingest.py
# from rag.rerank import search_hybrid
import sqlite3
import torch
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
from pathlib import Path
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from rag.db import get_db, init_schema, store_chunks
from sentence_transformers import SentenceTransformer
from rag.constants import MAX_TOKENS, BATCH, EMBED_MODEL_ID
def get_embed_model():
return SentenceTransformer(
EMBED_MODEL_ID,
model_kwargs={
# "trust_remote_code":True,
"attn_implementation":"flash_attention_2",
"device_map":"auto",
"dtype":torch.float16
}, tokenizer_kwargs={"padding_side": "left"}
)
def parse_and_chunk(source: Path, model: SentenceTransformer)-> list[str]:
tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS)
converter = DocumentConverter()
doc = converter.convert(source)
chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True)
out = []
for ch in chunker.chunk(doc.document):
txt = chunker.contextualize(ch)
if txt.strip():
out.append(txt)
return out
def embed_many(model: SentenceTransformer, texts: list[str]):
V_np = model.encode(texts,
batch_size=BATCH,
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=True)
return V_np.astype("float32")
def start_ingest(db: sqlite3.Connection, model: SentenceTransformer | None, collection: str, path: Path):
if model is None:
model = get_embed_model()
chunks = parse_and_chunk(path, model)
V_np = embed_many(model, chunks)
store_chunks(db, collection, chunks, V_np)
# TODO some try catch?
return True
# float32/fp16 on CPU? ensure float32 for DB:
>>>> main.py
import sys
import argparse
from pathlib import Path
from rag.constants import EMBED_MODEL_ID
from rag.ingest import get_embed_model, start_ingest
from rag.search import search_hybrid, vec_search
from rag.db import get_db, check_db, check_db2, init_schema
def valid_collection(col: str) -> bool:
# TODO must have less than 9 characters and be ascii, no spaces
return True
def cmd_ingest(args):
path = Path(args.file)
if not valid_collection(args.collection):
print(f"Collection name invalid: {args.collection}", file=sys.stderr)
sys.exit(1)
if not path.exists():
print(f"File not found: {path}", file=sys.stderr)
sys.exit(1)
db = get_db()
if not check_db2(db, args.collection):
model = get_embed_model()
dim = model.get_sentence_embedding_dimension()
if dim is None:
sys.exit(1)
# TODO Try catch in here, tell the user if it crashes
init_schema(db, args.collection, dim, EMBED_MODEL_ID, EMBED_MODEL_ID, True, 'idk')
stats = start_ingest(db, model, args.collection, path)
else:
stats = start_ingest(db, None, args.collection, path)
print(f"Ingested file={args.file} :: {stats}")
def cmd_query(args):
if not valid_collection(args.collection):
print(f"Collection name invalid: {args.collection}", file=sys.stderr)
sys.exit(1)
db = get_db()
if not check_db2(db, args.collection):
print(f"Collection name not in DB, what are you searching: {args.collection}", file=sys.stderr)
sys.exit(1)
if args.simple:
results = vec_search(db, args.collection, args.query, k=args.k_final)
else:
results = search_hybrid(db,
args.collection,
args.query,
k_vec=args.k_vec,
k_bm25=args.k_bm25,
k_ce=args.k_ce,
k_final=args.k_final,
use_mmr=args.mmr,
mmr_lambda=args.mmr_lambda,
)
for rid, txt, score in results:
print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n")
db.close()
def main():
ap = argparse.ArgumentParser(prog="rag")
sub = ap.add_subparsers(dest="cmd", required=True)
# ingest
ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection")
ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest")
ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)")
ap_ing.set_defaults(func=cmd_ingest)
# query
ap_q = sub.add_parser("query", help="Query a collection")
ap_q.add_argument("--collection", required=True, help="Collection name to search")
ap_q.add_argument("--query", required=True, help="User query text")
ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)")
ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE")
ap_q.add_argument("--mmr-lambda", type=float, default=0.7)
ap_q.add_argument("--k-vec", type=int, default=50)
ap_q.add_argument("--k-bm25", type=int, default=50)
ap_q.add_argument("--k-ce", type=int, default=30)
ap_q.add_argument("--k-final", type=int, default=10)
ap_q.set_defaults(func=cmd_query)
args = ap.parse_args()
args.func(args)
if __name__ == "__main__":
main()
>>>> mmr.py
from rag.constants import BATCH
import numpy as np
def cosine(a, b): return float(np.dot(a, b))
def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7):
"""cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids"""
selected, selected_idx = [], []
remaining = list(range(len(cand_ids)))
# seed with the most relevant
best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i]))
selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0)
while remaining and len(selected) < k:
def mmr_score(i):
rel = cosine(query_vec, cand_vecs[i])
red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0
return lamb * rel - (1.0 - lamb) * red
nxt = max(remaining, key=mmr_score)
selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt)
return selected
def embed_unit_np(st_model, texts: list[str]) -> np.ndarray:
V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=BATCH)
V = V.astype("float32", copy=False)
return V
def mmr2(qvec: np.ndarray, ids, vecs: np.ndarray, k=8, lamb=0.7):
sel, idxs = [], []
rest = list(range(len(ids)))
best0 = max(rest, key=lambda i: float(qvec @ vecs[i]))
sel.append(ids[best0]); idxs.append(best0); rest.remove(best0)
while rest and len(sel) < k:
def score(i):
rel = float(qvec @ vecs[i])
red = max(float(vecs[i] @ vecs[j]) for j in idxs)
return lamb*rel - (1-lamb)*red
nxt = max(rest, key=score)
sel.append(ids[nxt]); idxs.append(nxt); rest.remove(nxt)
return sel
>>>> nmain.py
# rag/main.py
import argparse, sqlite3, sys
from pathlib import Path
from sentence_transformers import SentenceTransformer
from rag.ingest import start_ingest
from rag.search import search_hybrid, search as vec_search
DB_PATH = "./rag.db"
EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
def open_db():
db = sqlite3.connect(DB_PATH)
# speed-ish pragmas
db.execute("PRAGMA journal_mode=WAL;")
db.execute("PRAGMA synchronous=NORMAL;")
return db
def load_st_model():
# ST handles batching + GPU internally
return SentenceTransformer(
EMBED_MODEL_ID,
model_kwargs={
"attn_implementation": "flash_attention_2",
"device_map": "auto",
"torch_dtype": "float16",
},
tokenizer_kwargs={"padding_side": "left"},
)
def ensure_collection_exists(db, collection: str):
row = db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(f"vec_{collection}",)
).fetchone()
return bool(row)
def cmd_ingest(args):
db = open_db()
st_model = load_st_model()
path = Path(args.file)
if not path.exists():
print(f"File not found: {path}", file=sys.stderr)
sys.exit(1)
if args.rebuild and ensure_collection_exists(db, args.collection):
db.executescript(f"""
DROP TABLE IF EXISTS chunks_{args.collection};
DROP TABLE IF EXISTS fts_{args.collection};
DROP TABLE IF EXISTS vec_{args.collection};
""")
stats = start_ingest(
db, st_model,
path=path,
collection=args.collection,
)
print(f"Ingested collection={args.collection} :: {stats}")
db.close()
def cmd_query(args):
db = open_db()
st_model = load_st_model()
coll_ok = ensure_collection_exists(db, args.collection)
if not coll_ok:
print(f"Collection '{args.collection}' not found. Ingest first.", file=sys.stderr)
sys.exit(2)
if args.simple:
results = vec_search(db, st_model, args.query, collection=args.collection, k=args.k_final)
else:
results = search_hybrid(
db, st_model, args.query,
collection=args.collection,
k_vec=args.k_vec,
k_bm25=args.k_bm25,
k_ce=args.k_ce,
k_final=args.k_final,
use_mmr=args.mmr,
mmr_lambda=args.mmr_lambda,
)
for rid, txt, score in results:
print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n")
db.close()
def main():
ap = argparse.ArgumentParser(prog="rag")
sub = ap.add_subparsers(dest="cmd", required=True)
# ingest
ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection")
ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest")
ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)")
ap_ing.add_argument("--rebuild", action="store_true", help="Drop and recreate collection tables")
ap_ing.set_defaults(func=cmd_ingest)
# query
ap_q = sub.add_parser("query", help="Query a collection")
ap_q.add_argument("--collection", required=True, help="Collection name to search")
ap_q.add_argument("--query", required=True, help="User query text")
ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)")
ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE")
ap_q.add_argument("--mmr-lambda", type=float, default=0.7)
ap_q.add_argument("--k-vec", type=int, default=50)
ap_q.add_argument("--k-bm25", type=int, default=50)
ap_q.add_argument("--k-ce", type=int, default=30)
ap_q.add_argument("--k-final", type=int, default=10)
ap_q.set_defaults(func=cmd_query)
args = ap.parse_args()
args.func(args)
if __name__ == "__main__":
main()
>>>> rerank.py
# rag/rerank.py
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import CrossEncoder
from rag.constants import BATCH, RERANKER_ID
# device: "cuda" | "cpu" | "mps"
def get_rerank_model():
id = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
return CrossEncoder(
id,
device="cuda",
max_length=384,
model_kwargs={
# "attn_implementation":"flash_attention_2",
"device_map":"auto",
"dtype":torch.float16
},
tokenizer_kwargs={"padding_side": "left"}
)
def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = BATCH):
"""
candidates: [(id, text), ...]
returns: [(id, text, score)] sorted desc by score
"""
if not candidates:
return []
ids, texts = zip(*candidates)
# pairs = [(query, t) for t in texts]
pairs = [(query, t[:1000]) for t in texts]
scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better
ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
return ranked
# tok = AutoTokenizer.from_pretrained(RERANKER_ID, use_fast=True, model_max_length=384)
# ce = AutoModelForSequenceClassification.from_pretrained(
# RERANKER_ID, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
# device_map="auto" # or load_in_8bit=True
# )
# def ce_scores(query, texts, batch_size=16, max_length=384):
# scores = []
# for i in range(0, len(texts), batch_size):
# batch = texts[i:i+batch_size]
# enc = tok([ (query, t[:1000]) for t in batch ],
# padding=True, truncation=True, max_length=max_length,
# return_tensors="pt").to(ce.device)
# with torch.inference_mode():
# logits = ce(**enc).logits.squeeze(-1) # [B]
# scores.extend(logits.float().cpu().tolist())
# return scores
>>>> search.py
import sqlite3
import numpy as np
import torch
import gc
from typing import List, Tuple
from sentence_transformers import CrossEncoder, SentenceTransformer
from rag.constants import BATCH
from rag.ingest import get_embed_model
from rag.rerank import get_rerank_model, rerank_cross_encoder
from rag.mmr import mmr, mmr2, embed_unit_np # if you added MMR; else remove
from rag.db import vec_topk, bm25_topk
# ) Query helper (cosine distance; operator may be <#> in sqlite-vec)
def vec_search(db: sqlite3.Connection, col: str, qtext, k=5):
model = get_embed_model()
q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
# Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
rows = vec_topk(db, col, q, k)
# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
return [(rid, db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
def search_hybrid(
db: sqlite3.Connection, col: str, query: str,
k_vec: int = 40, k_bm25: int = 40,
k_ce: int = 30, # rerank this many
k_final: int = 10, # return this many
use_mmr: bool = False, mmr_lambda: float = 0.7
):
emodel = get_embed_model()
# 1) embed query (unit, float32) for vec search + MMR
print("loading")
query_embeddings = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
emodel = None
with torch.no_grad():
torch.cuda.empty_cache()
gc.collect()
print("memory should be free by now!!")
# qbytes = q.tobytes()
# 2) ANN + BM25
print("phase 2", col, query)
vhits = vec_topk(db, col, query_embeddings, k_vec) # [(id, dist)]
vh_ids = [i for (i, _) in vhits]
bm_ids = bm25_topk(db, col, query, k_bm25)
#
# 3) merge ids [vector-first]
merged, seen = [], set()
for i in vh_ids + bm_ids:
if i not in seen:
merged.append(i); seen.add(i)
if not merged:
return []
# 4) fetch texts
qmarks = ",".join("?"*len(merged))
cand = db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN ({qmarks})", merged).fetchall()
ids, texts = zip(*cand)
# 5) rerank
print("loading reranking model")
reranker = get_rerank_model()
scores = reranker.predict([(query, t[:1000]) for t in texts], batch_size=BATCH)
reranker =None
with torch.no_grad():
torch.cuda.empty_cache()
gc.collect()
print("memory should be free by now!!")
print("unloading reranking model")
ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
if not use_mmr or len(ranked) <= k_final:
return ranked[:min(k_ce, k_final)]
# 6) MMR
ce_ids = [i for (i,_,_) in ranked]
ce_texts = [t for (_,t,_) in ranked]
st_model = get_embed_model()
ce_vecs = st_model.encode(ce_texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32")
keep = set(mmr2(query_embeddings, ce_ids, ce_vecs, k=k_final, lamb=mmr_lambda))
return [r for r in ranked if r[0] in keep][:k_final]
# # Hybrid + CE rerank query:
# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8)
# for rid, txt, score in results:
# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n")
#
#
#
#
#
#
#
def search_hybrid_with_mmr(db, col, query, k_vec=50, k_bm25=50, k_ce=30, k_final=10, lamb=0.7):
ranked = search_hybrid(db, col, query, k_vec, k_bm25, k_ce, k_ce)
if not ranked: return []
ids = [i for (i,_,_) in ranked]
texts = [t for (_,t,_) in ranked]
st_model = get_embed_model()
qvec = st_model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32")
keep = set(mmr2(qvec, ids, cvecs, k=k_final, lamb=lamb))
return [r for r in ranked if r[0] in keep][:k_final]
>>>> test.py
from pathlib import Path
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
from transformers import AutoModel
from sentence_transformers import SentenceTransformer
from rag.db import get_db
from rag.rerank import get_rerank_model
import rag.ingest
import rag.search
import torch
# converter = DocumentConverter()
# chunker = HybridChunker()
# file = Path("yek.md")
# doc = converter.convert(file).document
# chunk_iter = chunker.chunk(doc)
# for chunk in chunk_iter:
# print(chunk)
# txt = chunker.contextualize(chunk)
# print(txt)
def t():
batch: list[str] = ["This son of a bitch has gone too far", "Fuck me baby please", "I'm hungry now", "Charlie Kirk is dead"]
model = rag.ingest.get_embed_model()
v = model.encode("pepeee", normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True)
print("v")
print(type(v))
print(v)
print(v.dtype)
print(v.device)
# V = torch.cat([v], dim=0)
# print("V")
# print(type(V))
# print(V)
# print(V.dtype)
# print(V.device)
# print("V_np")
# V_idk = V.cpu().float()
# when they were pytorch tensors
# V = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB:
# V_np = V.float().cpu().numpy().astype("float32")
# DIM = V_np.shape[1]
# db = sqlite3.connect("./rag.db")
queries = [
"How was Shuihu zhuan received in early modern Japan?",
"Edo-period readers’ image of Song Jiang / Liangshan outlaws",
"Channels of transmission for Chinese vernacular fiction into Japan (kanbun kundoku, digests, translations)",
"Role of woodblock prints/illustrations in mediating Chinese fiction",
"Key Japanese scholars, writers, or publishers who popularized Chinese fiction",
"Kyokutei Bakin’s engagement with Chinese vernacular narrative",
"Santō Kyōden, gesaku, and Chinese models",
"Kanzen chōaku (encourage good, punish evil) and Water Margin in Japan",
"Moral ambivalence of outlaw heroes as discussed in the text",
"Censorship or moral debates around reading Chinese fiction",
"Translation strategies from vernacular Chinese to Japanese (furigana, kundoku, glossing)",
"Paratexts: prefaces, commentaries, reader guidance apparatus",
"Bibliographic details: editions, reprints, circulation networks",
"How does this book challenge older narratives about Sino-Japanese literary influence?",
"Methodology: sources, archives, limitations mentioned by the author",
]
def t2():
db = get_db()
# Hybrid + CE rerank query:
for query in queries:
print("query", query)
print("-----------\n\n")
# results = rag.search.search_hybrid(db, "muh", query, k_vec=50, k_bm25=50, k_final=8)
# for rid, txt, score in results:
# sim = score
# print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n")
results = rag.search.vec_search(db, "muh", query, k_vec=50, k_bm25=50, k_final=8)
for rid, txt, score in results:
sim = score
print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n")
t2()
>>>> utils.py
import re
FTS_META_CHARS = r'''["'*()^+-]''' # include ? if you see issues
def sanitize_query(q: str, *, allow_ops: bool = False) -> str:
q = q.strip()
if not q:
return q
if allow_ops:
# escape stray double quotes inside, then wrap
q = q.replace('"', '""')
return f'"{q}"'
# literal search: quote and escape special chars
q = re.sub(r'"', '""', q)
return f'"{q}"'
|