summaryrefslogtreecommitdiff
path: root/rag/ingest.py
blob: d17690a826906f3181d9d6077ac4bf8b6b28998f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# from rag.rerank import search_hybrid
import sqlite3
import torch
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
from pathlib import Path
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from rag.db import get_db, init_schema, store_chunks
from sentence_transformers import SentenceTransformer

EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B"
MAX_TOKENS = 600
BATCH = 16
  

def get_embed_model():
  return SentenceTransformer(
      "Qwen/Qwen3-Embedding-8B",
      model_kwargs={
       # "trust_remote_code":True,
       "attn_implementation":"flash_attention_2",
      "device_map":"auto",
      "dtype":torch.float16
    }, tokenizer_kwargs={"padding_side": "left"}
  )
  

def parse_and_chunk(source: Path, model: SentenceTransformer)-> list[str]:
  tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS)

  converter = DocumentConverter()
  doc = converter.convert(source)
  chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True)
  out = []
  for ch in chunker.chunk(doc.document):
    txt = chunker.contextualize(ch)
    if txt.strip():
      out.append(txt)
  return out

def embed_many(model: SentenceTransformer, texts: list[str]):
  V_np = model.encode(texts,
                   batch_size=BATCH,
                    normalize_embeddings=True,
                    convert_to_numpy=True,
                    show_progress_bar=True)
  return V_np.astype("float32")
  

def start_ingest(db: sqlite3.Connection, path: Path):
  model = get_embed_model()
  chunks = parse_and_chunk(path, model)
  V_np = embed_many(model, chunks)
  DIM = V_np.shape[1]
  db = get_db()
  init_schema(db, DIM)
  store_chunks(db, chunks, V_np)
  # TODO some try catch?
  return True
# float32/fp16 on CPU? ensure float32 for DB: