summaryrefslogtreecommitdiff
path: root/rag/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/main.py')
-rw-r--r--rag/main.py106
1 files changed, 69 insertions, 37 deletions
diff --git a/rag/main.py b/rag/main.py
index 221adfa..bbd549a 100644
--- a/rag/main.py
+++ b/rag/main.py
@@ -1,10 +1,74 @@
import sys
import argparse
from pathlib import Path
-from rag.ingest import start_ingest
+from rag.constants import EMBED_MODEL_ID
+from rag.ingest import get_embed_model, start_ingest
from rag.search import search_hybrid, vec_search
-from rag.db import get_db
-# your Docling chunker imports…
+from rag.db import get_db, check_db, check_db2, init_schema
+
+
+def valid_collection(col: str) -> bool:
+ # TODO must have less than 9 characters and be ascii, no spaces
+ return True
+
+
+def cmd_ingest(args):
+ path = Path(args.file)
+ if not valid_collection(args.collection):
+ print(f"Collection name invalid: {args.collection}", file=sys.stderr)
+ sys.exit(1)
+
+
+ if not path.exists():
+ print(f"File not found: {path}", file=sys.stderr)
+ sys.exit(1)
+
+
+ db = get_db()
+ if not check_db2(db, args.collection):
+ model = get_embed_model()
+ dim = model.get_sentence_embedding_dimension()
+ if dim is None:
+ sys.exit(1)
+
+ # TODO Try catch in here, tell the user if it crashes
+ init_schema(db, args.collection, dim, EMBED_MODEL_ID, EMBED_MODEL_ID, True, 'idk')
+ stats = start_ingest(db, model, args.collection, path)
+ else:
+ stats = start_ingest(db, None, args.collection, path)
+ print(f"Ingested file={args.file} :: {stats}")
+
+def cmd_query(args):
+ if not valid_collection(args.collection):
+ print(f"Collection name invalid: {args.collection}", file=sys.stderr)
+ sys.exit(1)
+ db = get_db()
+ if not check_db2(db, args.collection):
+ print(f"Collection name not in DB, what are you searching: {args.collection}", file=sys.stderr)
+ sys.exit(1)
+
+ if args.simple:
+ results = vec_search(db, args.collection, args.query, k=args.k_final)
+ else:
+ results = search_hybrid(db,
+ args.collection,
+ args.query,
+ k_vec=args.k_vec,
+ k_bm25=args.k_bm25,
+ k_ce=args.k_ce,
+ k_final=args.k_final,
+ use_mmr=args.mmr,
+ mmr_lambda=args.mmr_lambda,
+ )
+
+ for rid, txt, score in results:
+ print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n")
+
+ db.close()
+
+
+
+
def main():
ap = argparse.ArgumentParser(prog="rag")
@@ -13,10 +77,12 @@ def main():
# ingest
ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection")
ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest")
+ ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)")
ap_ing.set_defaults(func=cmd_ingest)
# query
ap_q = sub.add_parser("query", help="Query a collection")
+ ap_q.add_argument("--collection", required=True, help="Collection name to search")
ap_q.add_argument("--query", required=True, help="User query text")
ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)")
ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE")
@@ -35,37 +101,3 @@ if __name__ == "__main__":
main()
-
-
-def cmd_ingest(args):
- path = Path(args.file)
-
- if not path.exists():
- print(f"File not found: {path}", file=sys.stderr)
- sys.exit(1)
-
-
- db = get_db()
- stats = start_ingest(db,path)
- print(f"Ingested file={args.file} :: {stats}")
-
-def cmd_query(args):
-
- db = get_db()
- if args.simple:
- results = vec_search(db, args.query, k=args.k_final)
- else:
- results = search_hybrid(
- db, args.query,
- k_vec=args.k_vec,
- k_bm25=args.k_bm25,
- k_ce=args.k_ce,
- k_final=args.k_final,
- use_mmr=args.mmr,
- mmr_lambda=args.mmr_lambda,
- )
-
- for rid, txt, score in results:
- print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n")
-
- db.close()