summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
Diffstat (limited to 'rag')
-rw-r--r--rag/__init__.py0
-rw-r--r--rag/db.py55
-rw-r--r--rag/ingest.py61
-rw-r--r--rag/main.py71
-rw-r--r--rag/mmr.py26
-rw-r--r--rag/nmain.py118
-rw-r--r--rag/rerank.py31
-rw-r--r--rag/search.py101
-rw-r--r--rag/test.py77
9 files changed, 540 insertions, 0 deletions
diff --git a/rag/__init__.py b/rag/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rag/__init__.py
diff --git a/rag/db.py b/rag/db.py
new file mode 100644
index 0000000..c015c96
--- /dev/null
+++ b/rag/db.py
@@ -0,0 +1,55 @@
+import sqlite3
+from numpy import ndarray
+from sqlite_vec import serialize_float32
+
+def get_db():
+ db = sqlite3.connect("./rag.db")
+ db.execute("PRAGMA journal_mode=WAL;")
+ db.execute("PRAGMA synchronous=NORMAL;")
+ db.execute("PRAGMA temp_store=MEMORY;")
+ db.execute("PRAGMA mmap_size=300000000;") # 30GB if your OS allows
+ db.enable_load_extension(True)
+ sqlite_vec.load(db)
+ db.enable_load_extension(False)
+ return db
+
+def init_schema(db: sqlite3.Connection, DIM:int):
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])")
+ db.execute('''
+ CREATE TABLE IF NOT EXISTS chunks (
+ id INTEGER PRIMARY KEY,
+ text TEXT
+ )''')
+ db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)")
+ db.commit()
+
+def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray):
+ assert len(chunks) == len(V_np)
+ db.execute("BEGIN")
+ db.executemany('''
+ INSERT INTO chunks(id, text) VALUES (?, ?)
+ ''', list(enumerate(chunks, start=1)))
+ db.executemany(
+ "INSERT INTO vec(rowid, embedding) VALUES (?, ?)",
+ [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]
+ )
+ db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
+ db.commit()
+
+def vec_topk(db, q_vec_f32, k=10):
+ rows = db.execute(
+ "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
+ (serialize_float32(q_vec_f32), k)
+ ).fetchall()
+ return rows # [(rowid, distance)]
+
+
+def bm25_topk(db, qtext, k=10):
+ return [rid for (rid,) in db.execute(
+ "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k)
+ ).fetchall()]
+
+ def wipe_db():
+ db = sqlite3.connect("./rag.db")
+ db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;")
+ db.close()
diff --git a/rag/ingest.py b/rag/ingest.py
new file mode 100644
index 0000000..d17690a
--- /dev/null
+++ b/rag/ingest.py
@@ -0,0 +1,61 @@
+# from rag.rerank import search_hybrid
+import sqlite3
+import torch
+from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
+from pathlib import Path
+from docling.document_converter import DocumentConverter
+from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
+from rag.db import get_db, init_schema, store_chunks
+from sentence_transformers import SentenceTransformer
+
+EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
+RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B"
+MAX_TOKENS = 600
+BATCH = 16
+
+
+def get_embed_model():
+ return SentenceTransformer(
+ "Qwen/Qwen3-Embedding-8B",
+ model_kwargs={
+ # "trust_remote_code":True,
+ "attn_implementation":"flash_attention_2",
+ "device_map":"auto",
+ "dtype":torch.float16
+ }, tokenizer_kwargs={"padding_side": "left"}
+ )
+
+
+def parse_and_chunk(source: Path, model: SentenceTransformer)-> list[str]:
+ tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS)
+
+ converter = DocumentConverter()
+ doc = converter.convert(source)
+ chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True)
+ out = []
+ for ch in chunker.chunk(doc.document):
+ txt = chunker.contextualize(ch)
+ if txt.strip():
+ out.append(txt)
+ return out
+
+def embed_many(model: SentenceTransformer, texts: list[str]):
+ V_np = model.encode(texts,
+ batch_size=BATCH,
+ normalize_embeddings=True,
+ convert_to_numpy=True,
+ show_progress_bar=True)
+ return V_np.astype("float32")
+
+
+def start_ingest(db: sqlite3.Connection, path: Path):
+ model = get_embed_model()
+ chunks = parse_and_chunk(path, model)
+ V_np = embed_many(model, chunks)
+ DIM = V_np.shape[1]
+ db = get_db()
+ init_schema(db, DIM)
+ store_chunks(db, chunks, V_np)
+ # TODO some try catch?
+ return True
+# float32/fp16 on CPU? ensure float32 for DB:
diff --git a/rag/main.py b/rag/main.py
new file mode 100644
index 0000000..221adfa
--- /dev/null
+++ b/rag/main.py
@@ -0,0 +1,71 @@
+import sys
+import argparse
+from pathlib import Path
+from rag.ingest import start_ingest
+from rag.search import search_hybrid, vec_search
+from rag.db import get_db
+# your Docling chunker imports…
+
+def main():
+ ap = argparse.ArgumentParser(prog="rag")
+ sub = ap.add_subparsers(dest="cmd", required=True)
+
+ # ingest
+ ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection")
+ ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest")
+ ap_ing.set_defaults(func=cmd_ingest)
+
+ # query
+ ap_q = sub.add_parser("query", help="Query a collection")
+ ap_q.add_argument("--query", required=True, help="User query text")
+ ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)")
+ ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE")
+ ap_q.add_argument("--mmr-lambda", type=float, default=0.7)
+ ap_q.add_argument("--k-vec", type=int, default=50)
+ ap_q.add_argument("--k-bm25", type=int, default=50)
+ ap_q.add_argument("--k-ce", type=int, default=30)
+ ap_q.add_argument("--k-final", type=int, default=10)
+ ap_q.set_defaults(func=cmd_query)
+
+ args = ap.parse_args()
+ args.func(args)
+
+
+if __name__ == "__main__":
+ main()
+
+
+
+
+def cmd_ingest(args):
+ path = Path(args.file)
+
+ if not path.exists():
+ print(f"File not found: {path}", file=sys.stderr)
+ sys.exit(1)
+
+
+ db = get_db()
+ stats = start_ingest(db,path)
+ print(f"Ingested file={args.file} :: {stats}")
+
+def cmd_query(args):
+
+ db = get_db()
+ if args.simple:
+ results = vec_search(db, args.query, k=args.k_final)
+ else:
+ results = search_hybrid(
+ db, args.query,
+ k_vec=args.k_vec,
+ k_bm25=args.k_bm25,
+ k_ce=args.k_ce,
+ k_final=args.k_final,
+ use_mmr=args.mmr,
+ mmr_lambda=args.mmr_lambda,
+ )
+
+ for rid, txt, score in results:
+ print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n")
+
+ db.close()
diff --git a/rag/mmr.py b/rag/mmr.py
new file mode 100644
index 0000000..5e47c4f
--- /dev/null
+++ b/rag/mmr.py
@@ -0,0 +1,26 @@
+import numpy as np
+
+def cosine(a, b): return float(np.dot(a, b))
+
+def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7):
+ """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids"""
+ selected, selected_idx = [], []
+ remaining = list(range(len(cand_ids)))
+
+ # seed with the most relevant
+ best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i]))
+ selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0)
+
+ while remaining and len(selected) < k:
+ def mmr_score(i):
+ rel = cosine(query_vec, cand_vecs[i])
+ red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0
+ return lamb * rel - (1.0 - lamb) * red
+ nxt = max(remaining, key=mmr_score)
+ selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt)
+ return selected
+
+def embed_unit_np(st_model, texts: list[str]) -> np.ndarray:
+ V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32)
+ V = V.astype("float32", copy=False)
+ return V
diff --git a/rag/nmain.py b/rag/nmain.py
new file mode 100644
index 0000000..aa67dea
--- /dev/null
+++ b/rag/nmain.py
@@ -0,0 +1,118 @@
+# rag/main.py
+import argparse, sqlite3, sys
+from pathlib import Path
+
+from sentence_transformers import SentenceTransformer
+from rag.ingest import start_ingest
+from rag.search import search_hybrid, search as vec_search
+
+DB_PATH = "./rag.db"
+EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
+
+def open_db():
+ db = sqlite3.connect(DB_PATH)
+ # speed-ish pragmas
+ db.execute("PRAGMA journal_mode=WAL;")
+ db.execute("PRAGMA synchronous=NORMAL;")
+ return db
+
+def load_st_model():
+ # ST handles batching + GPU internally
+ return SentenceTransformer(
+ EMBED_MODEL_ID,
+ model_kwargs={
+ "attn_implementation": "flash_attention_2",
+ "device_map": "auto",
+ "torch_dtype": "float16",
+ },
+ tokenizer_kwargs={"padding_side": "left"},
+ )
+
+def ensure_collection_exists(db, collection: str):
+ row = db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name=?",
+ (f"vec_{collection}",)
+ ).fetchone()
+ return bool(row)
+
+def cmd_ingest(args):
+ db = open_db()
+ st_model = load_st_model()
+ path = Path(args.file)
+
+ if not path.exists():
+ print(f"File not found: {path}", file=sys.stderr)
+ sys.exit(1)
+
+ if args.rebuild and ensure_collection_exists(db, args.collection):
+ db.executescript(f"""
+ DROP TABLE IF EXISTS chunks_{args.collection};
+ DROP TABLE IF EXISTS fts_{args.collection};
+ DROP TABLE IF EXISTS vec_{args.collection};
+ """)
+
+ stats = start_ingest(
+ db, st_model,
+ path=path,
+ collection=args.collection,
+ )
+ print(f"Ingested collection={args.collection} :: {stats}")
+ db.close()
+
+def cmd_query(args):
+ db = open_db()
+ st_model = load_st_model()
+
+ coll_ok = ensure_collection_exists(db, args.collection)
+ if not coll_ok:
+ print(f"Collection '{args.collection}' not found. Ingest first.", file=sys.stderr)
+ sys.exit(2)
+
+ if args.simple:
+ results = vec_search(db, st_model, args.query, collection=args.collection, k=args.k_final)
+ else:
+ results = search_hybrid(
+ db, st_model, args.query,
+ collection=args.collection,
+ k_vec=args.k_vec,
+ k_bm25=args.k_bm25,
+ k_ce=args.k_ce,
+ k_final=args.k_final,
+ use_mmr=args.mmr,
+ mmr_lambda=args.mmr_lambda,
+ )
+
+ for rid, txt, score in results:
+ print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n")
+
+ db.close()
+
+def main():
+ ap = argparse.ArgumentParser(prog="rag")
+ sub = ap.add_subparsers(dest="cmd", required=True)
+
+ # ingest
+ ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection")
+ ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest")
+ ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)")
+ ap_ing.add_argument("--rebuild", action="store_true", help="Drop and recreate collection tables")
+ ap_ing.set_defaults(func=cmd_ingest)
+
+ # query
+ ap_q = sub.add_parser("query", help="Query a collection")
+ ap_q.add_argument("--collection", required=True, help="Collection name to search")
+ ap_q.add_argument("--query", required=True, help="User query text")
+ ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)")
+ ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE")
+ ap_q.add_argument("--mmr-lambda", type=float, default=0.7)
+ ap_q.add_argument("--k-vec", type=int, default=50)
+ ap_q.add_argument("--k-bm25", type=int, default=50)
+ ap_q.add_argument("--k-ce", type=int, default=30)
+ ap_q.add_argument("--k-final", type=int, default=10)
+ ap_q.set_defaults(func=cmd_query)
+
+ args = ap.parse_args()
+ args.func(args)
+
+if __name__ == "__main__":
+ main()
diff --git a/rag/rerank.py b/rag/rerank.py
new file mode 100644
index 0000000..6ae8938
--- /dev/null
+++ b/rag/rerank.py
@@ -0,0 +1,31 @@
+# rag/rerank.py
+import torch
+from sentence_transformers import CrossEncoder
+
+RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM
+# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
+# device: "cuda" | "cpu" | "mps"
+def get_rerank_model():
+ return CrossEncoder(
+ RERANKER_ID,
+ device="cuda",
+ model_kwargs={
+ "attn_implementation":"flash_attention_2",
+ "device_map":"auto",
+ "dtype":torch.float16
+ },
+ tokenizer_kwargs={"padding_side": "left"}
+ )
+
+def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = 32):
+ """
+ candidates: [(id, text), ...]
+ returns: [(id, text, score)] sorted desc by score
+ """
+ if not candidates:
+ return []
+ ids, texts = zip(*candidates)
+ pairs = [(query, t) for t in texts]
+ scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better
+ ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
+ return ranked
diff --git a/rag/search.py b/rag/search.py
new file mode 100644
index 0000000..55b8ffd
--- /dev/null
+++ b/rag/search.py
@@ -0,0 +1,101 @@
+import sqlite3
+import numpy as np
+import torch
+import gc
+from typing import List, Tuple
+
+from sentence_transformers import CrossEncoder, SentenceTransformer
+from rag.ingest import get_embed_model
+from rag.rerank import get_rerank_model, rerank_cross_encoder
+from rag.mmr import mmr, embed_unit_np # if you added MMR; else remove
+
+def search_hybrid(
+ db: sqlite3.Connection, query: str,
+ k_vec: int = 40, k_bm25: int = 40,
+ k_ce: int = 30, # rerank this many
+ k_final: int = 10, # return this many
+ use_mmr: bool = False, mmr_lambda: float = 0.7
+):
+ emodel = get_embed_model()
+ # 1) embed query (unit, float32) for vec search + MMR
+ print("loading")
+ q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
+ emodel = None
+ with torch.no_grad():
+ torch.cuda.empty_cache()
+ gc.collect()
+ print("memory should be free by now!!")
+ qbytes = q.tobytes()
+
+ # 2) ANN (cosine) + BM25 pools
+ vec_ids = [i for (i, _) in db.execute(
+ "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ (memoryview(qbytes), k_vec)
+ ).fetchall()]
+ bm25_ids = [i for (i,) in db.execute(
+ "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?",
+ (query, k_bm25)
+ ).fetchall()]
+
+ # 3) merge (vector-first)
+ seen, merged = set(), []
+ for i in vec_ids + bm25_ids:
+ if i not in seen:
+ merged.append(i); seen.add(i)
+ if not merged:
+ return []
+
+ # 4) fetch texts for CE
+ qmarks = ",".join("?"*len(merged))
+ cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall()
+
+ reranker = get_rerank_model()
+ # 5) cross-encoder rerank (returns [(id,text,score)] desc)
+ ranked = rerank_cross_encoder(reranker, query, cand)
+ reranker = None
+ print("freeing again!!")
+ with torch.no_grad():
+ torch.cuda.empty_cache()
+ gc.collect()
+ #
+ ranked = ranked[:min(k_ce, len(ranked))]
+
+ if not use_mmr or len(ranked) <= k_final:
+ return ranked[:k_final]
+
+ # 6) MMR diversity on CE top-k_ce
+ cand_ids = [i for (i,_,_) in ranked]
+ cand_text = [t for (_,t,_) in ranked]
+ emodel = get_embed_model()
+ # god this is annoying I should stop being poor
+ cand_vecs = embed_unit_np(emodel, cand_text) # [N,D], unit vectors
+ sel_ids = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda))
+ final = [trip for trip in ranked if trip[0] in sel_ids] # keep CE order, filter by MMR picks
+ return final[:k_final]
+
+
+# ) Query helper (cosine distance; operator may be <#> in sqlite-vec)
+def vec_search(db: sqlite3.Connection, qtext, k=5):
+ model = get_embed_model()
+ q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
+ # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
+ rows = db.execute(
+ "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ (memoryview(q.tobytes()), k)
+ ).fetchall()
+
+# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
+ return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
+
+
+# # Hybrid + CE rerank query:
+# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8)
+# for rid, txt, score in results:
+# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n")
+#
+#
+#
+#
+#
+#
+#
diff --git a/rag/test.py b/rag/test.py
new file mode 100644
index 0000000..b7a6d8e
--- /dev/null
+++ b/rag/test.py
@@ -0,0 +1,77 @@
+from pathlib import Path
+from docling.document_converter import DocumentConverter
+from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
+from transformers import AutoModel
+from sentence_transformers import SentenceTransformer
+from rag.db import get_db
+from rag.rerank import get_rerank_model
+import rag.ingest
+import rag.search
+import torch
+
+# converter = DocumentConverter()
+# chunker = HybridChunker()
+# file = Path("yek.md")
+# doc = converter.convert(file).document
+# chunk_iter = chunker.chunk(doc)
+# for chunk in chunk_iter:
+# print(chunk)
+# txt = chunker.contextualize(chunk)
+# print(txt)
+
+
+
+def t():
+ batch: list[str] = ["This son of a bitch has gone too far", "Fuck me baby please", "I'm hungry now", "Charlie Kirk is dead"]
+ model = rag.ingest.get_embed_model()
+ v = model.encode("pepeee", normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True)
+ print("v")
+ print(type(v))
+ print(v)
+ print(v.dtype)
+ print(v.device)
+
+# V = torch.cat([v], dim=0)
+# print("V")
+# print(type(V))
+# print(V)
+# print(V.dtype)
+# print(V.device)
+# print("V_np")
+# V_idk = V.cpu().float()
+
+# when they were pytorch tensors
+# V = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB:
+# V_np = V.float().cpu().numpy().astype("float32")
+# DIM = V_np.shape[1]
+# db = sqlite3.connect("./rag.db")
+
+queries = [
+"How was Shuihu zhuan received in early modern Japan?",
+"Edo-period readers’ image of Song Jiang / Liangshan outlaws",
+"Channels of transmission for Chinese vernacular fiction into Japan (kanbun kundoku, digests, translations)",
+"Role of woodblock prints/illustrations in mediating Chinese fiction",
+"Key Japanese scholars, writers, or publishers who popularized Chinese fiction",
+"Kyokutei Bakin’s engagement with Chinese vernacular narrative",
+"Santō Kyōden, gesaku, and Chinese models",
+"Kanzen chōaku (encourage good, punish evil) and Water Margin in Japan",
+"Moral ambivalence of outlaw heroes as discussed in the text",
+"Censorship or moral debates around reading Chinese fiction",
+"Translation strategies from vernacular Chinese to Japanese (furigana, kundoku, glossing)",
+"Paratexts: prefaces, commentaries, reader guidance apparatus",
+"Bibliographic details: editions, reprints, circulation networks",
+"How does this book challenge older narratives about Sino-Japanese literary influence?",
+"Methodology: sources, archives, limitations mentioned by the author",
+
+]
+def t2():
+ db = get_db()
+ # Hybrid + CE rerank query:
+ for query in queries:
+ print("query", query)
+ results = rag.search.search_hybrid(db, query, k_vec=50, k_bm25=50, k_final=8)
+ for rid, txt, score in results:
+ sim = score
+ print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n")
+
+t2()