summaryrefslogtreecommitdiff
path: root/rag/ingest.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/ingest.py')
-rw-r--r--rag/ingest.py61
1 files changed, 61 insertions, 0 deletions
diff --git a/rag/ingest.py b/rag/ingest.py
new file mode 100644
index 0000000..d17690a
--- /dev/null
+++ b/rag/ingest.py
@@ -0,0 +1,61 @@
+# from rag.rerank import search_hybrid
+import sqlite3
+import torch
+from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
+from pathlib import Path
+from docling.document_converter import DocumentConverter
+from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
+from rag.db import get_db, init_schema, store_chunks
+from sentence_transformers import SentenceTransformer
+
+EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
+RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B"
+MAX_TOKENS = 600
+BATCH = 16
+
+
+def get_embed_model():
+ return SentenceTransformer(
+ "Qwen/Qwen3-Embedding-8B",
+ model_kwargs={
+ # "trust_remote_code":True,
+ "attn_implementation":"flash_attention_2",
+ "device_map":"auto",
+ "dtype":torch.float16
+ }, tokenizer_kwargs={"padding_side": "left"}
+ )
+
+
+def parse_and_chunk(source: Path, model: SentenceTransformer)-> list[str]:
+ tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS)
+
+ converter = DocumentConverter()
+ doc = converter.convert(source)
+ chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True)
+ out = []
+ for ch in chunker.chunk(doc.document):
+ txt = chunker.contextualize(ch)
+ if txt.strip():
+ out.append(txt)
+ return out
+
+def embed_many(model: SentenceTransformer, texts: list[str]):
+ V_np = model.encode(texts,
+ batch_size=BATCH,
+ normalize_embeddings=True,
+ convert_to_numpy=True,
+ show_progress_bar=True)
+ return V_np.astype("float32")
+
+
+def start_ingest(db: sqlite3.Connection, path: Path):
+ model = get_embed_model()
+ chunks = parse_and_chunk(path, model)
+ V_np = embed_many(model, chunks)
+ DIM = V_np.shape[1]
+ db = get_db()
+ init_schema(db, DIM)
+ store_chunks(db, chunks, V_np)
+ # TODO some try catch?
+ return True
+# float32/fp16 on CPU? ensure float32 for DB: