diff options
Diffstat (limited to 'rag/main.py')
-rw-r--r-- | rag/main.py | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/rag/main.py b/rag/main.py new file mode 100644 index 0000000..221adfa --- /dev/null +++ b/rag/main.py @@ -0,0 +1,71 @@ +import sys +import argparse +from pathlib import Path +from rag.ingest import start_ingest +from rag.search import search_hybrid, vec_search +from rag.db import get_db +# your Docling chunker imports… + +def main(): + ap = argparse.ArgumentParser(prog="rag") + sub = ap.add_subparsers(dest="cmd", required=True) + + # ingest + ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection") + ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest") + ap_ing.set_defaults(func=cmd_ingest) + + # query + ap_q = sub.add_parser("query", help="Query a collection") + ap_q.add_argument("--query", required=True, help="User query text") + ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)") + ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE") + ap_q.add_argument("--mmr-lambda", type=float, default=0.7) + ap_q.add_argument("--k-vec", type=int, default=50) + ap_q.add_argument("--k-bm25", type=int, default=50) + ap_q.add_argument("--k-ce", type=int, default=30) + ap_q.add_argument("--k-final", type=int, default=10) + ap_q.set_defaults(func=cmd_query) + + args = ap.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() + + + + +def cmd_ingest(args): + path = Path(args.file) + + if not path.exists(): + print(f"File not found: {path}", file=sys.stderr) + sys.exit(1) + + + db = get_db() + stats = start_ingest(db,path) + print(f"Ingested file={args.file} :: {stats}") + +def cmd_query(args): + + db = get_db() + if args.simple: + results = vec_search(db, args.query, k=args.k_final) + else: + results = search_hybrid( + db, args.query, + k_vec=args.k_vec, + k_bm25=args.k_bm25, + k_ce=args.k_ce, + k_final=args.k_final, + use_mmr=args.mmr, + mmr_lambda=args.mmr_lambda, + ) + + for rid, txt, score in results: + print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") + + db.close() |