>>>> __init__.py >>>> constants.py EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM # RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM MAX_TOKENS = 600 BATCH = 16 >>>> db.py import sqlite3 from numpy import ndarray from sqlite_vec import serialize_float32 import sqlite_vec def get_db(): db = sqlite3.connect("./rag.db") db.execute("PRAGMA journal_mode=WAL;") db.execute("PRAGMA synchronous=NORMAL;") db.execute("PRAGMA temp_store=MEMORY;") db.execute("PRAGMA mmap_size=300000000;") # 30GB if your OS allows db.enable_load_extension(True) sqlite_vec.load(db) db.enable_load_extension(False) return db def init_schema(db: sqlite3.Connection, col: str, DIM:int, model_id: str, tok_id: str, normalize: bool, preproc_hash: str): print("initing schema", col) db.execute(""" CREATE TABLE IF NOT EXISTS collections( name TEXT PRIMARY KEY, model TEXT, tokenizer TEXT, dim INTEGER, normalize INTEGER, preproc_hash TEXT, created_at INTEGER DEFAULT (unixepoch()) )""") db.execute("BEGIN") db.execute("INSERT INTO collections(name, model, tokenizer, dim, normalize, preproc_hash) VALUES(?, ?, ?, ?, ?, ?)", (col, model_id, tok_id, DIM, int(normalize), preproc_hash)) db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec_{col} USING vec0(embedding float[{DIM}])") db.execute(f''' CREATE TABLE IF NOT EXISTS chunks_{col} ( id INTEGER PRIMARY KEY, text TEXT )''') db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS fts_{col} USING fts5(text)") db.commit() def check_db(db: sqlite3.Connection, coll: str): row = db.execute("SELECT dim, model, normalize FROM collections WHERE name=?", (coll,)).fetchone() return row # assert row and row[0] == DIM and row[1] == EMBED_MODEL_ID and row[2] == 1 # if you normalize def check_db2(db: sqlite3.Connection, coll: str): row = db.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (f"vec_{coll}",) ).fetchone() return bool(row) def store_chunks(db: sqlite3.Connection, col: str, chunks: list[str], V_np:ndarray): assert len(chunks) == len(V_np) db.execute("BEGIN") db.executemany(f''' INSERT INTO chunks_{col}(id, text) VALUES (?, ?) ''', list(enumerate(chunks, start=1))) db.executemany( f"INSERT INTO vec_{col}(rowid, embedding) VALUES (?, ?)", [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))] ) db.executemany(f"INSERT INTO fts_{col}(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) db.commit() def vec_topk(db,col: str, q_vec_f32, k=10): # rows = db.execute( # "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", # (memoryview(q.tobytes()), k) # ).fetchall() rows = db.execute( f"SELECT rowid, distance FROM vec_{col} WHERE embedding MATCH ? ORDER BY distance LIMIT ?", (serialize_float32(q_vec_f32), k) ).fetchall() return rows # [(rowid, distance)] def bm25_topk(db: sqlite3.Connection, col: str, qtext, k=10): safe_q = f'"{qtext}"' return [rid for (rid,) in db.execute( f"SELECT rowid FROM fts_{col} WHERE fts_{col} MATCH ? LIMIT ?", (safe_q, k) ).fetchall()] def wipe_db(col: str): db = sqlite3.connect("./rag.db") db.executescript(f"DROP TABLE IF EXISTS chunks_{col}; DROP TABLE IF EXISTS fts_{col}; DROP TABLE IF EXISTS vec_{col};") db.close() >>>> ingest.py # from rag.rerank import search_hybrid import sqlite3 import torch from docling_core.transforms.chunker.hybrid_chunker import HybridChunker from pathlib import Path from docling.document_converter import DocumentConverter from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer from rag.db import get_db, init_schema, store_chunks from sentence_transformers import SentenceTransformer from rag.constants import MAX_TOKENS, BATCH, EMBED_MODEL_ID def get_embed_model(): return SentenceTransformer( EMBED_MODEL_ID, model_kwargs={ # "trust_remote_code":True, "attn_implementation":"flash_attention_2", "device_map":"auto", "dtype":torch.float16 }, tokenizer_kwargs={"padding_side": "left"} ) def parse_and_chunk(source: Path, model: SentenceTransformer)-> list[str]: tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS) converter = DocumentConverter() doc = converter.convert(source) chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True) out = [] for ch in chunker.chunk(doc.document): txt = chunker.contextualize(ch) if txt.strip(): out.append(txt) return out def embed_many(model: SentenceTransformer, texts: list[str]): V_np = model.encode(texts, batch_size=BATCH, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True) return V_np.astype("float32") def start_ingest(db: sqlite3.Connection, model: SentenceTransformer | None, collection: str, path: Path): if model is None: model = get_embed_model() chunks = parse_and_chunk(path, model) V_np = embed_many(model, chunks) store_chunks(db, collection, chunks, V_np) # TODO some try catch? return True # float32/fp16 on CPU? ensure float32 for DB: >>>> main.py import sys import argparse from pathlib import Path from rag.constants import EMBED_MODEL_ID from rag.ingest import get_embed_model, start_ingest from rag.search import search_hybrid, vec_search from rag.db import get_db, check_db, check_db2, init_schema def valid_collection(col: str) -> bool: # TODO must have less than 9 characters and be ascii, no spaces return True def cmd_ingest(args): path = Path(args.file) if not valid_collection(args.collection): print(f"Collection name invalid: {args.collection}", file=sys.stderr) sys.exit(1) if not path.exists(): print(f"File not found: {path}", file=sys.stderr) sys.exit(1) db = get_db() if not check_db2(db, args.collection): model = get_embed_model() dim = model.get_sentence_embedding_dimension() if dim is None: sys.exit(1) # TODO Try catch in here, tell the user if it crashes init_schema(db, args.collection, dim, EMBED_MODEL_ID, EMBED_MODEL_ID, True, 'idk') stats = start_ingest(db, model, args.collection, path) else: stats = start_ingest(db, None, args.collection, path) print(f"Ingested file={args.file} :: {stats}") def cmd_query(args): if not valid_collection(args.collection): print(f"Collection name invalid: {args.collection}", file=sys.stderr) sys.exit(1) db = get_db() if not check_db2(db, args.collection): print(f"Collection name not in DB, what are you searching: {args.collection}", file=sys.stderr) sys.exit(1) if args.simple: results = vec_search(db, args.collection, args.query, k=args.k_final) else: results = search_hybrid(db, args.collection, args.query, k_vec=args.k_vec, k_bm25=args.k_bm25, k_ce=args.k_ce, k_final=args.k_final, use_mmr=args.mmr, mmr_lambda=args.mmr_lambda, ) for rid, txt, score in results: print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") db.close() def main(): ap = argparse.ArgumentParser(prog="rag") sub = ap.add_subparsers(dest="cmd", required=True) # ingest ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection") ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest") ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)") ap_ing.set_defaults(func=cmd_ingest) # query ap_q = sub.add_parser("query", help="Query a collection") ap_q.add_argument("--collection", required=True, help="Collection name to search") ap_q.add_argument("--query", required=True, help="User query text") ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)") ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE") ap_q.add_argument("--mmr-lambda", type=float, default=0.7) ap_q.add_argument("--k-vec", type=int, default=50) ap_q.add_argument("--k-bm25", type=int, default=50) ap_q.add_argument("--k-ce", type=int, default=30) ap_q.add_argument("--k-final", type=int, default=10) ap_q.set_defaults(func=cmd_query) args = ap.parse_args() args.func(args) if __name__ == "__main__": main() >>>> mmr.py from rag.constants import BATCH import numpy as np def cosine(a, b): return float(np.dot(a, b)) def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7): """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids""" selected, selected_idx = [], [] remaining = list(range(len(cand_ids))) # seed with the most relevant best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i])) selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0) while remaining and len(selected) < k: def mmr_score(i): rel = cosine(query_vec, cand_vecs[i]) red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0 return lamb * rel - (1.0 - lamb) * red nxt = max(remaining, key=mmr_score) selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt) return selected def embed_unit_np(st_model, texts: list[str]) -> np.ndarray: V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=BATCH) V = V.astype("float32", copy=False) return V def mmr2(qvec: np.ndarray, ids, vecs: np.ndarray, k=8, lamb=0.7): sel, idxs = [], [] rest = list(range(len(ids))) best0 = max(rest, key=lambda i: float(qvec @ vecs[i])) sel.append(ids[best0]); idxs.append(best0); rest.remove(best0) while rest and len(sel) < k: def score(i): rel = float(qvec @ vecs[i]) red = max(float(vecs[i] @ vecs[j]) for j in idxs) return lamb*rel - (1-lamb)*red nxt = max(rest, key=score) sel.append(ids[nxt]); idxs.append(nxt); rest.remove(nxt) return sel >>>> nmain.py # rag/main.py import argparse, sqlite3, sys from pathlib import Path from sentence_transformers import SentenceTransformer from rag.ingest import start_ingest from rag.search import search_hybrid, search as vec_search DB_PATH = "./rag.db" EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" def open_db(): db = sqlite3.connect(DB_PATH) # speed-ish pragmas db.execute("PRAGMA journal_mode=WAL;") db.execute("PRAGMA synchronous=NORMAL;") return db def load_st_model(): # ST handles batching + GPU internally return SentenceTransformer( EMBED_MODEL_ID, model_kwargs={ "attn_implementation": "flash_attention_2", "device_map": "auto", "torch_dtype": "float16", }, tokenizer_kwargs={"padding_side": "left"}, ) def ensure_collection_exists(db, collection: str): row = db.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (f"vec_{collection}",) ).fetchone() return bool(row) def cmd_ingest(args): db = open_db() st_model = load_st_model() path = Path(args.file) if not path.exists(): print(f"File not found: {path}", file=sys.stderr) sys.exit(1) if args.rebuild and ensure_collection_exists(db, args.collection): db.executescript(f""" DROP TABLE IF EXISTS chunks_{args.collection}; DROP TABLE IF EXISTS fts_{args.collection}; DROP TABLE IF EXISTS vec_{args.collection}; """) stats = start_ingest( db, st_model, path=path, collection=args.collection, ) print(f"Ingested collection={args.collection} :: {stats}") db.close() def cmd_query(args): db = open_db() st_model = load_st_model() coll_ok = ensure_collection_exists(db, args.collection) if not coll_ok: print(f"Collection '{args.collection}' not found. Ingest first.", file=sys.stderr) sys.exit(2) if args.simple: results = vec_search(db, st_model, args.query, collection=args.collection, k=args.k_final) else: results = search_hybrid( db, st_model, args.query, collection=args.collection, k_vec=args.k_vec, k_bm25=args.k_bm25, k_ce=args.k_ce, k_final=args.k_final, use_mmr=args.mmr, mmr_lambda=args.mmr_lambda, ) for rid, txt, score in results: print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") db.close() def main(): ap = argparse.ArgumentParser(prog="rag") sub = ap.add_subparsers(dest="cmd", required=True) # ingest ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection") ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest") ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)") ap_ing.add_argument("--rebuild", action="store_true", help="Drop and recreate collection tables") ap_ing.set_defaults(func=cmd_ingest) # query ap_q = sub.add_parser("query", help="Query a collection") ap_q.add_argument("--collection", required=True, help="Collection name to search") ap_q.add_argument("--query", required=True, help="User query text") ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)") ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE") ap_q.add_argument("--mmr-lambda", type=float, default=0.7) ap_q.add_argument("--k-vec", type=int, default=50) ap_q.add_argument("--k-bm25", type=int, default=50) ap_q.add_argument("--k-ce", type=int, default=30) ap_q.add_argument("--k-final", type=int, default=10) ap_q.set_defaults(func=cmd_query) args = ap.parse_args() args.func(args) if __name__ == "__main__": main() >>>> rerank.py # rag/rerank.py import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from sentence_transformers import CrossEncoder from rag.constants import BATCH, RERANKER_ID # device: "cuda" | "cpu" | "mps" def get_rerank_model(): id = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM return CrossEncoder( id, device="cuda", max_length=384, model_kwargs={ # "attn_implementation":"flash_attention_2", "device_map":"auto", "dtype":torch.float16 }, tokenizer_kwargs={"padding_side": "left"} ) def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = BATCH): """ candidates: [(id, text), ...] returns: [(id, text, score)] sorted desc by score """ if not candidates: return [] ids, texts = zip(*candidates) # pairs = [(query, t) for t in texts] pairs = [(query, t[:1000]) for t in texts] scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) return ranked # tok = AutoTokenizer.from_pretrained(RERANKER_ID, use_fast=True, model_max_length=384) # ce = AutoModelForSequenceClassification.from_pretrained( # RERANKER_ID, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", # device_map="auto" # or load_in_8bit=True # ) # def ce_scores(query, texts, batch_size=16, max_length=384): # scores = [] # for i in range(0, len(texts), batch_size): # batch = texts[i:i+batch_size] # enc = tok([ (query, t[:1000]) for t in batch ], # padding=True, truncation=True, max_length=max_length, # return_tensors="pt").to(ce.device) # with torch.inference_mode(): # logits = ce(**enc).logits.squeeze(-1) # [B] # scores.extend(logits.float().cpu().tolist()) # return scores >>>> search.py import sqlite3 import numpy as np import torch import gc from typing import List, Tuple from sentence_transformers import CrossEncoder, SentenceTransformer from rag.constants import BATCH from rag.ingest import get_embed_model from rag.rerank import get_rerank_model, rerank_cross_encoder from rag.mmr import mmr, mmr2, embed_unit_np # if you added MMR; else remove from rag.db import vec_topk, bm25_topk # ) Query helper (cosine distance; operator may be <#> in sqlite-vec) def vec_search(db: sqlite3.Connection, col: str, qtext, k=5): model = get_embed_model() q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32") # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine rows = vec_topk(db, col, q, k) # db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40)) return [(rid, db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows] def search_hybrid( db: sqlite3.Connection, col: str, query: str, k_vec: int = 40, k_bm25: int = 40, k_ce: int = 30, # rerank this many k_final: int = 10, # return this many use_mmr: bool = False, mmr_lambda: float = 0.7 ): emodel = get_embed_model() # 1) embed query (unit, float32) for vec search + MMR print("loading") query_embeddings = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") emodel = None with torch.no_grad(): torch.cuda.empty_cache() gc.collect() print("memory should be free by now!!") # qbytes = q.tobytes() # 2) ANN + BM25 print("phase 2", col, query) vhits = vec_topk(db, col, query_embeddings, k_vec) # [(id, dist)] vh_ids = [i for (i, _) in vhits] bm_ids = bm25_topk(db, col, query, k_bm25) # # 3) merge ids [vector-first] merged, seen = [], set() for i in vh_ids + bm_ids: if i not in seen: merged.append(i); seen.add(i) if not merged: return [] # 4) fetch texts qmarks = ",".join("?"*len(merged)) cand = db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN ({qmarks})", merged).fetchall() ids, texts = zip(*cand) # 5) rerank print("loading reranking model") reranker = get_rerank_model() scores = reranker.predict([(query, t[:1000]) for t in texts], batch_size=BATCH) reranker =None with torch.no_grad(): torch.cuda.empty_cache() gc.collect() print("memory should be free by now!!") print("unloading reranking model") ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) if not use_mmr or len(ranked) <= k_final: return ranked[:min(k_ce, k_final)] # 6) MMR ce_ids = [i for (i,_,_) in ranked] ce_texts = [t for (_,t,_) in ranked] st_model = get_embed_model() ce_vecs = st_model.encode(ce_texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32") keep = set(mmr2(query_embeddings, ce_ids, ce_vecs, k=k_final, lamb=mmr_lambda)) return [r for r in ranked if r[0] in keep][:k_final] # # Hybrid + CE rerank query: # results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8) # for rid, txt, score in results: # print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n") # # # # # # # def search_hybrid_with_mmr(db, col, query, k_vec=50, k_bm25=50, k_ce=30, k_final=10, lamb=0.7): ranked = search_hybrid(db, col, query, k_vec, k_bm25, k_ce, k_ce) if not ranked: return [] ids = [i for (i,_,_) in ranked] texts = [t for (_,t,_) in ranked] st_model = get_embed_model() qvec = st_model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32") keep = set(mmr2(qvec, ids, cvecs, k=k_final, lamb=lamb)) return [r for r in ranked if r[0] in keep][:k_final] >>>> test.py from pathlib import Path from docling.document_converter import DocumentConverter from docling_core.transforms.chunker.hybrid_chunker import HybridChunker from transformers import AutoModel from sentence_transformers import SentenceTransformer from rag.db import get_db from rag.rerank import get_rerank_model import rag.ingest import rag.search import torch # converter = DocumentConverter() # chunker = HybridChunker() # file = Path("yek.md") # doc = converter.convert(file).document # chunk_iter = chunker.chunk(doc) # for chunk in chunk_iter: # print(chunk) # txt = chunker.contextualize(chunk) # print(txt) def t(): batch: list[str] = ["This son of a bitch has gone too far", "Fuck me baby please", "I'm hungry now", "Charlie Kirk is dead"] model = rag.ingest.get_embed_model() v = model.encode("pepeee", normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True) print("v") print(type(v)) print(v) print(v.dtype) print(v.device) # V = torch.cat([v], dim=0) # print("V") # print(type(V)) # print(V) # print(V.dtype) # print(V.device) # print("V_np") # V_idk = V.cpu().float() # when they were pytorch tensors # V = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB: # V_np = V.float().cpu().numpy().astype("float32") # DIM = V_np.shape[1] # db = sqlite3.connect("./rag.db") queries = [ "How was Shuihu zhuan received in early modern Japan?", "Edo-period readers’ image of Song Jiang / Liangshan outlaws", "Channels of transmission for Chinese vernacular fiction into Japan (kanbun kundoku, digests, translations)", "Role of woodblock prints/illustrations in mediating Chinese fiction", "Key Japanese scholars, writers, or publishers who popularized Chinese fiction", "Kyokutei Bakin’s engagement with Chinese vernacular narrative", "Santō Kyōden, gesaku, and Chinese models", "Kanzen chōaku (encourage good, punish evil) and Water Margin in Japan", "Moral ambivalence of outlaw heroes as discussed in the text", "Censorship or moral debates around reading Chinese fiction", "Translation strategies from vernacular Chinese to Japanese (furigana, kundoku, glossing)", "Paratexts: prefaces, commentaries, reader guidance apparatus", "Bibliographic details: editions, reprints, circulation networks", "How does this book challenge older narratives about Sino-Japanese literary influence?", "Methodology: sources, archives, limitations mentioned by the author", ] def t2(): db = get_db() # Hybrid + CE rerank query: for query in queries: print("query", query) print("-----------\n\n") # results = rag.search.search_hybrid(db, "muh", query, k_vec=50, k_bm25=50, k_final=8) # for rid, txt, score in results: # sim = score # print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n") results = rag.search.vec_search(db, "muh", query, k_vec=50, k_bm25=50, k_final=8) for rid, txt, score in results: sim = score print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n") t2() >>>> utils.py import re FTS_META_CHARS = r'''["'*()^+-]''' # include ? if you see issues def sanitize_query(q: str, *, allow_ops: bool = False) -> str: q = q.strip() if not q: return q if allow_ops: # escape stray double quotes inside, then wrap q = q.replace('"', '""') return f'"{q}"' # literal search: quote and escape special chars q = re.sub(r'"', '""', q) return f'"{q}"'