diff options
Diffstat (limited to 'rag/ingest.py')
-rw-r--r-- | rag/ingest.py | 17 |
1 files changed, 6 insertions, 11 deletions
diff --git a/rag/ingest.py b/rag/ingest.py index d17690a..5def23d 100644 --- a/rag/ingest.py +++ b/rag/ingest.py @@ -7,16 +7,13 @@ from docling.document_converter import DocumentConverter from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer from rag.db import get_db, init_schema, store_chunks from sentence_transformers import SentenceTransformer +from rag.constants import MAX_TOKENS, BATCH, EMBED_MODEL_ID -EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" -RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B" -MAX_TOKENS = 600 -BATCH = 16 def get_embed_model(): return SentenceTransformer( - "Qwen/Qwen3-Embedding-8B", + EMBED_MODEL_ID, model_kwargs={ # "trust_remote_code":True, "attn_implementation":"flash_attention_2", @@ -48,14 +45,12 @@ def embed_many(model: SentenceTransformer, texts: list[str]): return V_np.astype("float32") -def start_ingest(db: sqlite3.Connection, path: Path): - model = get_embed_model() +def start_ingest(db: sqlite3.Connection, model: SentenceTransformer | None, collection: str, path: Path): + if model is None: + model = get_embed_model() chunks = parse_and_chunk(path, model) V_np = embed_many(model, chunks) - DIM = V_np.shape[1] - db = get_db() - init_schema(db, DIM) - store_chunks(db, chunks, V_np) + store_chunks(db, collection, chunks, V_np) # TODO some try catch? return True # float32/fp16 on CPU? ensure float32 for DB: |