diff options
author | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
---|---|---|
committer | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
commit | 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (patch) | |
tree | 1a7556927bed94377630d33dd29c3bf07d159619 /rag/ingest.py |
init
Diffstat (limited to 'rag/ingest.py')
-rw-r--r-- | rag/ingest.py | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/rag/ingest.py b/rag/ingest.py new file mode 100644 index 0000000..d17690a --- /dev/null +++ b/rag/ingest.py @@ -0,0 +1,61 @@ +# from rag.rerank import search_hybrid +import sqlite3 +import torch +from docling_core.transforms.chunker.hybrid_chunker import HybridChunker +from pathlib import Path +from docling.document_converter import DocumentConverter +from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer +from rag.db import get_db, init_schema, store_chunks +from sentence_transformers import SentenceTransformer + +EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" +RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B" +MAX_TOKENS = 600 +BATCH = 16 + + +def get_embed_model(): + return SentenceTransformer( + "Qwen/Qwen3-Embedding-8B", + model_kwargs={ + # "trust_remote_code":True, + "attn_implementation":"flash_attention_2", + "device_map":"auto", + "dtype":torch.float16 + }, tokenizer_kwargs={"padding_side": "left"} + ) + + +def parse_and_chunk(source: Path, model: SentenceTransformer)-> list[str]: + tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS) + + converter = DocumentConverter() + doc = converter.convert(source) + chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True) + out = [] + for ch in chunker.chunk(doc.document): + txt = chunker.contextualize(ch) + if txt.strip(): + out.append(txt) + return out + +def embed_many(model: SentenceTransformer, texts: list[str]): + V_np = model.encode(texts, + batch_size=BATCH, + normalize_embeddings=True, + convert_to_numpy=True, + show_progress_bar=True) + return V_np.astype("float32") + + +def start_ingest(db: sqlite3.Connection, path: Path): + model = get_embed_model() + chunks = parse_and_chunk(path, model) + V_np = embed_many(model, chunks) + DIM = V_np.shape[1] + db = get_db() + init_schema(db, DIM) + store_chunks(db, chunks, V_np) + # TODO some try catch? + return True +# float32/fp16 on CPU? ensure float32 for DB: |