summaryrefslogtreecommitdiff
path: root/rag/ingest.py
diff options
context:
space:
mode:
authorpolwex <polwex@sortug.com>2025-09-24 23:38:36 +0700
committerpolwex <polwex@sortug.com>2025-09-24 23:38:36 +0700
commit734b89570040e97f0c7743c4c0bc28e30a3cd4ee (patch)
tree7142d9f37908138c38d0ade066e960c3a1c69f5d /rag/ingest.py
parent57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (diff)
Diffstat (limited to 'rag/ingest.py')
-rw-r--r--rag/ingest.py17
1 files changed, 6 insertions, 11 deletions
diff --git a/rag/ingest.py b/rag/ingest.py
index d17690a..5def23d 100644
--- a/rag/ingest.py
+++ b/rag/ingest.py
@@ -7,16 +7,13 @@ from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from rag.db import get_db, init_schema, store_chunks
from sentence_transformers import SentenceTransformer
+from rag.constants import MAX_TOKENS, BATCH, EMBED_MODEL_ID
-EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
-RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B"
-MAX_TOKENS = 600
-BATCH = 16
def get_embed_model():
return SentenceTransformer(
- "Qwen/Qwen3-Embedding-8B",
+ EMBED_MODEL_ID,
model_kwargs={
# "trust_remote_code":True,
"attn_implementation":"flash_attention_2",
@@ -48,14 +45,12 @@ def embed_many(model: SentenceTransformer, texts: list[str]):
return V_np.astype("float32")
-def start_ingest(db: sqlite3.Connection, path: Path):
- model = get_embed_model()
+def start_ingest(db: sqlite3.Connection, model: SentenceTransformer | None, collection: str, path: Path):
+ if model is None:
+ model = get_embed_model()
chunks = parse_and_chunk(path, model)
V_np = embed_many(model, chunks)
- DIM = V_np.shape[1]
- db = get_db()
- init_schema(db, DIM)
- store_chunks(db, chunks, V_np)
+ store_chunks(db, collection, chunks, V_np)
# TODO some try catch?
return True
# float32/fp16 on CPU? ensure float32 for DB: