diff options
author | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
---|---|---|
committer | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
commit | 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (patch) | |
tree | 1a7556927bed94377630d33dd29c3bf07d159619 /yek.md |
init
Diffstat (limited to 'yek.md')
-rw-r--r-- | yek.md | 316 |
1 files changed, 316 insertions, 0 deletions
@@ -0,0 +1,316 @@ +>>>> __init__.py + +>>>> main.py +# from rag.rerank import search_hybrid +import torch +import sqlite3 +import sqlite_vec +from docling_core.transforms.chunker.hybrid_chunker import HybridChunker +from pathlib import Path +from docling.document_converter import DocumentConverter +from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer + +from sentence_transformers import SentenceTransformer + +EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" +RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B" +MAX_TOKENS = 600 +BATCH = 16 + +model = SentenceTransformer( + "Qwen/Qwen3-Embedding-8B", + model_kwargs={ + # "trust_remote_code":True, + "attn_implementation":"flash_attention_2", + "device_map":"auto", + "dtype":torch.float16 + }, tokenizer_kwargs={"padding_side": "left"} +) +tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS) + +def parse_and_chunk()-> list[str]: + source = Path("./japan-water-margin.pdf") + + converter = DocumentConverter() + doc = converter.convert(source) + chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True) + out = [] + for ch in chunker.chunk(doc.document): + txt = chunker.contextualize(ch) + if txt.strip(): + out.append(txt) + return out + +def embed_many(texts: list[str]): + V_np = model.encode(texts, + batch_size=BATCH, + normalize_embeddings=True, + convert_to_numpy=True, + show_progress_bar=True) + return V_np.astype("float32") + + + +chunks =parse_and_chunk() +V_np = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB: +DIM = V_np.shape[1] +db = sqlite3.connect("./rag.db") + + + +def init_db(DIM: str): + db.enable_load_extension(True) + sqlite_vec.load(db) + db.enable_load_extension(False) + db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])") + db.execute(''' + CREATE TABLE IF NOT EXISTS chunks ( + id INTEGER PRIMARY KEY, + text TEXT + )''') + db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)") + +def store_chunks(chunks: list[str], V_np): + assert len(chunks) == len(V_np) + db.execute("BEGIN") + db.executemany(''' + INSERT INTO chunks(id, text) VALUES (?, ?) + ''', list(enumerate(chunks, start=1))) + db.executemany( + "INSERT INTO vec(rowid, embedding) VALUES (?, ?)", + [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))] + ) + db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) + db.commit() + +# 5) Query helper (cosine distance; operator may be <#> in sqlite-vec) +def search(qtext, k=5): + q = model.encode([qtext], normalize=True, convert_to_numpy=True).astype("float32") + # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine + rows = db.execute( + "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", + (memoryview(q.tobytes()), k) + ).fetchall() + +# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40)) + return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows] + + + +init_db(DIM) +store_chunks(chunks, V_np) + + +# # Hybrid + CE rerank query: +# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8) +# for rid, txt, score in results: +# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n") + +>>>> mmr.py +import numpy as np + +def cosine(a, b): return float(np.dot(a, b)) + +def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7): + """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids""" + selected, selected_idx = [], [] + remaining = list(range(len(cand_ids))) + + # seed with the most relevant + best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i])) + selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0) + + while remaining and len(selected) < k: + def mmr_score(i): + rel = cosine(query_vec, cand_vecs[i]) + red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0 + return lamb * rel - (1.0 - lamb) * red + nxt = max(remaining, key=mmr_score) + selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt) + return selected + +>>>> rag.db + +>>>> rerank.py +import torch +import sqlite3 +import sqlite_vec +from transformers import AutoTokenizer, AutoModelForSequenceClassification + +# ----------------------------- +# RERANKER (cross-encoder) +# ----------------------------- +RERANKER_ID = "BAAI/bge-reranker-base" # small & solid +# RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # small & solid +_rr_tok = AutoTokenizer.from_pretrained(RERANKER_ID) +_rr_mod = AutoModelForSequenceClassification.from_pretrained( + RERANKER_ID, device_map="auto", dtype=torch.bfloat16 +) + +def rerank_cross_encoder(query: str, candidates: list[tuple[int, str]], batch_size: int = 32) -> list[tuple[int, str, float]]: + """ + candidates: [(id, text), ...] + returns: [(id, text, score)] sorted desc by score + """ + if not candidates: + return [] + ids, texts = zip(*candidates) + scores = [] + for i in range(0, len(texts), batch_size): + batch = texts[i:i+batch_size] + enc = _rr_tok(list(zip([query]*len(batch), batch)), + padding=True, truncation=True, max_length=512, + return_tensors="pt").to(_rr_mod.device) + with torch.no_grad(): + logits = _rr_mod(**enc).logits.squeeze(-1) # [B] + scores.extend(logits.float().cpu().tolist()) + ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) + return ranked + +# ----------------------------- +# DB: vec + chunks + FTS5 +# ----------------------------- +db = sqlite3.connect("./rag.db") + +def init_db(dim: int): + db.enable_load_extension(True) + sqlite_vec.load(db) # loads vec0 + db.enable_load_extension(False) + + # Vector table (cosine expects float32 blob) + db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{dim}])") + + # Raw chunks + db.execute(""" + CREATE TABLE IF NOT EXISTS chunks( + id INTEGER PRIMARY KEY, + text TEXT + ) + """) + + # FTS5 over chunks.text (contentless mode keeps it simple) + db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)") + db.execute("CREATE INDEX IF NOT EXISTS idx_chunks_id ON chunks(id)") + +def store_chunks(chunks: list[str], V_np): + # chunks: list[str], V_np: np.ndarray float32 [N, D] + assert len(chunks) == len(V_np), "mismatch chunks vs vectors" + + db.execute("BEGIN") + db.executemany("INSERT INTO chunks(id, text) VALUES (?,?)", + list(enumerate(chunks, start=1))) + db.executemany("INSERT INTO vec(rowid, embedding) VALUES (?,?)", + [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]) + # FTS content + db.executemany("INSERT INTO fts(rowid, text) VALUES (?,?)", + list(enumerate(chunks, start=1))) + db.commit() + +# ----------------------------- +# SEARCH: hybrid → rerank +# ----------------------------- +def _vec_topk(qvec_f32: bytes, k: int): + # cosine distance: smaller is better + rows = db.execute( + "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", + (memoryview(qvec_f32), k) + ).fetchall() + return [(rid, float(dist)) for (rid, dist) in rows] + +def _bm25_topk(query: str, k: int): + # note: FTS5 MATCH syntax, tokenizer defaults ok for English/legal + rows = db.execute( + "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", + (query, k) + ).fetchall() + return [rid for (rid,) in rows] + +def _fetch_text(ids: list[int]) -> list[tuple[int, str]]: + if not ids: + return [] + qmarks = ",".join("?"*len(ids)) + rows = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", ids).fetchall() + by_id = {i:t for (i,t) in rows} + # keep input order + return [(i, by_id[i]) for i in ids if i in by_id] + +def search_hybrid(query: str, embed_model, k_vec: int = 40, k_bm25: int = 40, k_final: int = 10): + # 1) embed query (unit-norm) with same model you used for chunks + with torch.no_grad(): + q = embed_model.encode([query], normalize=True).to("cpu").to(torch.float32).numpy()[0] + qbytes = q.tobytes() + + # 2) candidate pools + vec_hits = _vec_topk(qbytes, k_vec) # [(id, dist)] + vec_ids = [i for (i, _) in vec_hits] + bm25_ids = _bm25_topk(query, k_bm25) + + # 3) merge (preserve vector order first) + seen = set() + merged_ids = [] + for i in vec_ids + bm25_ids: + if i not in seen: + merged_ids.append(i); seen.add(i) + + # 4) fetch texts, then rerank with cross-encoder + cand = _fetch_text(merged_ids) + ranked = rerank_cross_encoder(query, cand) + + # 5) return top k_final: [(id, text, ce_score)] + return ranked[:k_final] + +>>>> test.py +from pathlib import Path +from docling.document_converter import DocumentConverter +from docling_core.transforms.chunker.hybrid_chunker import HybridChunker +from transformers import AutoModel +from sentence_transformers import SentenceTransformer +import torch + +# converter = DocumentConverter() +# chunker = HybridChunker() +# file = Path("yek.md") +# doc = converter.convert(file).document +# chunk_iter = chunker.chunk(doc) +# for chunk in chunk_iter: +# print(chunk) +# txt = chunker.contextualize(chunk) +# print(txt) + + + +model = SentenceTransformer( + "Qwen/Qwen3-Embedding-8B", + model_kwargs={ + # "trust_remote_code":True, + "attn_implementation":"flash_attention_2", + "device_map":"auto", + "dtype":torch.float16 + }, tokenizer_kwargs={"padding_side": "left"} +) + +batch: list[str] = ["This son of a bitch has gone too far", "Fuck me baby please", "I'm hungry now", "Charlie Kirk is dead"] + +v = model.encode("pepeee", normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True) +print("v") +print(type(v)) +print(v) +print(v.dtype) +print(v.device) + +# V = torch.cat([v], dim=0) +# print("V") +# print(type(V)) +# print(V) +# print(V.dtype) +# print(V.device) +# print("V_np") +# V_idk = V.cpu().float() + +# when they were pytorch tensors +# V = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB: +# V_np = V.float().cpu().numpy().astype("float32") +# DIM = V_np.shape[1] +# db = sqlite3.connect("./rag.db") + + |