# rag/main.py import argparse, sqlite3, sys from pathlib import Path from sentence_transformers import SentenceTransformer from rag.ingest import start_ingest from rag.search import search_hybrid, search as vec_search DB_PATH = "./rag.db" EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" def open_db(): db = sqlite3.connect(DB_PATH) # speed-ish pragmas db.execute("PRAGMA journal_mode=WAL;") db.execute("PRAGMA synchronous=NORMAL;") return db def load_st_model(): # ST handles batching + GPU internally return SentenceTransformer( EMBED_MODEL_ID, model_kwargs={ "attn_implementation": "flash_attention_2", "device_map": "auto", "torch_dtype": "float16", }, tokenizer_kwargs={"padding_side": "left"}, ) def ensure_collection_exists(db, collection: str): row = db.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (f"vec_{collection}",) ).fetchone() return bool(row) def cmd_ingest(args): db = open_db() st_model = load_st_model() path = Path(args.file) if not path.exists(): print(f"File not found: {path}", file=sys.stderr) sys.exit(1) if args.rebuild and ensure_collection_exists(db, args.collection): db.executescript(f""" DROP TABLE IF EXISTS chunks_{args.collection}; DROP TABLE IF EXISTS fts_{args.collection}; DROP TABLE IF EXISTS vec_{args.collection}; """) stats = start_ingest( db, st_model, path=path, collection=args.collection, ) print(f"Ingested collection={args.collection} :: {stats}") db.close() def cmd_query(args): db = open_db() st_model = load_st_model() coll_ok = ensure_collection_exists(db, args.collection) if not coll_ok: print(f"Collection '{args.collection}' not found. Ingest first.", file=sys.stderr) sys.exit(2) if args.simple: results = vec_search(db, st_model, args.query, collection=args.collection, k=args.k_final) else: results = search_hybrid( db, st_model, args.query, collection=args.collection, k_vec=args.k_vec, k_bm25=args.k_bm25, k_ce=args.k_ce, k_final=args.k_final, use_mmr=args.mmr, mmr_lambda=args.mmr_lambda, ) for rid, txt, score in results: print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") db.close() def main(): ap = argparse.ArgumentParser(prog="rag") sub = ap.add_subparsers(dest="cmd", required=True) # ingest ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection") ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest") ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)") ap_ing.add_argument("--rebuild", action="store_true", help="Drop and recreate collection tables") ap_ing.set_defaults(func=cmd_ingest) # query ap_q = sub.add_parser("query", help="Query a collection") ap_q.add_argument("--collection", required=True, help="Collection name to search") ap_q.add_argument("--query", required=True, help="User query text") ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)") ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE") ap_q.add_argument("--mmr-lambda", type=float, default=0.7) ap_q.add_argument("--k-vec", type=int, default=50) ap_q.add_argument("--k-bm25", type=int, default=50) ap_q.add_argument("--k-ce", type=int, default=30) ap_q.add_argument("--k-final", type=int, default=10) ap_q.set_defaults(func=cmd_query) args = ap.parse_args() args.func(args) if __name__ == "__main__": main()