diff options
Diffstat (limited to 'rag')
-rw-r--r-- | rag/__init__.py | 0 | ||||
-rw-r--r-- | rag/db.py | 55 | ||||
-rw-r--r-- | rag/ingest.py | 61 | ||||
-rw-r--r-- | rag/main.py | 71 | ||||
-rw-r--r-- | rag/mmr.py | 26 | ||||
-rw-r--r-- | rag/nmain.py | 118 | ||||
-rw-r--r-- | rag/rerank.py | 31 | ||||
-rw-r--r-- | rag/search.py | 101 | ||||
-rw-r--r-- | rag/test.py | 77 |
9 files changed, 540 insertions, 0 deletions
diff --git a/rag/__init__.py b/rag/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/rag/__init__.py diff --git a/rag/db.py b/rag/db.py new file mode 100644 index 0000000..c015c96 --- /dev/null +++ b/rag/db.py @@ -0,0 +1,55 @@ +import sqlite3 +from numpy import ndarray +from sqlite_vec import serialize_float32 + +def get_db(): + db = sqlite3.connect("./rag.db") + db.execute("PRAGMA journal_mode=WAL;") + db.execute("PRAGMA synchronous=NORMAL;") + db.execute("PRAGMA temp_store=MEMORY;") + db.execute("PRAGMA mmap_size=300000000;") # 30GB if your OS allows + db.enable_load_extension(True) + sqlite_vec.load(db) + db.enable_load_extension(False) + return db + +def init_schema(db: sqlite3.Connection, DIM:int): + db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])") + db.execute(''' + CREATE TABLE IF NOT EXISTS chunks ( + id INTEGER PRIMARY KEY, + text TEXT + )''') + db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)") + db.commit() + +def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray): + assert len(chunks) == len(V_np) + db.execute("BEGIN") + db.executemany(''' + INSERT INTO chunks(id, text) VALUES (?, ?) + ''', list(enumerate(chunks, start=1))) + db.executemany( + "INSERT INTO vec(rowid, embedding) VALUES (?, ?)", + [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))] + ) + db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) + db.commit() + +def vec_topk(db, q_vec_f32, k=10): + rows = db.execute( + "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + (serialize_float32(q_vec_f32), k) + ).fetchall() + return rows # [(rowid, distance)] + + +def bm25_topk(db, qtext, k=10): + return [rid for (rid,) in db.execute( + "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k) + ).fetchall()] + + def wipe_db(): + db = sqlite3.connect("./rag.db") + db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;") + db.close() diff --git a/rag/ingest.py b/rag/ingest.py new file mode 100644 index 0000000..d17690a --- /dev/null +++ b/rag/ingest.py @@ -0,0 +1,61 @@ +# from rag.rerank import search_hybrid +import sqlite3 +import torch +from docling_core.transforms.chunker.hybrid_chunker import HybridChunker +from pathlib import Path +from docling.document_converter import DocumentConverter +from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer +from rag.db import get_db, init_schema, store_chunks +from sentence_transformers import SentenceTransformer + +EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" +RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B" +MAX_TOKENS = 600 +BATCH = 16 + + +def get_embed_model(): + return SentenceTransformer( + "Qwen/Qwen3-Embedding-8B", + model_kwargs={ + # "trust_remote_code":True, + "attn_implementation":"flash_attention_2", + "device_map":"auto", + "dtype":torch.float16 + }, tokenizer_kwargs={"padding_side": "left"} + ) + + +def parse_and_chunk(source: Path, model: SentenceTransformer)-> list[str]: + tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS) + + converter = DocumentConverter() + doc = converter.convert(source) + chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True) + out = [] + for ch in chunker.chunk(doc.document): + txt = chunker.contextualize(ch) + if txt.strip(): + out.append(txt) + return out + +def embed_many(model: SentenceTransformer, texts: list[str]): + V_np = model.encode(texts, + batch_size=BATCH, + normalize_embeddings=True, + convert_to_numpy=True, + show_progress_bar=True) + return V_np.astype("float32") + + +def start_ingest(db: sqlite3.Connection, path: Path): + model = get_embed_model() + chunks = parse_and_chunk(path, model) + V_np = embed_many(model, chunks) + DIM = V_np.shape[1] + db = get_db() + init_schema(db, DIM) + store_chunks(db, chunks, V_np) + # TODO some try catch? + return True +# float32/fp16 on CPU? ensure float32 for DB: diff --git a/rag/main.py b/rag/main.py new file mode 100644 index 0000000..221adfa --- /dev/null +++ b/rag/main.py @@ -0,0 +1,71 @@ +import sys +import argparse +from pathlib import Path +from rag.ingest import start_ingest +from rag.search import search_hybrid, vec_search +from rag.db import get_db +# your Docling chunker imports… + +def main(): + ap = argparse.ArgumentParser(prog="rag") + sub = ap.add_subparsers(dest="cmd", required=True) + + # ingest + ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection") + ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest") + ap_ing.set_defaults(func=cmd_ingest) + + # query + ap_q = sub.add_parser("query", help="Query a collection") + ap_q.add_argument("--query", required=True, help="User query text") + ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)") + ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE") + ap_q.add_argument("--mmr-lambda", type=float, default=0.7) + ap_q.add_argument("--k-vec", type=int, default=50) + ap_q.add_argument("--k-bm25", type=int, default=50) + ap_q.add_argument("--k-ce", type=int, default=30) + ap_q.add_argument("--k-final", type=int, default=10) + ap_q.set_defaults(func=cmd_query) + + args = ap.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() + + + + +def cmd_ingest(args): + path = Path(args.file) + + if not path.exists(): + print(f"File not found: {path}", file=sys.stderr) + sys.exit(1) + + + db = get_db() + stats = start_ingest(db,path) + print(f"Ingested file={args.file} :: {stats}") + +def cmd_query(args): + + db = get_db() + if args.simple: + results = vec_search(db, args.query, k=args.k_final) + else: + results = search_hybrid( + db, args.query, + k_vec=args.k_vec, + k_bm25=args.k_bm25, + k_ce=args.k_ce, + k_final=args.k_final, + use_mmr=args.mmr, + mmr_lambda=args.mmr_lambda, + ) + + for rid, txt, score in results: + print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") + + db.close() diff --git a/rag/mmr.py b/rag/mmr.py new file mode 100644 index 0000000..5e47c4f --- /dev/null +++ b/rag/mmr.py @@ -0,0 +1,26 @@ +import numpy as np + +def cosine(a, b): return float(np.dot(a, b)) + +def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7): + """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids""" + selected, selected_idx = [], [] + remaining = list(range(len(cand_ids))) + + # seed with the most relevant + best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i])) + selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0) + + while remaining and len(selected) < k: + def mmr_score(i): + rel = cosine(query_vec, cand_vecs[i]) + red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0 + return lamb * rel - (1.0 - lamb) * red + nxt = max(remaining, key=mmr_score) + selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt) + return selected + +def embed_unit_np(st_model, texts: list[str]) -> np.ndarray: + V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32) + V = V.astype("float32", copy=False) + return V diff --git a/rag/nmain.py b/rag/nmain.py new file mode 100644 index 0000000..aa67dea --- /dev/null +++ b/rag/nmain.py @@ -0,0 +1,118 @@ +# rag/main.py +import argparse, sqlite3, sys +from pathlib import Path + +from sentence_transformers import SentenceTransformer +from rag.ingest import start_ingest +from rag.search import search_hybrid, search as vec_search + +DB_PATH = "./rag.db" +EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" + +def open_db(): + db = sqlite3.connect(DB_PATH) + # speed-ish pragmas + db.execute("PRAGMA journal_mode=WAL;") + db.execute("PRAGMA synchronous=NORMAL;") + return db + +def load_st_model(): + # ST handles batching + GPU internally + return SentenceTransformer( + EMBED_MODEL_ID, + model_kwargs={ + "attn_implementation": "flash_attention_2", + "device_map": "auto", + "torch_dtype": "float16", + }, + tokenizer_kwargs={"padding_side": "left"}, + ) + +def ensure_collection_exists(db, collection: str): + row = db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (f"vec_{collection}",) + ).fetchone() + return bool(row) + +def cmd_ingest(args): + db = open_db() + st_model = load_st_model() + path = Path(args.file) + + if not path.exists(): + print(f"File not found: {path}", file=sys.stderr) + sys.exit(1) + + if args.rebuild and ensure_collection_exists(db, args.collection): + db.executescript(f""" + DROP TABLE IF EXISTS chunks_{args.collection}; + DROP TABLE IF EXISTS fts_{args.collection}; + DROP TABLE IF EXISTS vec_{args.collection}; + """) + + stats = start_ingest( + db, st_model, + path=path, + collection=args.collection, + ) + print(f"Ingested collection={args.collection} :: {stats}") + db.close() + +def cmd_query(args): + db = open_db() + st_model = load_st_model() + + coll_ok = ensure_collection_exists(db, args.collection) + if not coll_ok: + print(f"Collection '{args.collection}' not found. Ingest first.", file=sys.stderr) + sys.exit(2) + + if args.simple: + results = vec_search(db, st_model, args.query, collection=args.collection, k=args.k_final) + else: + results = search_hybrid( + db, st_model, args.query, + collection=args.collection, + k_vec=args.k_vec, + k_bm25=args.k_bm25, + k_ce=args.k_ce, + k_final=args.k_final, + use_mmr=args.mmr, + mmr_lambda=args.mmr_lambda, + ) + + for rid, txt, score in results: + print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") + + db.close() + +def main(): + ap = argparse.ArgumentParser(prog="rag") + sub = ap.add_subparsers(dest="cmd", required=True) + + # ingest + ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection") + ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest") + ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)") + ap_ing.add_argument("--rebuild", action="store_true", help="Drop and recreate collection tables") + ap_ing.set_defaults(func=cmd_ingest) + + # query + ap_q = sub.add_parser("query", help="Query a collection") + ap_q.add_argument("--collection", required=True, help="Collection name to search") + ap_q.add_argument("--query", required=True, help="User query text") + ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)") + ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE") + ap_q.add_argument("--mmr-lambda", type=float, default=0.7) + ap_q.add_argument("--k-vec", type=int, default=50) + ap_q.add_argument("--k-bm25", type=int, default=50) + ap_q.add_argument("--k-ce", type=int, default=30) + ap_q.add_argument("--k-final", type=int, default=10) + ap_q.set_defaults(func=cmd_query) + + args = ap.parse_args() + args.func(args) + +if __name__ == "__main__": + main() diff --git a/rag/rerank.py b/rag/rerank.py new file mode 100644 index 0000000..6ae8938 --- /dev/null +++ b/rag/rerank.py @@ -0,0 +1,31 @@ +# rag/rerank.py +import torch +from sentence_transformers import CrossEncoder + +RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM +# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM +# device: "cuda" | "cpu" | "mps" +def get_rerank_model(): + return CrossEncoder( + RERANKER_ID, + device="cuda", + model_kwargs={ + "attn_implementation":"flash_attention_2", + "device_map":"auto", + "dtype":torch.float16 + }, + tokenizer_kwargs={"padding_side": "left"} + ) + +def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = 32): + """ + candidates: [(id, text), ...] + returns: [(id, text, score)] sorted desc by score + """ + if not candidates: + return [] + ids, texts = zip(*candidates) + pairs = [(query, t) for t in texts] + scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better + ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) + return ranked diff --git a/rag/search.py b/rag/search.py new file mode 100644 index 0000000..55b8ffd --- /dev/null +++ b/rag/search.py @@ -0,0 +1,101 @@ +import sqlite3 +import numpy as np +import torch +import gc +from typing import List, Tuple + +from sentence_transformers import CrossEncoder, SentenceTransformer +from rag.ingest import get_embed_model +from rag.rerank import get_rerank_model, rerank_cross_encoder +from rag.mmr import mmr, embed_unit_np # if you added MMR; else remove + +def search_hybrid( + db: sqlite3.Connection, query: str, + k_vec: int = 40, k_bm25: int = 40, + k_ce: int = 30, # rerank this many + k_final: int = 10, # return this many + use_mmr: bool = False, mmr_lambda: float = 0.7 +): + emodel = get_embed_model() + # 1) embed query (unit, float32) for vec search + MMR + print("loading") + q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") + emodel = None + with torch.no_grad(): + torch.cuda.empty_cache() + gc.collect() + print("memory should be free by now!!") + qbytes = q.tobytes() + + # 2) ANN (cosine) + BM25 pools + vec_ids = [i for (i, _) in db.execute( + "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", + (memoryview(qbytes), k_vec) + ).fetchall()] + bm25_ids = [i for (i,) in db.execute( + "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", + (query, k_bm25) + ).fetchall()] + + # 3) merge (vector-first) + seen, merged = set(), [] + for i in vec_ids + bm25_ids: + if i not in seen: + merged.append(i); seen.add(i) + if not merged: + return [] + + # 4) fetch texts for CE + qmarks = ",".join("?"*len(merged)) + cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall() + + reranker = get_rerank_model() + # 5) cross-encoder rerank (returns [(id,text,score)] desc) + ranked = rerank_cross_encoder(reranker, query, cand) + reranker = None + print("freeing again!!") + with torch.no_grad(): + torch.cuda.empty_cache() + gc.collect() + # + ranked = ranked[:min(k_ce, len(ranked))] + + if not use_mmr or len(ranked) <= k_final: + return ranked[:k_final] + + # 6) MMR diversity on CE top-k_ce + cand_ids = [i for (i,_,_) in ranked] + cand_text = [t for (_,t,_) in ranked] + emodel = get_embed_model() + # god this is annoying I should stop being poor + cand_vecs = embed_unit_np(emodel, cand_text) # [N,D], unit vectors + sel_ids = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda)) + final = [trip for trip in ranked if trip[0] in sel_ids] # keep CE order, filter by MMR picks + return final[:k_final] + + +# ) Query helper (cosine distance; operator may be <#> in sqlite-vec) +def vec_search(db: sqlite3.Connection, qtext, k=5): + model = get_embed_model() + q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32") + # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine + rows = db.execute( + "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", + (memoryview(q.tobytes()), k) + ).fetchall() + +# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40)) + return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows] + + +# # Hybrid + CE rerank query: +# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8) +# for rid, txt, score in results: +# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n") +# +# +# +# +# +# +# diff --git a/rag/test.py b/rag/test.py new file mode 100644 index 0000000..b7a6d8e --- /dev/null +++ b/rag/test.py @@ -0,0 +1,77 @@ +from pathlib import Path +from docling.document_converter import DocumentConverter +from docling_core.transforms.chunker.hybrid_chunker import HybridChunker +from transformers import AutoModel +from sentence_transformers import SentenceTransformer +from rag.db import get_db +from rag.rerank import get_rerank_model +import rag.ingest +import rag.search +import torch + +# converter = DocumentConverter() +# chunker = HybridChunker() +# file = Path("yek.md") +# doc = converter.convert(file).document +# chunk_iter = chunker.chunk(doc) +# for chunk in chunk_iter: +# print(chunk) +# txt = chunker.contextualize(chunk) +# print(txt) + + + +def t(): + batch: list[str] = ["This son of a bitch has gone too far", "Fuck me baby please", "I'm hungry now", "Charlie Kirk is dead"] + model = rag.ingest.get_embed_model() + v = model.encode("pepeee", normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True) + print("v") + print(type(v)) + print(v) + print(v.dtype) + print(v.device) + +# V = torch.cat([v], dim=0) +# print("V") +# print(type(V)) +# print(V) +# print(V.dtype) +# print(V.device) +# print("V_np") +# V_idk = V.cpu().float() + +# when they were pytorch tensors +# V = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB: +# V_np = V.float().cpu().numpy().astype("float32") +# DIM = V_np.shape[1] +# db = sqlite3.connect("./rag.db") + +queries = [ +"How was Shuihu zhuan received in early modern Japan?", +"Edo-period readers’ image of Song Jiang / Liangshan outlaws", +"Channels of transmission for Chinese vernacular fiction into Japan (kanbun kundoku, digests, translations)", +"Role of woodblock prints/illustrations in mediating Chinese fiction", +"Key Japanese scholars, writers, or publishers who popularized Chinese fiction", +"Kyokutei Bakin’s engagement with Chinese vernacular narrative", +"Santō Kyōden, gesaku, and Chinese models", +"Kanzen chōaku (encourage good, punish evil) and Water Margin in Japan", +"Moral ambivalence of outlaw heroes as discussed in the text", +"Censorship or moral debates around reading Chinese fiction", +"Translation strategies from vernacular Chinese to Japanese (furigana, kundoku, glossing)", +"Paratexts: prefaces, commentaries, reader guidance apparatus", +"Bibliographic details: editions, reprints, circulation networks", +"How does this book challenge older narratives about Sino-Japanese literary influence?", +"Methodology: sources, archives, limitations mentioned by the author", + +] +def t2(): + db = get_db() + # Hybrid + CE rerank query: + for query in queries: + print("query", query) + results = rag.search.search_hybrid(db, query, k_vec=50, k_bm25=50, k_final=8) + for rid, txt, score in results: + sim = score + print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n") + +t2() |