diff options
Diffstat (limited to 'rag/nmain.py')
-rw-r--r-- | rag/nmain.py | 118 |
1 files changed, 118 insertions, 0 deletions
diff --git a/rag/nmain.py b/rag/nmain.py new file mode 100644 index 0000000..aa67dea --- /dev/null +++ b/rag/nmain.py @@ -0,0 +1,118 @@ +# rag/main.py +import argparse, sqlite3, sys +from pathlib import Path + +from sentence_transformers import SentenceTransformer +from rag.ingest import start_ingest +from rag.search import search_hybrid, search as vec_search + +DB_PATH = "./rag.db" +EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" + +def open_db(): + db = sqlite3.connect(DB_PATH) + # speed-ish pragmas + db.execute("PRAGMA journal_mode=WAL;") + db.execute("PRAGMA synchronous=NORMAL;") + return db + +def load_st_model(): + # ST handles batching + GPU internally + return SentenceTransformer( + EMBED_MODEL_ID, + model_kwargs={ + "attn_implementation": "flash_attention_2", + "device_map": "auto", + "torch_dtype": "float16", + }, + tokenizer_kwargs={"padding_side": "left"}, + ) + +def ensure_collection_exists(db, collection: str): + row = db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (f"vec_{collection}",) + ).fetchone() + return bool(row) + +def cmd_ingest(args): + db = open_db() + st_model = load_st_model() + path = Path(args.file) + + if not path.exists(): + print(f"File not found: {path}", file=sys.stderr) + sys.exit(1) + + if args.rebuild and ensure_collection_exists(db, args.collection): + db.executescript(f""" + DROP TABLE IF EXISTS chunks_{args.collection}; + DROP TABLE IF EXISTS fts_{args.collection}; + DROP TABLE IF EXISTS vec_{args.collection}; + """) + + stats = start_ingest( + db, st_model, + path=path, + collection=args.collection, + ) + print(f"Ingested collection={args.collection} :: {stats}") + db.close() + +def cmd_query(args): + db = open_db() + st_model = load_st_model() + + coll_ok = ensure_collection_exists(db, args.collection) + if not coll_ok: + print(f"Collection '{args.collection}' not found. Ingest first.", file=sys.stderr) + sys.exit(2) + + if args.simple: + results = vec_search(db, st_model, args.query, collection=args.collection, k=args.k_final) + else: + results = search_hybrid( + db, st_model, args.query, + collection=args.collection, + k_vec=args.k_vec, + k_bm25=args.k_bm25, + k_ce=args.k_ce, + k_final=args.k_final, + use_mmr=args.mmr, + mmr_lambda=args.mmr_lambda, + ) + + for rid, txt, score in results: + print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") + + db.close() + +def main(): + ap = argparse.ArgumentParser(prog="rag") + sub = ap.add_subparsers(dest="cmd", required=True) + + # ingest + ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection") + ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest") + ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)") + ap_ing.add_argument("--rebuild", action="store_true", help="Drop and recreate collection tables") + ap_ing.set_defaults(func=cmd_ingest) + + # query + ap_q = sub.add_parser("query", help="Query a collection") + ap_q.add_argument("--collection", required=True, help="Collection name to search") + ap_q.add_argument("--query", required=True, help="User query text") + ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)") + ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE") + ap_q.add_argument("--mmr-lambda", type=float, default=0.7) + ap_q.add_argument("--k-vec", type=int, default=50) + ap_q.add_argument("--k-bm25", type=int, default=50) + ap_q.add_argument("--k-ce", type=int, default=30) + ap_q.add_argument("--k-final", type=int, default=10) + ap_q.set_defaults(func=cmd_query) + + args = ap.parse_args() + args.func(args) + +if __name__ == "__main__": + main() |