summaryrefslogtreecommitdiff
path: root/yek.md
diff options
context:
space:
mode:
Diffstat (limited to 'yek.md')
-rw-r--r--yek.md316
1 files changed, 316 insertions, 0 deletions
diff --git a/yek.md b/yek.md
new file mode 100644
index 0000000..d71c21d
--- /dev/null
+++ b/yek.md
@@ -0,0 +1,316 @@
+>>>> __init__.py
+
+>>>> main.py
+# from rag.rerank import search_hybrid
+import torch
+import sqlite3
+import sqlite_vec
+from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
+from pathlib import Path
+from docling.document_converter import DocumentConverter
+from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
+
+from sentence_transformers import SentenceTransformer
+
+EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
+RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B"
+MAX_TOKENS = 600
+BATCH = 16
+
+model = SentenceTransformer(
+ "Qwen/Qwen3-Embedding-8B",
+ model_kwargs={
+ # "trust_remote_code":True,
+ "attn_implementation":"flash_attention_2",
+ "device_map":"auto",
+ "dtype":torch.float16
+ }, tokenizer_kwargs={"padding_side": "left"}
+)
+tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS)
+
+def parse_and_chunk()-> list[str]:
+ source = Path("./japan-water-margin.pdf")
+
+ converter = DocumentConverter()
+ doc = converter.convert(source)
+ chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True)
+ out = []
+ for ch in chunker.chunk(doc.document):
+ txt = chunker.contextualize(ch)
+ if txt.strip():
+ out.append(txt)
+ return out
+
+def embed_many(texts: list[str]):
+ V_np = model.encode(texts,
+ batch_size=BATCH,
+ normalize_embeddings=True,
+ convert_to_numpy=True,
+ show_progress_bar=True)
+ return V_np.astype("float32")
+
+
+
+chunks =parse_and_chunk()
+V_np = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB:
+DIM = V_np.shape[1]
+db = sqlite3.connect("./rag.db")
+
+
+
+def init_db(DIM: str):
+ db.enable_load_extension(True)
+ sqlite_vec.load(db)
+ db.enable_load_extension(False)
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])")
+ db.execute('''
+ CREATE TABLE IF NOT EXISTS chunks (
+ id INTEGER PRIMARY KEY,
+ text TEXT
+ )''')
+ db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)")
+
+def store_chunks(chunks: list[str], V_np):
+ assert len(chunks) == len(V_np)
+ db.execute("BEGIN")
+ db.executemany('''
+ INSERT INTO chunks(id, text) VALUES (?, ?)
+ ''', list(enumerate(chunks, start=1)))
+ db.executemany(
+ "INSERT INTO vec(rowid, embedding) VALUES (?, ?)",
+ [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]
+ )
+ db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
+ db.commit()
+
+# 5) Query helper (cosine distance; operator may be <#> in sqlite-vec)
+def search(qtext, k=5):
+ q = model.encode([qtext], normalize=True, convert_to_numpy=True).astype("float32")
+ # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
+ rows = db.execute(
+ "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ (memoryview(q.tobytes()), k)
+ ).fetchall()
+
+# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
+ return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
+
+
+
+init_db(DIM)
+store_chunks(chunks, V_np)
+
+
+# # Hybrid + CE rerank query:
+# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8)
+# for rid, txt, score in results:
+# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n")
+
+>>>> mmr.py
+import numpy as np
+
+def cosine(a, b): return float(np.dot(a, b))
+
+def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7):
+ """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids"""
+ selected, selected_idx = [], []
+ remaining = list(range(len(cand_ids)))
+
+ # seed with the most relevant
+ best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i]))
+ selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0)
+
+ while remaining and len(selected) < k:
+ def mmr_score(i):
+ rel = cosine(query_vec, cand_vecs[i])
+ red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0
+ return lamb * rel - (1.0 - lamb) * red
+ nxt = max(remaining, key=mmr_score)
+ selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt)
+ return selected
+
+>>>> rag.db
+
+>>>> rerank.py
+import torch
+import sqlite3
+import sqlite_vec
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
+
+# -----------------------------
+# RERANKER (cross-encoder)
+# -----------------------------
+RERANKER_ID = "BAAI/bge-reranker-base" # small & solid
+# RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # small & solid
+_rr_tok = AutoTokenizer.from_pretrained(RERANKER_ID)
+_rr_mod = AutoModelForSequenceClassification.from_pretrained(
+ RERANKER_ID, device_map="auto", dtype=torch.bfloat16
+)
+
+def rerank_cross_encoder(query: str, candidates: list[tuple[int, str]], batch_size: int = 32) -> list[tuple[int, str, float]]:
+ """
+ candidates: [(id, text), ...]
+ returns: [(id, text, score)] sorted desc by score
+ """
+ if not candidates:
+ return []
+ ids, texts = zip(*candidates)
+ scores = []
+ for i in range(0, len(texts), batch_size):
+ batch = texts[i:i+batch_size]
+ enc = _rr_tok(list(zip([query]*len(batch), batch)),
+ padding=True, truncation=True, max_length=512,
+ return_tensors="pt").to(_rr_mod.device)
+ with torch.no_grad():
+ logits = _rr_mod(**enc).logits.squeeze(-1) # [B]
+ scores.extend(logits.float().cpu().tolist())
+ ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
+ return ranked
+
+# -----------------------------
+# DB: vec + chunks + FTS5
+# -----------------------------
+db = sqlite3.connect("./rag.db")
+
+def init_db(dim: int):
+ db.enable_load_extension(True)
+ sqlite_vec.load(db) # loads vec0
+ db.enable_load_extension(False)
+
+ # Vector table (cosine expects float32 blob)
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{dim}])")
+
+ # Raw chunks
+ db.execute("""
+ CREATE TABLE IF NOT EXISTS chunks(
+ id INTEGER PRIMARY KEY,
+ text TEXT
+ )
+ """)
+
+ # FTS5 over chunks.text (contentless mode keeps it simple)
+ db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)")
+ db.execute("CREATE INDEX IF NOT EXISTS idx_chunks_id ON chunks(id)")
+
+def store_chunks(chunks: list[str], V_np):
+ # chunks: list[str], V_np: np.ndarray float32 [N, D]
+ assert len(chunks) == len(V_np), "mismatch chunks vs vectors"
+
+ db.execute("BEGIN")
+ db.executemany("INSERT INTO chunks(id, text) VALUES (?,?)",
+ list(enumerate(chunks, start=1)))
+ db.executemany("INSERT INTO vec(rowid, embedding) VALUES (?,?)",
+ [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))])
+ # FTS content
+ db.executemany("INSERT INTO fts(rowid, text) VALUES (?,?)",
+ list(enumerate(chunks, start=1)))
+ db.commit()
+
+# -----------------------------
+# SEARCH: hybrid → rerank
+# -----------------------------
+def _vec_topk(qvec_f32: bytes, k: int):
+ # cosine distance: smaller is better
+ rows = db.execute(
+ "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ (memoryview(qvec_f32), k)
+ ).fetchall()
+ return [(rid, float(dist)) for (rid, dist) in rows]
+
+def _bm25_topk(query: str, k: int):
+ # note: FTS5 MATCH syntax, tokenizer defaults ok for English/legal
+ rows = db.execute(
+ "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?",
+ (query, k)
+ ).fetchall()
+ return [rid for (rid,) in rows]
+
+def _fetch_text(ids: list[int]) -> list[tuple[int, str]]:
+ if not ids:
+ return []
+ qmarks = ",".join("?"*len(ids))
+ rows = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", ids).fetchall()
+ by_id = {i:t for (i,t) in rows}
+ # keep input order
+ return [(i, by_id[i]) for i in ids if i in by_id]
+
+def search_hybrid(query: str, embed_model, k_vec: int = 40, k_bm25: int = 40, k_final: int = 10):
+ # 1) embed query (unit-norm) with same model you used for chunks
+ with torch.no_grad():
+ q = embed_model.encode([query], normalize=True).to("cpu").to(torch.float32).numpy()[0]
+ qbytes = q.tobytes()
+
+ # 2) candidate pools
+ vec_hits = _vec_topk(qbytes, k_vec) # [(id, dist)]
+ vec_ids = [i for (i, _) in vec_hits]
+ bm25_ids = _bm25_topk(query, k_bm25)
+
+ # 3) merge (preserve vector order first)
+ seen = set()
+ merged_ids = []
+ for i in vec_ids + bm25_ids:
+ if i not in seen:
+ merged_ids.append(i); seen.add(i)
+
+ # 4) fetch texts, then rerank with cross-encoder
+ cand = _fetch_text(merged_ids)
+ ranked = rerank_cross_encoder(query, cand)
+
+ # 5) return top k_final: [(id, text, ce_score)]
+ return ranked[:k_final]
+
+>>>> test.py
+from pathlib import Path
+from docling.document_converter import DocumentConverter
+from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
+from transformers import AutoModel
+from sentence_transformers import SentenceTransformer
+import torch
+
+# converter = DocumentConverter()
+# chunker = HybridChunker()
+# file = Path("yek.md")
+# doc = converter.convert(file).document
+# chunk_iter = chunker.chunk(doc)
+# for chunk in chunk_iter:
+# print(chunk)
+# txt = chunker.contextualize(chunk)
+# print(txt)
+
+
+
+model = SentenceTransformer(
+ "Qwen/Qwen3-Embedding-8B",
+ model_kwargs={
+ # "trust_remote_code":True,
+ "attn_implementation":"flash_attention_2",
+ "device_map":"auto",
+ "dtype":torch.float16
+ }, tokenizer_kwargs={"padding_side": "left"}
+)
+
+batch: list[str] = ["This son of a bitch has gone too far", "Fuck me baby please", "I'm hungry now", "Charlie Kirk is dead"]
+
+v = model.encode("pepeee", normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True)
+print("v")
+print(type(v))
+print(v)
+print(v.dtype)
+print(v.device)
+
+# V = torch.cat([v], dim=0)
+# print("V")
+# print(type(V))
+# print(V)
+# print(V.dtype)
+# print(V.device)
+# print("V_np")
+# V_idk = V.cpu().float()
+
+# when they were pytorch tensors
+# V = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB:
+# V_np = V.float().cpu().numpy().astype("float32")
+# DIM = V_np.shape[1]
+# db = sqlite3.connect("./rag.db")
+
+