summaryrefslogtreecommitdiff
path: root/yek.txt
diff options
context:
space:
mode:
authorpolwex <polwex@sortug.com>2025-09-24 23:38:36 +0700
committerpolwex <polwex@sortug.com>2025-09-24 23:38:36 +0700
commit734b89570040e97f0c7743c4c0bc28e30a3cd4ee (patch)
tree7142d9f37908138c38d0ade066e960c3a1c69f5d /yek.txt
parent57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (diff)
Diffstat (limited to 'yek.txt')
-rw-r--r--yek.txt694
1 files changed, 694 insertions, 0 deletions
diff --git a/yek.txt b/yek.txt
new file mode 100644
index 0000000..37b02f9
--- /dev/null
+++ b/yek.txt
@@ -0,0 +1,694 @@
+>>>> __init__.py
+
+>>>> constants.py
+EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
+RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM
+# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
+MAX_TOKENS = 600
+BATCH = 16
+
+>>>> db.py
+import sqlite3
+from numpy import ndarray
+from sqlite_vec import serialize_float32
+import sqlite_vec
+
+def get_db():
+ db = sqlite3.connect("./rag.db")
+ db.execute("PRAGMA journal_mode=WAL;")
+ db.execute("PRAGMA synchronous=NORMAL;")
+ db.execute("PRAGMA temp_store=MEMORY;")
+ db.execute("PRAGMA mmap_size=300000000;") # 30GB if your OS allows
+ db.enable_load_extension(True)
+ sqlite_vec.load(db)
+ db.enable_load_extension(False)
+ return db
+
+def init_schema(db: sqlite3.Connection, col: str, DIM:int, model_id: str, tok_id: str, normalize: bool, preproc_hash: str):
+ print("initing schema", col)
+ db.execute("""
+ CREATE TABLE IF NOT EXISTS collections(
+ name TEXT PRIMARY KEY,
+ model TEXT,
+ tokenizer TEXT,
+ dim INTEGER,
+ normalize INTEGER,
+ preproc_hash TEXT,
+ created_at INTEGER DEFAULT (unixepoch())
+ )""")
+ db.execute("BEGIN")
+ db.execute("INSERT INTO collections(name, model, tokenizer, dim, normalize, preproc_hash) VALUES(?, ?, ?, ?, ?, ?)",
+ (col, model_id, tok_id, DIM, int(normalize), preproc_hash))
+
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec_{col} USING vec0(embedding float[{DIM}])")
+ db.execute(f'''
+ CREATE TABLE IF NOT EXISTS chunks_{col} (
+ id INTEGER PRIMARY KEY,
+ text TEXT
+ )''')
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS fts_{col} USING fts5(text)")
+ db.commit()
+
+
+def check_db(db: sqlite3.Connection, coll: str):
+ row = db.execute("SELECT dim, model, normalize FROM collections WHERE name=?", (coll,)).fetchone()
+ return row
+ # assert row and row[0] == DIM and row[1] == EMBED_MODEL_ID and row[2] == 1 # if you normalize
+def check_db2(db: sqlite3.Connection, coll: str):
+ row = db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name=?",
+ (f"vec_{coll}",)
+ ).fetchone()
+ return bool(row)
+def store_chunks(db: sqlite3.Connection, col: str, chunks: list[str], V_np:ndarray):
+ assert len(chunks) == len(V_np)
+ db.execute("BEGIN")
+ db.executemany(f'''
+ INSERT INTO chunks_{col}(id, text) VALUES (?, ?)
+ ''', list(enumerate(chunks, start=1)))
+ db.executemany(
+ f"INSERT INTO vec_{col}(rowid, embedding) VALUES (?, ?)",
+ [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]
+ )
+ db.executemany(f"INSERT INTO fts_{col}(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
+ db.commit()
+
+def vec_topk(db,col: str, q_vec_f32, k=10):
+ # rows = db.execute(
+ # "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ # (memoryview(q.tobytes()), k)
+ # ).fetchall()
+ rows = db.execute(
+ f"SELECT rowid, distance FROM vec_{col} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
+ (serialize_float32(q_vec_f32), k)
+ ).fetchall()
+ return rows # [(rowid, distance)]
+
+
+def bm25_topk(db: sqlite3.Connection, col: str, qtext, k=10):
+ safe_q = f'"{qtext}"'
+ return [rid for (rid,) in db.execute(
+ f"SELECT rowid FROM fts_{col} WHERE fts_{col} MATCH ? LIMIT ?", (safe_q, k)
+ ).fetchall()]
+
+def wipe_db(col: str):
+ db = sqlite3.connect("./rag.db")
+ db.executescript(f"DROP TABLE IF EXISTS chunks_{col}; DROP TABLE IF EXISTS fts_{col}; DROP TABLE IF EXISTS vec_{col};")
+ db.close()
+
+>>>> ingest.py
+# from rag.rerank import search_hybrid
+import sqlite3
+import torch
+from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
+from pathlib import Path
+from docling.document_converter import DocumentConverter
+from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
+from rag.db import get_db, init_schema, store_chunks
+from sentence_transformers import SentenceTransformer
+from rag.constants import MAX_TOKENS, BATCH, EMBED_MODEL_ID
+
+
+
+def get_embed_model():
+ return SentenceTransformer(
+ EMBED_MODEL_ID,
+ model_kwargs={
+ # "trust_remote_code":True,
+ "attn_implementation":"flash_attention_2",
+ "device_map":"auto",
+ "dtype":torch.float16
+ }, tokenizer_kwargs={"padding_side": "left"}
+ )
+
+
+def parse_and_chunk(source: Path, model: SentenceTransformer)-> list[str]:
+ tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS)
+
+ converter = DocumentConverter()
+ doc = converter.convert(source)
+ chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True)
+ out = []
+ for ch in chunker.chunk(doc.document):
+ txt = chunker.contextualize(ch)
+ if txt.strip():
+ out.append(txt)
+ return out
+
+def embed_many(model: SentenceTransformer, texts: list[str]):
+ V_np = model.encode(texts,
+ batch_size=BATCH,
+ normalize_embeddings=True,
+ convert_to_numpy=True,
+ show_progress_bar=True)
+ return V_np.astype("float32")
+
+
+def start_ingest(db: sqlite3.Connection, model: SentenceTransformer | None, collection: str, path: Path):
+ if model is None:
+ model = get_embed_model()
+ chunks = parse_and_chunk(path, model)
+ V_np = embed_many(model, chunks)
+ store_chunks(db, collection, chunks, V_np)
+ # TODO some try catch?
+ return True
+# float32/fp16 on CPU? ensure float32 for DB:
+
+>>>> main.py
+import sys
+import argparse
+from pathlib import Path
+from rag.constants import EMBED_MODEL_ID
+from rag.ingest import get_embed_model, start_ingest
+from rag.search import search_hybrid, vec_search
+from rag.db import get_db, check_db, check_db2, init_schema
+
+
+def valid_collection(col: str) -> bool:
+ # TODO must have less than 9 characters and be ascii, no spaces
+ return True
+
+
+def cmd_ingest(args):
+ path = Path(args.file)
+ if not valid_collection(args.collection):
+ print(f"Collection name invalid: {args.collection}", file=sys.stderr)
+ sys.exit(1)
+
+
+ if not path.exists():
+ print(f"File not found: {path}", file=sys.stderr)
+ sys.exit(1)
+
+
+ db = get_db()
+ if not check_db2(db, args.collection):
+ model = get_embed_model()
+ dim = model.get_sentence_embedding_dimension()
+ if dim is None:
+ sys.exit(1)
+
+ # TODO Try catch in here, tell the user if it crashes
+ init_schema(db, args.collection, dim, EMBED_MODEL_ID, EMBED_MODEL_ID, True, 'idk')
+ stats = start_ingest(db, model, args.collection, path)
+ else:
+ stats = start_ingest(db, None, args.collection, path)
+ print(f"Ingested file={args.file} :: {stats}")
+
+def cmd_query(args):
+ if not valid_collection(args.collection):
+ print(f"Collection name invalid: {args.collection}", file=sys.stderr)
+ sys.exit(1)
+ db = get_db()
+ if not check_db2(db, args.collection):
+ print(f"Collection name not in DB, what are you searching: {args.collection}", file=sys.stderr)
+ sys.exit(1)
+
+ if args.simple:
+ results = vec_search(db, args.collection, args.query, k=args.k_final)
+ else:
+ results = search_hybrid(db,
+ args.collection,
+ args.query,
+ k_vec=args.k_vec,
+ k_bm25=args.k_bm25,
+ k_ce=args.k_ce,
+ k_final=args.k_final,
+ use_mmr=args.mmr,
+ mmr_lambda=args.mmr_lambda,
+ )
+
+ for rid, txt, score in results:
+ print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n")
+
+ db.close()
+
+
+
+
+
+def main():
+ ap = argparse.ArgumentParser(prog="rag")
+ sub = ap.add_subparsers(dest="cmd", required=True)
+
+ # ingest
+ ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection")
+ ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest")
+ ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)")
+ ap_ing.set_defaults(func=cmd_ingest)
+
+ # query
+ ap_q = sub.add_parser("query", help="Query a collection")
+ ap_q.add_argument("--collection", required=True, help="Collection name to search")
+ ap_q.add_argument("--query", required=True, help="User query text")
+ ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)")
+ ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE")
+ ap_q.add_argument("--mmr-lambda", type=float, default=0.7)
+ ap_q.add_argument("--k-vec", type=int, default=50)
+ ap_q.add_argument("--k-bm25", type=int, default=50)
+ ap_q.add_argument("--k-ce", type=int, default=30)
+ ap_q.add_argument("--k-final", type=int, default=10)
+ ap_q.set_defaults(func=cmd_query)
+
+ args = ap.parse_args()
+ args.func(args)
+
+
+if __name__ == "__main__":
+ main()
+
+
+
+>>>> mmr.py
+from rag.constants import BATCH
+import numpy as np
+
+def cosine(a, b): return float(np.dot(a, b))
+
+def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7):
+ """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids"""
+ selected, selected_idx = [], []
+ remaining = list(range(len(cand_ids)))
+
+ # seed with the most relevant
+ best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i]))
+ selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0)
+
+ while remaining and len(selected) < k:
+ def mmr_score(i):
+ rel = cosine(query_vec, cand_vecs[i])
+ red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0
+ return lamb * rel - (1.0 - lamb) * red
+ nxt = max(remaining, key=mmr_score)
+ selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt)
+ return selected
+
+def embed_unit_np(st_model, texts: list[str]) -> np.ndarray:
+ V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=BATCH)
+ V = V.astype("float32", copy=False)
+ return V
+
+def mmr2(qvec: np.ndarray, ids, vecs: np.ndarray, k=8, lamb=0.7):
+ sel, idxs = [], []
+ rest = list(range(len(ids)))
+ best0 = max(rest, key=lambda i: float(qvec @ vecs[i]))
+ sel.append(ids[best0]); idxs.append(best0); rest.remove(best0)
+ while rest and len(sel) < k:
+ def score(i):
+ rel = float(qvec @ vecs[i])
+ red = max(float(vecs[i] @ vecs[j]) for j in idxs)
+ return lamb*rel - (1-lamb)*red
+ nxt = max(rest, key=score)
+ sel.append(ids[nxt]); idxs.append(nxt); rest.remove(nxt)
+ return sel
+
+>>>> nmain.py
+# rag/main.py
+import argparse, sqlite3, sys
+from pathlib import Path
+
+from sentence_transformers import SentenceTransformer
+from rag.ingest import start_ingest
+from rag.search import search_hybrid, search as vec_search
+
+DB_PATH = "./rag.db"
+EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
+
+def open_db():
+ db = sqlite3.connect(DB_PATH)
+ # speed-ish pragmas
+ db.execute("PRAGMA journal_mode=WAL;")
+ db.execute("PRAGMA synchronous=NORMAL;")
+ return db
+
+def load_st_model():
+ # ST handles batching + GPU internally
+ return SentenceTransformer(
+ EMBED_MODEL_ID,
+ model_kwargs={
+ "attn_implementation": "flash_attention_2",
+ "device_map": "auto",
+ "torch_dtype": "float16",
+ },
+ tokenizer_kwargs={"padding_side": "left"},
+ )
+
+def ensure_collection_exists(db, collection: str):
+ row = db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name=?",
+ (f"vec_{collection}",)
+ ).fetchone()
+ return bool(row)
+
+def cmd_ingest(args):
+ db = open_db()
+ st_model = load_st_model()
+ path = Path(args.file)
+
+ if not path.exists():
+ print(f"File not found: {path}", file=sys.stderr)
+ sys.exit(1)
+
+ if args.rebuild and ensure_collection_exists(db, args.collection):
+ db.executescript(f"""
+ DROP TABLE IF EXISTS chunks_{args.collection};
+ DROP TABLE IF EXISTS fts_{args.collection};
+ DROP TABLE IF EXISTS vec_{args.collection};
+ """)
+
+ stats = start_ingest(
+ db, st_model,
+ path=path,
+ collection=args.collection,
+ )
+ print(f"Ingested collection={args.collection} :: {stats}")
+ db.close()
+
+def cmd_query(args):
+ db = open_db()
+ st_model = load_st_model()
+
+ coll_ok = ensure_collection_exists(db, args.collection)
+ if not coll_ok:
+ print(f"Collection '{args.collection}' not found. Ingest first.", file=sys.stderr)
+ sys.exit(2)
+
+ if args.simple:
+ results = vec_search(db, st_model, args.query, collection=args.collection, k=args.k_final)
+ else:
+ results = search_hybrid(
+ db, st_model, args.query,
+ collection=args.collection,
+ k_vec=args.k_vec,
+ k_bm25=args.k_bm25,
+ k_ce=args.k_ce,
+ k_final=args.k_final,
+ use_mmr=args.mmr,
+ mmr_lambda=args.mmr_lambda,
+ )
+
+ for rid, txt, score in results:
+ print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n")
+
+ db.close()
+
+def main():
+ ap = argparse.ArgumentParser(prog="rag")
+ sub = ap.add_subparsers(dest="cmd", required=True)
+
+ # ingest
+ ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection")
+ ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest")
+ ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)")
+ ap_ing.add_argument("--rebuild", action="store_true", help="Drop and recreate collection tables")
+ ap_ing.set_defaults(func=cmd_ingest)
+
+ # query
+ ap_q = sub.add_parser("query", help="Query a collection")
+ ap_q.add_argument("--collection", required=True, help="Collection name to search")
+ ap_q.add_argument("--query", required=True, help="User query text")
+ ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)")
+ ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE")
+ ap_q.add_argument("--mmr-lambda", type=float, default=0.7)
+ ap_q.add_argument("--k-vec", type=int, default=50)
+ ap_q.add_argument("--k-bm25", type=int, default=50)
+ ap_q.add_argument("--k-ce", type=int, default=30)
+ ap_q.add_argument("--k-final", type=int, default=10)
+ ap_q.set_defaults(func=cmd_query)
+
+ args = ap.parse_args()
+ args.func(args)
+
+if __name__ == "__main__":
+ main()
+
+>>>> rerank.py
+# rag/rerank.py
+import torch
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
+from sentence_transformers import CrossEncoder
+from rag.constants import BATCH, RERANKER_ID
+
+# device: "cuda" | "cpu" | "mps"
+def get_rerank_model():
+ id = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
+ return CrossEncoder(
+ id,
+ device="cuda",
+ max_length=384,
+ model_kwargs={
+ # "attn_implementation":"flash_attention_2",
+ "device_map":"auto",
+ "dtype":torch.float16
+ },
+ tokenizer_kwargs={"padding_side": "left"}
+ )
+
+def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = BATCH):
+ """
+ candidates: [(id, text), ...]
+ returns: [(id, text, score)] sorted desc by score
+ """
+ if not candidates:
+ return []
+ ids, texts = zip(*candidates)
+ # pairs = [(query, t) for t in texts]
+ pairs = [(query, t[:1000]) for t in texts]
+ scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better
+ ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
+ return ranked
+
+
+
+
+# tok = AutoTokenizer.from_pretrained(RERANKER_ID, use_fast=True, model_max_length=384)
+# ce = AutoModelForSequenceClassification.from_pretrained(
+# RERANKER_ID, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
+# device_map="auto" # or load_in_8bit=True
+# )
+
+# def ce_scores(query, texts, batch_size=16, max_length=384):
+# scores = []
+# for i in range(0, len(texts), batch_size):
+# batch = texts[i:i+batch_size]
+# enc = tok([ (query, t[:1000]) for t in batch ],
+# padding=True, truncation=True, max_length=max_length,
+# return_tensors="pt").to(ce.device)
+# with torch.inference_mode():
+# logits = ce(**enc).logits.squeeze(-1) # [B]
+# scores.extend(logits.float().cpu().tolist())
+# return scores
+
+>>>> search.py
+import sqlite3
+import numpy as np
+import torch
+import gc
+from typing import List, Tuple
+
+from sentence_transformers import CrossEncoder, SentenceTransformer
+from rag.constants import BATCH
+from rag.ingest import get_embed_model
+from rag.rerank import get_rerank_model, rerank_cross_encoder
+from rag.mmr import mmr, mmr2, embed_unit_np # if you added MMR; else remove
+from rag.db import vec_topk, bm25_topk
+
+
+# ) Query helper (cosine distance; operator may be <#> in sqlite-vec)
+def vec_search(db: sqlite3.Connection, col: str, qtext, k=5):
+ model = get_embed_model()
+ q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
+ # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
+ rows = vec_topk(db, col, q, k)
+
+# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
+ return [(rid, db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
+
+
+
+
+def search_hybrid(
+ db: sqlite3.Connection, col: str, query: str,
+ k_vec: int = 40, k_bm25: int = 40,
+ k_ce: int = 30, # rerank this many
+ k_final: int = 10, # return this many
+ use_mmr: bool = False, mmr_lambda: float = 0.7
+):
+ emodel = get_embed_model()
+ # 1) embed query (unit, float32) for vec search + MMR
+ print("loading")
+ query_embeddings = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
+ emodel = None
+ with torch.no_grad():
+ torch.cuda.empty_cache()
+ gc.collect()
+ print("memory should be free by now!!")
+ # qbytes = q.tobytes()
+
+
+ # 2) ANN + BM25
+ print("phase 2", col, query)
+ vhits = vec_topk(db, col, query_embeddings, k_vec) # [(id, dist)]
+ vh_ids = [i for (i, _) in vhits]
+ bm_ids = bm25_topk(db, col, query, k_bm25)
+ #
+ # 3) merge ids [vector-first]
+ merged, seen = [], set()
+ for i in vh_ids + bm_ids:
+ if i not in seen:
+ merged.append(i); seen.add(i)
+ if not merged:
+ return []
+
+ # 4) fetch texts
+ qmarks = ",".join("?"*len(merged))
+ cand = db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN ({qmarks})", merged).fetchall()
+ ids, texts = zip(*cand)
+
+
+ # 5) rerank
+ print("loading reranking model")
+ reranker = get_rerank_model()
+ scores = reranker.predict([(query, t[:1000]) for t in texts], batch_size=BATCH)
+ reranker =None
+ with torch.no_grad():
+ torch.cuda.empty_cache()
+ gc.collect()
+ print("memory should be free by now!!")
+ print("unloading reranking model")
+ ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
+ if not use_mmr or len(ranked) <= k_final:
+ return ranked[:min(k_ce, k_final)]
+
+ # 6) MMR
+
+ ce_ids = [i for (i,_,_) in ranked]
+ ce_texts = [t for (_,t,_) in ranked]
+ st_model = get_embed_model()
+ ce_vecs = st_model.encode(ce_texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32")
+ keep = set(mmr2(query_embeddings, ce_ids, ce_vecs, k=k_final, lamb=mmr_lambda))
+ return [r for r in ranked if r[0] in keep][:k_final]
+
+# # Hybrid + CE rerank query:
+# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8)
+# for rid, txt, score in results:
+# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n")
+#
+#
+#
+#
+#
+#
+#
+def search_hybrid_with_mmr(db, col, query, k_vec=50, k_bm25=50, k_ce=30, k_final=10, lamb=0.7):
+ ranked = search_hybrid(db, col, query, k_vec, k_bm25, k_ce, k_ce)
+ if not ranked: return []
+ ids = [i for (i,_,_) in ranked]
+ texts = [t for (_,t,_) in ranked]
+ st_model = get_embed_model()
+ qvec = st_model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
+ cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32")
+ keep = set(mmr2(qvec, ids, cvecs, k=k_final, lamb=lamb))
+ return [r for r in ranked if r[0] in keep][:k_final]
+
+>>>> test.py
+from pathlib import Path
+from docling.document_converter import DocumentConverter
+from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
+from transformers import AutoModel
+from sentence_transformers import SentenceTransformer
+from rag.db import get_db
+from rag.rerank import get_rerank_model
+import rag.ingest
+import rag.search
+import torch
+
+# converter = DocumentConverter()
+# chunker = HybridChunker()
+# file = Path("yek.md")
+# doc = converter.convert(file).document
+# chunk_iter = chunker.chunk(doc)
+# for chunk in chunk_iter:
+# print(chunk)
+# txt = chunker.contextualize(chunk)
+# print(txt)
+
+
+
+def t():
+ batch: list[str] = ["This son of a bitch has gone too far", "Fuck me baby please", "I'm hungry now", "Charlie Kirk is dead"]
+ model = rag.ingest.get_embed_model()
+ v = model.encode("pepeee", normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True)
+ print("v")
+ print(type(v))
+ print(v)
+ print(v.dtype)
+ print(v.device)
+
+# V = torch.cat([v], dim=0)
+# print("V")
+# print(type(V))
+# print(V)
+# print(V.dtype)
+# print(V.device)
+# print("V_np")
+# V_idk = V.cpu().float()
+
+# when they were pytorch tensors
+# V = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB:
+# V_np = V.float().cpu().numpy().astype("float32")
+# DIM = V_np.shape[1]
+# db = sqlite3.connect("./rag.db")
+
+queries = [
+"How was Shuihu zhuan received in early modern Japan?",
+"Edo-period readers’ image of Song Jiang / Liangshan outlaws",
+"Channels of transmission for Chinese vernacular fiction into Japan (kanbun kundoku, digests, translations)",
+"Role of woodblock prints/illustrations in mediating Chinese fiction",
+"Key Japanese scholars, writers, or publishers who popularized Chinese fiction",
+"Kyokutei Bakin’s engagement with Chinese vernacular narrative",
+"Santō Kyōden, gesaku, and Chinese models",
+"Kanzen chōaku (encourage good, punish evil) and Water Margin in Japan",
+"Moral ambivalence of outlaw heroes as discussed in the text",
+"Censorship or moral debates around reading Chinese fiction",
+"Translation strategies from vernacular Chinese to Japanese (furigana, kundoku, glossing)",
+"Paratexts: prefaces, commentaries, reader guidance apparatus",
+"Bibliographic details: editions, reprints, circulation networks",
+"How does this book challenge older narratives about Sino-Japanese literary influence?",
+"Methodology: sources, archives, limitations mentioned by the author",
+
+]
+def t2():
+ db = get_db()
+ # Hybrid + CE rerank query:
+ for query in queries:
+ print("query", query)
+ print("-----------\n\n")
+ # results = rag.search.search_hybrid(db, "muh", query, k_vec=50, k_bm25=50, k_final=8)
+ # for rid, txt, score in results:
+ # sim = score
+ # print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n")
+ results = rag.search.vec_search(db, "muh", query, k_vec=50, k_bm25=50, k_final=8)
+ for rid, txt, score in results:
+ sim = score
+ print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n")
+
+t2()
+
+>>>> utils.py
+import re
+
+FTS_META_CHARS = r'''["'*()^+-]''' # include ? if you see issues
+
+def sanitize_query(q: str, *, allow_ops: bool = False) -> str:
+ q = q.strip()
+ if not q:
+ return q
+ if allow_ops:
+ # escape stray double quotes inside, then wrap
+ q = q.replace('"', '""')
+ return f'"{q}"'
+ # literal search: quote and escape special chars
+ q = re.sub(r'"', '""', q)
+ return f'"{q}"'
+