1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
|
# from rag.rerank import search_hybrid
import sqlite3
import torch
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
from pathlib import Path
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from rag.db import get_db, init_schema, store_chunks
from sentence_transformers import SentenceTransformer
from rag.constants import MAX_TOKENS, BATCH, EMBED_MODEL_ID
def get_embed_model():
return SentenceTransformer(
EMBED_MODEL_ID,
model_kwargs={
# "trust_remote_code":True,
"attn_implementation":"flash_attention_2",
"device_map":"auto",
"dtype":torch.float16
}, tokenizer_kwargs={"padding_side": "left"}
)
def parse_and_chunk(source: Path, model: SentenceTransformer)-> list[str]:
tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS)
converter = DocumentConverter()
doc = converter.convert(source)
chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True)
out = []
for ch in chunker.chunk(doc.document):
txt = chunker.contextualize(ch)
if txt.strip():
out.append(txt)
return out
def embed_many(model: SentenceTransformer, texts: list[str]):
V_np = model.encode(texts,
batch_size=BATCH,
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=True)
return V_np.astype("float32")
def start_ingest(db: sqlite3.Connection, model: SentenceTransformer | None, collection: str, path: Path):
if model is None:
model = get_embed_model()
chunks = parse_and_chunk(path, model)
V_np = embed_many(model, chunks)
store_chunks(db, collection, chunks, V_np)
# TODO some try catch?
return True
# float32/fp16 on CPU? ensure float32 for DB:
|