diff options
Diffstat (limited to 'rag/main.py')
-rw-r--r-- | rag/main.py | 106 |
1 files changed, 69 insertions, 37 deletions
diff --git a/rag/main.py b/rag/main.py index 221adfa..bbd549a 100644 --- a/rag/main.py +++ b/rag/main.py @@ -1,10 +1,74 @@ import sys import argparse from pathlib import Path -from rag.ingest import start_ingest +from rag.constants import EMBED_MODEL_ID +from rag.ingest import get_embed_model, start_ingest from rag.search import search_hybrid, vec_search -from rag.db import get_db -# your Docling chunker imports… +from rag.db import get_db, check_db, check_db2, init_schema + + +def valid_collection(col: str) -> bool: + # TODO must have less than 9 characters and be ascii, no spaces + return True + + +def cmd_ingest(args): + path = Path(args.file) + if not valid_collection(args.collection): + print(f"Collection name invalid: {args.collection}", file=sys.stderr) + sys.exit(1) + + + if not path.exists(): + print(f"File not found: {path}", file=sys.stderr) + sys.exit(1) + + + db = get_db() + if not check_db2(db, args.collection): + model = get_embed_model() + dim = model.get_sentence_embedding_dimension() + if dim is None: + sys.exit(1) + + # TODO Try catch in here, tell the user if it crashes + init_schema(db, args.collection, dim, EMBED_MODEL_ID, EMBED_MODEL_ID, True, 'idk') + stats = start_ingest(db, model, args.collection, path) + else: + stats = start_ingest(db, None, args.collection, path) + print(f"Ingested file={args.file} :: {stats}") + +def cmd_query(args): + if not valid_collection(args.collection): + print(f"Collection name invalid: {args.collection}", file=sys.stderr) + sys.exit(1) + db = get_db() + if not check_db2(db, args.collection): + print(f"Collection name not in DB, what are you searching: {args.collection}", file=sys.stderr) + sys.exit(1) + + if args.simple: + results = vec_search(db, args.collection, args.query, k=args.k_final) + else: + results = search_hybrid(db, + args.collection, + args.query, + k_vec=args.k_vec, + k_bm25=args.k_bm25, + k_ce=args.k_ce, + k_final=args.k_final, + use_mmr=args.mmr, + mmr_lambda=args.mmr_lambda, + ) + + for rid, txt, score in results: + print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") + + db.close() + + + + def main(): ap = argparse.ArgumentParser(prog="rag") @@ -13,10 +77,12 @@ def main(): # ingest ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection") ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest") + ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)") ap_ing.set_defaults(func=cmd_ingest) # query ap_q = sub.add_parser("query", help="Query a collection") + ap_q.add_argument("--collection", required=True, help="Collection name to search") ap_q.add_argument("--query", required=True, help="User query text") ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)") ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE") @@ -35,37 +101,3 @@ if __name__ == "__main__": main() - - -def cmd_ingest(args): - path = Path(args.file) - - if not path.exists(): - print(f"File not found: {path}", file=sys.stderr) - sys.exit(1) - - - db = get_db() - stats = start_ingest(db,path) - print(f"Ingested file={args.file} :: {stats}") - -def cmd_query(args): - - db = get_db() - if args.simple: - results = vec_search(db, args.query, k=args.k_final) - else: - results = search_hybrid( - db, args.query, - k_vec=args.k_vec, - k_bm25=args.k_bm25, - k_ce=args.k_ce, - k_final=args.k_final, - use_mmr=args.mmr, - mmr_lambda=args.mmr_lambda, - ) - - for rid, txt, score in results: - print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") - - db.close() |