diff options
author | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
---|---|---|
committer | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
commit | 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (patch) | |
tree | 1a7556927bed94377630d33dd29c3bf07d159619 |
init
-rw-r--r-- | .gitignore | 4 | ||||
-rw-r--r-- | NOTES.md | 22 | ||||
-rw-r--r-- | flake.lock | 48 | ||||
-rw-r--r-- | flake.nix | 96 | ||||
-rw-r--r-- | knobs.md | 118 | ||||
-rw-r--r-- | rag/__init__.py | 0 | ||||
-rw-r--r-- | rag/db.py | 55 | ||||
-rw-r--r-- | rag/ingest.py | 61 | ||||
-rw-r--r-- | rag/main.py | 71 | ||||
-rw-r--r-- | rag/mmr.py | 26 | ||||
-rw-r--r-- | rag/nmain.py | 118 | ||||
-rw-r--r-- | rag/rerank.py | 31 | ||||
-rw-r--r-- | rag/search.py | 101 | ||||
-rw-r--r-- | rag/test.py | 77 | ||||
-rw-r--r-- | sentence.py | 42 | ||||
-rw-r--r-- | tf.py | 65 | ||||
-rw-r--r-- | yek.md | 316 |
17 files changed, 1251 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0dad9e9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +input +*.db +*.bkp +__pycache__ diff --git a/NOTES.md b/NOTES.md new file mode 100644 index 0000000..85f571b --- /dev/null +++ b/NOTES.md @@ -0,0 +1,22 @@ +We install some stuff on Nix (as much as we can) and the rest with `uv pip`. + +That works well. Which means it installs. But the `venv`, which is in `.venv`, also has a python runtime there, and as we're using the nix store python runtime, it won't see the venv packages. So you gotta do `source .venv/bin/activate` which is pretty weird but it does work. + + +21-9-2025- Long convo with ChatGPT about torch and numpy and tensors and ndarrays. idgi but interesting, should revisit. +Apparently transformers need the with torch.no_grad(): thing all the time but sentence transformers don't. at all + + + + +## DIM creating the DB table is retarded damn you chatgpt +>> he says +TL;DR + +DIM = number of floats per embedding vector (fixed by the model). + +sqlite-vec needs it at table creation. + +You don’t restart the DB every run — you just make sure new embeddings match the stored DIM. + +Best practice: store model+dim in a meta table so you don’t accidentally mix dimensions. diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..731a240 --- /dev/null +++ b/flake.lock @@ -0,0 +1,48 @@ +{ + "nodes": { + "nix-ai-stuff": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1756911915, + "narHash": "sha256-2b+GPPCM3Av2rZyuqALsOhnN2LTDmg6GmqBGUm8x/ww=", + "owner": "BatteredBunny", + "repo": "nix-ai-stuff", + "rev": "84db92a097d2c87234e096b880e685cd6423eb88", + "type": "github" + }, + "original": { + "owner": "BatteredBunny", + "repo": "nix-ai-stuff", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1758277210, + "narHash": "sha256-iCGWf/LTy+aY0zFu8q12lK8KuZp7yvdhStehhyX1v8w=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "8eaee110344796db060382e15d3af0a9fc396e0e", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "nix-ai-stuff": "nix-ai-stuff", + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..61b64d3 --- /dev/null +++ b/flake.nix @@ -0,0 +1,96 @@ +{ + description = "Torch cuda flake using nix-community cachix"; + + nixConfig = { + extra-substituters = [ + "https://nix-community.cachix.org" + "https://nix-ai-stuff.cachix.org" + "https://ai.cachix.org" + "https://cuda-maintainers.cachix.org" + ]; + extra-trusted-public-keys = [ + "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=" + "cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=" + "ai.cachix.org-1:N9dzRK+alWwoKXQlnn0H6aUx0lU/mspIoz8hMvGvbbc=" + "nix-ai-stuff.cachix.org-1:WlUGeVCs26w9xF0/rjyg32PujDqbVMlSHufpj1fqix8=" + ]; + }; + inputs = { + nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; + nix-ai-stuff = { + url = "github:BatteredBunny/nix-ai-stuff"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + }; + + outputs = { + self, + nixpkgs, + nix-ai-stuff, + ... + }: let + pkgs = import nixpkgs { + system = "x86_64-linux"; + config.allowUnfree = true; + config.cudaSupport = true; + }; + in { + devShell.x86_64-linux = with pkgs; + mkShell rec { + venvDir = "./.venv"; + pythonPkgs = pkgs.python3.withPackages ( + ps: + with ps; [ + mypy + ruff + torch + accelerate + transformers + typing-extensions + psutil + ninja + einops + packaging + sentence-transformers + nix-ai-stuff.packages.${pkgs.system}.flash-attn + # docling + uv + ] + ); + + py = with pkgs.python3Packages; [ + mypy + ruff + torch + accelerate + transformers + typing-extensions + psutil + ninja + einops + packaging + sentence-transformers + nix-ai-stuff.packages.${pkgs.system}.flash-attn + # docling + uv + ]; + buildInputs = [ + pyright + pythonPkgs + pkgs.sqlite + pkgs.sqlite-vec + pkgs.virtualenv + pkgs.python3Packages.venvShellHook + ]; + + # LOLPATH = lib.makeSearchPath "" buildInputs; + LOLPATH = lib.makeSearchPath "" py; + postVenvCreation = '' + unset SOURCE_DATE_EPOCH + ''; + shellHook = '' + fish + ''; + }; + }; +} diff --git a/knobs.md b/knobs.md new file mode 100644 index 0000000..0199714 --- /dev/null +++ b/knobs.md @@ -0,0 +1,118 @@ + +Got it. Let’s break down all the knobs in this hybrid + rerank (+ optional MMR) pipeline and what they control. Think of them as dials on your retrieval engine; how you set them shifts the tradeoff between recall, precision, speed, and diversity. + +1. k_vec — how many nearest-neighbor (embedding) hits to pull + +What it does: number of top candidates retrieved by cosine similarity from the vector index (vec0). + +Low values (10–20): very fast, but risky — if the embedding missed some relevant chunks, they’ll never make it into rerank. + +High values (50–100+): better recall, since the CE reranker has a wider pool to choose from. Slower, but vector search is cheap. + +Typical sweet spot: 40–80. Vector ANN is blazing fast, so it’s safe to overshoot. + +2. k_bm25 — how many lexical hits to pull (FTS5) + +What it does: number of chunks fetched with BM25 scoring (keyword-based search). + +Why needed: embeddings sometimes fuzz things too much; BM25 catches exact matches, rare names, technical jargon. + +Low values (10–20): cheap, but may miss keyword-rich relevant hits. + +High values (50–100+): good for “needle in haystack” terms, but can pull lots of noise. + +Typical sweet spot: 30–60. Balances recall with noise. + +3. Merging strategy (vec+bm25) + +Current code: concatenates vector hits then BM25 hits, deduplicates, passes to CE. + +Effect: vector has slight priority, but BM25 ensures coverage. + +Alternative: interleave or weighted merge (future upgrade if you want). + +4. k_ce — how many merged candidates to rerank + +What it does: size of candidate pool fed into the CrossEncoder. + +Why important: CE is expensive — each (query,doc) is a transformer forward pass. + +Low (10–20): very fast, but can miss gems that were just outside the cutoff. + +High (50–100): CE sees more context, better chance to surface true top chunks, but slower (linear in k_ce). + +Ballpark costs: + +bge-reranker-base on GPU: ~2ms per pair. + +k_ce=30 → ~60ms. + +k_ce=100 → ~200ms. + +Typical sweet spot: 20–50. Enough diversity without killing latency. + +5. k_final — how many chunks you actually keep + +What it does: final number of chunks to return for context injection or answer. + +Low (3–5): compact context, but maybe too narrow for complex queries. + +High (15–20): more coverage, but can bloat your prompt and confuse the LLM. + +Typical sweet spot: 8–12. Enough context richness, still fits in a 4k–8k token window easily. + +6. use_mmr — toggle for Maximal Marginal Relevance + +What it does: apply MMR on the CE top-N (e.g. 30) before picking final K. + +Why: rerankers often cluster — you’ll get 5 almost-identical chunks from one section. MMR diversifies. + +Cost: you need vectors for those CE top candidates (either re-embed on the fly or store in DB). Cheap compared to CE. + +When to turn on: long documents where redundancy is high (e.g., laws, academic papers, transcripts). + +When to skip: short docs, or if you want maximum precision and don’t care about duplicates. + +7. mmr_lambda — relevance vs. diversity balance + +Range: 0 → pure diversity, 1 → pure relevance. + +Typical settings: + +0.6 → favors relevance but still kicks out duplicates. + +0.7–0.8 → more focused, just enough diversity. + +0.4–0.5 → exploratory search, less focused but broad coverage. + +Use case: If CE is already precise, set 0.7+. If your doc is redundant, drop closer to 0.5. + +8. Secondary knobs (not in your code yet but worth considering) + +BM25 cutoff / minimum match: require a keyword overlap for lexical candidates. + +Chunk length / overlap: directly affects retriever performance. Shorter chunks = finer retrieval, but noisier. Longer = richer context, but less precise. + +Normalization choice: your pipeline uses cosine (good default). Alternatives: dot-product (works if embeddings are already normalized). + +Practical example + +Let’s say you ask: “How did Japanese scholars engage with Shuihu zhuan?” + +If k_vec=20, k_bm25=20, k_ce=20: CE only sees 40 candidates, may miss the one chapter that actually describes Bakin’s commentary. + +If k_vec=80, k_bm25=50, k_ce=50: CE sees 130 candidates, reranks, and reliably bubbles up the right passage. Latency maybe 150ms, but precision ↑. + +If use_mmr=True, mmr_lambda=0.6: instead of 10 chunks all from the same chapter, you get 10 chunks spread across reception, transmission, and commentary — much better for LLM summarization. + +👉 So the way to think about it: + +k_vec + k_bm25 = recall reservoir (make this generously high). + +k_ce = how much of that reservoir the expensive reranker drinks. + +k_final = how many glasses of water you hand to the LLM. + +use_mmr + mmr_lambda = whether you want those glasses from one pitcher or spread across the table. + +Do you want me to also suggest default knob profiles (like “fast mode”, “balanced mode”, “deep recall mode”) so you can flip between them depending on your use-case? diff --git a/rag/__init__.py b/rag/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/rag/__init__.py diff --git a/rag/db.py b/rag/db.py new file mode 100644 index 0000000..c015c96 --- /dev/null +++ b/rag/db.py @@ -0,0 +1,55 @@ +import sqlite3 +from numpy import ndarray +from sqlite_vec import serialize_float32 + +def get_db(): + db = sqlite3.connect("./rag.db") + db.execute("PRAGMA journal_mode=WAL;") + db.execute("PRAGMA synchronous=NORMAL;") + db.execute("PRAGMA temp_store=MEMORY;") + db.execute("PRAGMA mmap_size=300000000;") # 30GB if your OS allows + db.enable_load_extension(True) + sqlite_vec.load(db) + db.enable_load_extension(False) + return db + +def init_schema(db: sqlite3.Connection, DIM:int): + db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])") + db.execute(''' + CREATE TABLE IF NOT EXISTS chunks ( + id INTEGER PRIMARY KEY, + text TEXT + )''') + db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)") + db.commit() + +def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray): + assert len(chunks) == len(V_np) + db.execute("BEGIN") + db.executemany(''' + INSERT INTO chunks(id, text) VALUES (?, ?) + ''', list(enumerate(chunks, start=1))) + db.executemany( + "INSERT INTO vec(rowid, embedding) VALUES (?, ?)", + [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))] + ) + db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) + db.commit() + +def vec_topk(db, q_vec_f32, k=10): + rows = db.execute( + "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + (serialize_float32(q_vec_f32), k) + ).fetchall() + return rows # [(rowid, distance)] + + +def bm25_topk(db, qtext, k=10): + return [rid for (rid,) in db.execute( + "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k) + ).fetchall()] + + def wipe_db(): + db = sqlite3.connect("./rag.db") + db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;") + db.close() diff --git a/rag/ingest.py b/rag/ingest.py new file mode 100644 index 0000000..d17690a --- /dev/null +++ b/rag/ingest.py @@ -0,0 +1,61 @@ +# from rag.rerank import search_hybrid +import sqlite3 +import torch +from docling_core.transforms.chunker.hybrid_chunker import HybridChunker +from pathlib import Path +from docling.document_converter import DocumentConverter +from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer +from rag.db import get_db, init_schema, store_chunks +from sentence_transformers import SentenceTransformer + +EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" +RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B" +MAX_TOKENS = 600 +BATCH = 16 + + +def get_embed_model(): + return SentenceTransformer( + "Qwen/Qwen3-Embedding-8B", + model_kwargs={ + # "trust_remote_code":True, + "attn_implementation":"flash_attention_2", + "device_map":"auto", + "dtype":torch.float16 + }, tokenizer_kwargs={"padding_side": "left"} + ) + + +def parse_and_chunk(source: Path, model: SentenceTransformer)-> list[str]: + tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS) + + converter = DocumentConverter() + doc = converter.convert(source) + chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True) + out = [] + for ch in chunker.chunk(doc.document): + txt = chunker.contextualize(ch) + if txt.strip(): + out.append(txt) + return out + +def embed_many(model: SentenceTransformer, texts: list[str]): + V_np = model.encode(texts, + batch_size=BATCH, + normalize_embeddings=True, + convert_to_numpy=True, + show_progress_bar=True) + return V_np.astype("float32") + + +def start_ingest(db: sqlite3.Connection, path: Path): + model = get_embed_model() + chunks = parse_and_chunk(path, model) + V_np = embed_many(model, chunks) + DIM = V_np.shape[1] + db = get_db() + init_schema(db, DIM) + store_chunks(db, chunks, V_np) + # TODO some try catch? + return True +# float32/fp16 on CPU? ensure float32 for DB: diff --git a/rag/main.py b/rag/main.py new file mode 100644 index 0000000..221adfa --- /dev/null +++ b/rag/main.py @@ -0,0 +1,71 @@ +import sys +import argparse +from pathlib import Path +from rag.ingest import start_ingest +from rag.search import search_hybrid, vec_search +from rag.db import get_db +# your Docling chunker imports… + +def main(): + ap = argparse.ArgumentParser(prog="rag") + sub = ap.add_subparsers(dest="cmd", required=True) + + # ingest + ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection") + ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest") + ap_ing.set_defaults(func=cmd_ingest) + + # query + ap_q = sub.add_parser("query", help="Query a collection") + ap_q.add_argument("--query", required=True, help="User query text") + ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)") + ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE") + ap_q.add_argument("--mmr-lambda", type=float, default=0.7) + ap_q.add_argument("--k-vec", type=int, default=50) + ap_q.add_argument("--k-bm25", type=int, default=50) + ap_q.add_argument("--k-ce", type=int, default=30) + ap_q.add_argument("--k-final", type=int, default=10) + ap_q.set_defaults(func=cmd_query) + + args = ap.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() + + + + +def cmd_ingest(args): + path = Path(args.file) + + if not path.exists(): + print(f"File not found: {path}", file=sys.stderr) + sys.exit(1) + + + db = get_db() + stats = start_ingest(db,path) + print(f"Ingested file={args.file} :: {stats}") + +def cmd_query(args): + + db = get_db() + if args.simple: + results = vec_search(db, args.query, k=args.k_final) + else: + results = search_hybrid( + db, args.query, + k_vec=args.k_vec, + k_bm25=args.k_bm25, + k_ce=args.k_ce, + k_final=args.k_final, + use_mmr=args.mmr, + mmr_lambda=args.mmr_lambda, + ) + + for rid, txt, score in results: + print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") + + db.close() diff --git a/rag/mmr.py b/rag/mmr.py new file mode 100644 index 0000000..5e47c4f --- /dev/null +++ b/rag/mmr.py @@ -0,0 +1,26 @@ +import numpy as np + +def cosine(a, b): return float(np.dot(a, b)) + +def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7): + """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids""" + selected, selected_idx = [], [] + remaining = list(range(len(cand_ids))) + + # seed with the most relevant + best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i])) + selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0) + + while remaining and len(selected) < k: + def mmr_score(i): + rel = cosine(query_vec, cand_vecs[i]) + red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0 + return lamb * rel - (1.0 - lamb) * red + nxt = max(remaining, key=mmr_score) + selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt) + return selected + +def embed_unit_np(st_model, texts: list[str]) -> np.ndarray: + V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32) + V = V.astype("float32", copy=False) + return V diff --git a/rag/nmain.py b/rag/nmain.py new file mode 100644 index 0000000..aa67dea --- /dev/null +++ b/rag/nmain.py @@ -0,0 +1,118 @@ +# rag/main.py +import argparse, sqlite3, sys +from pathlib import Path + +from sentence_transformers import SentenceTransformer +from rag.ingest import start_ingest +from rag.search import search_hybrid, search as vec_search + +DB_PATH = "./rag.db" +EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" + +def open_db(): + db = sqlite3.connect(DB_PATH) + # speed-ish pragmas + db.execute("PRAGMA journal_mode=WAL;") + db.execute("PRAGMA synchronous=NORMAL;") + return db + +def load_st_model(): + # ST handles batching + GPU internally + return SentenceTransformer( + EMBED_MODEL_ID, + model_kwargs={ + "attn_implementation": "flash_attention_2", + "device_map": "auto", + "torch_dtype": "float16", + }, + tokenizer_kwargs={"padding_side": "left"}, + ) + +def ensure_collection_exists(db, collection: str): + row = db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (f"vec_{collection}",) + ).fetchone() + return bool(row) + +def cmd_ingest(args): + db = open_db() + st_model = load_st_model() + path = Path(args.file) + + if not path.exists(): + print(f"File not found: {path}", file=sys.stderr) + sys.exit(1) + + if args.rebuild and ensure_collection_exists(db, args.collection): + db.executescript(f""" + DROP TABLE IF EXISTS chunks_{args.collection}; + DROP TABLE IF EXISTS fts_{args.collection}; + DROP TABLE IF EXISTS vec_{args.collection}; + """) + + stats = start_ingest( + db, st_model, + path=path, + collection=args.collection, + ) + print(f"Ingested collection={args.collection} :: {stats}") + db.close() + +def cmd_query(args): + db = open_db() + st_model = load_st_model() + + coll_ok = ensure_collection_exists(db, args.collection) + if not coll_ok: + print(f"Collection '{args.collection}' not found. Ingest first.", file=sys.stderr) + sys.exit(2) + + if args.simple: + results = vec_search(db, st_model, args.query, collection=args.collection, k=args.k_final) + else: + results = search_hybrid( + db, st_model, args.query, + collection=args.collection, + k_vec=args.k_vec, + k_bm25=args.k_bm25, + k_ce=args.k_ce, + k_final=args.k_final, + use_mmr=args.mmr, + mmr_lambda=args.mmr_lambda, + ) + + for rid, txt, score in results: + print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") + + db.close() + +def main(): + ap = argparse.ArgumentParser(prog="rag") + sub = ap.add_subparsers(dest="cmd", required=True) + + # ingest + ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection") + ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest") + ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)") + ap_ing.add_argument("--rebuild", action="store_true", help="Drop and recreate collection tables") + ap_ing.set_defaults(func=cmd_ingest) + + # query + ap_q = sub.add_parser("query", help="Query a collection") + ap_q.add_argument("--collection", required=True, help="Collection name to search") + ap_q.add_argument("--query", required=True, help="User query text") + ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)") + ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE") + ap_q.add_argument("--mmr-lambda", type=float, default=0.7) + ap_q.add_argument("--k-vec", type=int, default=50) + ap_q.add_argument("--k-bm25", type=int, default=50) + ap_q.add_argument("--k-ce", type=int, default=30) + ap_q.add_argument("--k-final", type=int, default=10) + ap_q.set_defaults(func=cmd_query) + + args = ap.parse_args() + args.func(args) + +if __name__ == "__main__": + main() diff --git a/rag/rerank.py b/rag/rerank.py new file mode 100644 index 0000000..6ae8938 --- /dev/null +++ b/rag/rerank.py @@ -0,0 +1,31 @@ +# rag/rerank.py +import torch +from sentence_transformers import CrossEncoder + +RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM +# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM +# device: "cuda" | "cpu" | "mps" +def get_rerank_model(): + return CrossEncoder( + RERANKER_ID, + device="cuda", + model_kwargs={ + "attn_implementation":"flash_attention_2", + "device_map":"auto", + "dtype":torch.float16 + }, + tokenizer_kwargs={"padding_side": "left"} + ) + +def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = 32): + """ + candidates: [(id, text), ...] + returns: [(id, text, score)] sorted desc by score + """ + if not candidates: + return [] + ids, texts = zip(*candidates) + pairs = [(query, t) for t in texts] + scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better + ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) + return ranked diff --git a/rag/search.py b/rag/search.py new file mode 100644 index 0000000..55b8ffd --- /dev/null +++ b/rag/search.py @@ -0,0 +1,101 @@ +import sqlite3 +import numpy as np +import torch +import gc +from typing import List, Tuple + +from sentence_transformers import CrossEncoder, SentenceTransformer +from rag.ingest import get_embed_model +from rag.rerank import get_rerank_model, rerank_cross_encoder +from rag.mmr import mmr, embed_unit_np # if you added MMR; else remove + +def search_hybrid( + db: sqlite3.Connection, query: str, + k_vec: int = 40, k_bm25: int = 40, + k_ce: int = 30, # rerank this many + k_final: int = 10, # return this many + use_mmr: bool = False, mmr_lambda: float = 0.7 +): + emodel = get_embed_model() + # 1) embed query (unit, float32) for vec search + MMR + print("loading") + q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") + emodel = None + with torch.no_grad(): + torch.cuda.empty_cache() + gc.collect() + print("memory should be free by now!!") + qbytes = q.tobytes() + + # 2) ANN (cosine) + BM25 pools + vec_ids = [i for (i, _) in db.execute( + "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", + (memoryview(qbytes), k_vec) + ).fetchall()] + bm25_ids = [i for (i,) in db.execute( + "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", + (query, k_bm25) + ).fetchall()] + + # 3) merge (vector-first) + seen, merged = set(), [] + for i in vec_ids + bm25_ids: + if i not in seen: + merged.append(i); seen.add(i) + if not merged: + return [] + + # 4) fetch texts for CE + qmarks = ",".join("?"*len(merged)) + cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall() + + reranker = get_rerank_model() + # 5) cross-encoder rerank (returns [(id,text,score)] desc) + ranked = rerank_cross_encoder(reranker, query, cand) + reranker = None + print("freeing again!!") + with torch.no_grad(): + torch.cuda.empty_cache() + gc.collect() + # + ranked = ranked[:min(k_ce, len(ranked))] + + if not use_mmr or len(ranked) <= k_final: + return ranked[:k_final] + + # 6) MMR diversity on CE top-k_ce + cand_ids = [i for (i,_,_) in ranked] + cand_text = [t for (_,t,_) in ranked] + emodel = get_embed_model() + # god this is annoying I should stop being poor + cand_vecs = embed_unit_np(emodel, cand_text) # [N,D], unit vectors + sel_ids = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda)) + final = [trip for trip in ranked if trip[0] in sel_ids] # keep CE order, filter by MMR picks + return final[:k_final] + + +# ) Query helper (cosine distance; operator may be <#> in sqlite-vec) +def vec_search(db: sqlite3.Connection, qtext, k=5): + model = get_embed_model() + q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32") + # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine + rows = db.execute( + "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", + (memoryview(q.tobytes()), k) + ).fetchall() + +# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40)) + return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows] + + +# # Hybrid + CE rerank query: +# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8) +# for rid, txt, score in results: +# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n") +# +# +# +# +# +# +# diff --git a/rag/test.py b/rag/test.py new file mode 100644 index 0000000..b7a6d8e --- /dev/null +++ b/rag/test.py @@ -0,0 +1,77 @@ +from pathlib import Path +from docling.document_converter import DocumentConverter +from docling_core.transforms.chunker.hybrid_chunker import HybridChunker +from transformers import AutoModel +from sentence_transformers import SentenceTransformer +from rag.db import get_db +from rag.rerank import get_rerank_model +import rag.ingest +import rag.search +import torch + +# converter = DocumentConverter() +# chunker = HybridChunker() +# file = Path("yek.md") +# doc = converter.convert(file).document +# chunk_iter = chunker.chunk(doc) +# for chunk in chunk_iter: +# print(chunk) +# txt = chunker.contextualize(chunk) +# print(txt) + + + +def t(): + batch: list[str] = ["This son of a bitch has gone too far", "Fuck me baby please", "I'm hungry now", "Charlie Kirk is dead"] + model = rag.ingest.get_embed_model() + v = model.encode("pepeee", normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True) + print("v") + print(type(v)) + print(v) + print(v.dtype) + print(v.device) + +# V = torch.cat([v], dim=0) +# print("V") +# print(type(V)) +# print(V) +# print(V.dtype) +# print(V.device) +# print("V_np") +# V_idk = V.cpu().float() + +# when they were pytorch tensors +# V = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB: +# V_np = V.float().cpu().numpy().astype("float32") +# DIM = V_np.shape[1] +# db = sqlite3.connect("./rag.db") + +queries = [ +"How was Shuihu zhuan received in early modern Japan?", +"Edo-period readers’ image of Song Jiang / Liangshan outlaws", +"Channels of transmission for Chinese vernacular fiction into Japan (kanbun kundoku, digests, translations)", +"Role of woodblock prints/illustrations in mediating Chinese fiction", +"Key Japanese scholars, writers, or publishers who popularized Chinese fiction", +"Kyokutei Bakin’s engagement with Chinese vernacular narrative", +"Santō Kyōden, gesaku, and Chinese models", +"Kanzen chōaku (encourage good, punish evil) and Water Margin in Japan", +"Moral ambivalence of outlaw heroes as discussed in the text", +"Censorship or moral debates around reading Chinese fiction", +"Translation strategies from vernacular Chinese to Japanese (furigana, kundoku, glossing)", +"Paratexts: prefaces, commentaries, reader guidance apparatus", +"Bibliographic details: editions, reprints, circulation networks", +"How does this book challenge older narratives about Sino-Japanese literary influence?", +"Methodology: sources, archives, limitations mentioned by the author", + +] +def t2(): + db = get_db() + # Hybrid + CE rerank query: + for query in queries: + print("query", query) + results = rag.search.search_hybrid(db, query, k_vec=50, k_bm25=50, k_final=8) + for rid, txt, score in results: + sim = score + print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n") + +t2() diff --git a/sentence.py b/sentence.py new file mode 100644 index 0000000..f29927a --- /dev/null +++ b/sentence.py @@ -0,0 +1,42 @@ +# Requires transformers>=4.51.0 +# Requires sentence-transformers>=2.7.0 +import torch +from sentence_transformers import SentenceTransformer + +# Load the model +# model = SentenceTransformer("Qwen/Qwen3-Embedding-8B") + +# We recommend enabling flash_attention_2 for better acceleration and memory saving, +# together with setting `padding_side` to "left": +model = SentenceTransformer( + "Qwen/Qwen3-Embedding-8B", + model_kwargs={"attn_implementation": "flash_attention_2", "device_map": "auto", "dtype": torch.float16}, + tokenizer_kwargs={"padding_side": "left"}, +) + + +# Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen3Model is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)` + +# The queries and documents to embed +queries = [ + "What is the capital of China?", + "Explain gravity", +] +documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", +] + +# with torch.autocast(device_type='torch_device'): +with torch.no_grad(): +# Encode the queries and documents. Note that queries benefit from using a prompt +# Here we use the prompt called "query" stored under `model.prompts`, but you can +# also pass your own prompt via the `prompt` argument + query_embeddings = model.encode(queries, prompt_name="query") + document_embeddings = model.encode(documents) + +# Compute the (cosine) similarity between the query and document embeddings +similarity = model.similarity(query_embeddings, document_embeddings) +print(similarity) +# tensor([[0.7493, 0.0751], +# [0.0880, 0.6318]]) @@ -0,0 +1,65 @@ + +import torch +import torch.nn.functional as F + +from torch import Tensor +from transformers import AutoModel, AutoTokenizer + + +def last_token_pool(last_hidden_states: Tensor, + attention_mask: Tensor) -> Tensor: + left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + if left_padding: + return last_hidden_states[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] + + +def get_detailed_instruct(task_description: str, query: str) -> str: + return f'Instruct: {task_description}\nQuery:{query}' + +# Each query must come with a one-sentence instruction that describes the task +task = 'Given a web search query, retrieve relevant passages that answer the query' + +queries = [ + get_detailed_instruct(task, 'What is the capital of China?'), + get_detailed_instruct(task, 'Explain gravity') +] +# No need to add instruction for retrieval documents +documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." +] +input_texts = queries + documents + + +# We recommend enabling flash_attention_2 for better acceleration and memory saving, +# together with setting `padding_side` to "left": +model = AutoModel.from_pretrained( + "Qwen/Qwen3-Embedding-8B", attn_implementation="flash_attention_2", device_map="auto", torch_dtype=torch.float16 +) +tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-8B', padding_side="left") +# The queries and documents to embed +max_length = 8192 + +# Tokenize the input texts +batch_dict = tokenizer( + input_texts, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", +) +batch_dict.to(model.device) +with torch.no_grad(): + outputs = model(**batch_dict) + embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) + + # normalize embeddings + embeddings = F.normalize(embeddings, p=2, dim=1) + scores = (embeddings[:2] @ embeddings[2:].T) + +print(scores.tolist()) +# [[0.7645568251609802, 0.14142508804798126], [0.13549736142158508, 0.5999549627304077]] @@ -0,0 +1,316 @@ +>>>> __init__.py + +>>>> main.py +# from rag.rerank import search_hybrid +import torch +import sqlite3 +import sqlite_vec +from docling_core.transforms.chunker.hybrid_chunker import HybridChunker +from pathlib import Path +from docling.document_converter import DocumentConverter +from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer + +from sentence_transformers import SentenceTransformer + +EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" +RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B" +MAX_TOKENS = 600 +BATCH = 16 + +model = SentenceTransformer( + "Qwen/Qwen3-Embedding-8B", + model_kwargs={ + # "trust_remote_code":True, + "attn_implementation":"flash_attention_2", + "device_map":"auto", + "dtype":torch.float16 + }, tokenizer_kwargs={"padding_side": "left"} +) +tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS) + +def parse_and_chunk()-> list[str]: + source = Path("./japan-water-margin.pdf") + + converter = DocumentConverter() + doc = converter.convert(source) + chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True) + out = [] + for ch in chunker.chunk(doc.document): + txt = chunker.contextualize(ch) + if txt.strip(): + out.append(txt) + return out + +def embed_many(texts: list[str]): + V_np = model.encode(texts, + batch_size=BATCH, + normalize_embeddings=True, + convert_to_numpy=True, + show_progress_bar=True) + return V_np.astype("float32") + + + +chunks =parse_and_chunk() +V_np = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB: +DIM = V_np.shape[1] +db = sqlite3.connect("./rag.db") + + + +def init_db(DIM: str): + db.enable_load_extension(True) + sqlite_vec.load(db) + db.enable_load_extension(False) + db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])") + db.execute(''' + CREATE TABLE IF NOT EXISTS chunks ( + id INTEGER PRIMARY KEY, + text TEXT + )''') + db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)") + +def store_chunks(chunks: list[str], V_np): + assert len(chunks) == len(V_np) + db.execute("BEGIN") + db.executemany(''' + INSERT INTO chunks(id, text) VALUES (?, ?) + ''', list(enumerate(chunks, start=1))) + db.executemany( + "INSERT INTO vec(rowid, embedding) VALUES (?, ?)", + [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))] + ) + db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) + db.commit() + +# 5) Query helper (cosine distance; operator may be <#> in sqlite-vec) +def search(qtext, k=5): + q = model.encode([qtext], normalize=True, convert_to_numpy=True).astype("float32") + # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine + rows = db.execute( + "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", + (memoryview(q.tobytes()), k) + ).fetchall() + +# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40)) + return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows] + + + +init_db(DIM) +store_chunks(chunks, V_np) + + +# # Hybrid + CE rerank query: +# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8) +# for rid, txt, score in results: +# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n") + +>>>> mmr.py +import numpy as np + +def cosine(a, b): return float(np.dot(a, b)) + +def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7): + """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids""" + selected, selected_idx = [], [] + remaining = list(range(len(cand_ids))) + + # seed with the most relevant + best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i])) + selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0) + + while remaining and len(selected) < k: + def mmr_score(i): + rel = cosine(query_vec, cand_vecs[i]) + red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0 + return lamb * rel - (1.0 - lamb) * red + nxt = max(remaining, key=mmr_score) + selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt) + return selected + +>>>> rag.db + +>>>> rerank.py +import torch +import sqlite3 +import sqlite_vec +from transformers import AutoTokenizer, AutoModelForSequenceClassification + +# ----------------------------- +# RERANKER (cross-encoder) +# ----------------------------- +RERANKER_ID = "BAAI/bge-reranker-base" # small & solid +# RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # small & solid +_rr_tok = AutoTokenizer.from_pretrained(RERANKER_ID) +_rr_mod = AutoModelForSequenceClassification.from_pretrained( + RERANKER_ID, device_map="auto", dtype=torch.bfloat16 +) + +def rerank_cross_encoder(query: str, candidates: list[tuple[int, str]], batch_size: int = 32) -> list[tuple[int, str, float]]: + """ + candidates: [(id, text), ...] + returns: [(id, text, score)] sorted desc by score + """ + if not candidates: + return [] + ids, texts = zip(*candidates) + scores = [] + for i in range(0, len(texts), batch_size): + batch = texts[i:i+batch_size] + enc = _rr_tok(list(zip([query]*len(batch), batch)), + padding=True, truncation=True, max_length=512, + return_tensors="pt").to(_rr_mod.device) + with torch.no_grad(): + logits = _rr_mod(**enc).logits.squeeze(-1) # [B] + scores.extend(logits.float().cpu().tolist()) + ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) + return ranked + +# ----------------------------- +# DB: vec + chunks + FTS5 +# ----------------------------- +db = sqlite3.connect("./rag.db") + +def init_db(dim: int): + db.enable_load_extension(True) + sqlite_vec.load(db) # loads vec0 + db.enable_load_extension(False) + + # Vector table (cosine expects float32 blob) + db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{dim}])") + + # Raw chunks + db.execute(""" + CREATE TABLE IF NOT EXISTS chunks( + id INTEGER PRIMARY KEY, + text TEXT + ) + """) + + # FTS5 over chunks.text (contentless mode keeps it simple) + db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)") + db.execute("CREATE INDEX IF NOT EXISTS idx_chunks_id ON chunks(id)") + +def store_chunks(chunks: list[str], V_np): + # chunks: list[str], V_np: np.ndarray float32 [N, D] + assert len(chunks) == len(V_np), "mismatch chunks vs vectors" + + db.execute("BEGIN") + db.executemany("INSERT INTO chunks(id, text) VALUES (?,?)", + list(enumerate(chunks, start=1))) + db.executemany("INSERT INTO vec(rowid, embedding) VALUES (?,?)", + [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]) + # FTS content + db.executemany("INSERT INTO fts(rowid, text) VALUES (?,?)", + list(enumerate(chunks, start=1))) + db.commit() + +# ----------------------------- +# SEARCH: hybrid → rerank +# ----------------------------- +def _vec_topk(qvec_f32: bytes, k: int): + # cosine distance: smaller is better + rows = db.execute( + "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", + (memoryview(qvec_f32), k) + ).fetchall() + return [(rid, float(dist)) for (rid, dist) in rows] + +def _bm25_topk(query: str, k: int): + # note: FTS5 MATCH syntax, tokenizer defaults ok for English/legal + rows = db.execute( + "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", + (query, k) + ).fetchall() + return [rid for (rid,) in rows] + +def _fetch_text(ids: list[int]) -> list[tuple[int, str]]: + if not ids: + return [] + qmarks = ",".join("?"*len(ids)) + rows = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", ids).fetchall() + by_id = {i:t for (i,t) in rows} + # keep input order + return [(i, by_id[i]) for i in ids if i in by_id] + +def search_hybrid(query: str, embed_model, k_vec: int = 40, k_bm25: int = 40, k_final: int = 10): + # 1) embed query (unit-norm) with same model you used for chunks + with torch.no_grad(): + q = embed_model.encode([query], normalize=True).to("cpu").to(torch.float32).numpy()[0] + qbytes = q.tobytes() + + # 2) candidate pools + vec_hits = _vec_topk(qbytes, k_vec) # [(id, dist)] + vec_ids = [i for (i, _) in vec_hits] + bm25_ids = _bm25_topk(query, k_bm25) + + # 3) merge (preserve vector order first) + seen = set() + merged_ids = [] + for i in vec_ids + bm25_ids: + if i not in seen: + merged_ids.append(i); seen.add(i) + + # 4) fetch texts, then rerank with cross-encoder + cand = _fetch_text(merged_ids) + ranked = rerank_cross_encoder(query, cand) + + # 5) return top k_final: [(id, text, ce_score)] + return ranked[:k_final] + +>>>> test.py +from pathlib import Path +from docling.document_converter import DocumentConverter +from docling_core.transforms.chunker.hybrid_chunker import HybridChunker +from transformers import AutoModel +from sentence_transformers import SentenceTransformer +import torch + +# converter = DocumentConverter() +# chunker = HybridChunker() +# file = Path("yek.md") +# doc = converter.convert(file).document +# chunk_iter = chunker.chunk(doc) +# for chunk in chunk_iter: +# print(chunk) +# txt = chunker.contextualize(chunk) +# print(txt) + + + +model = SentenceTransformer( + "Qwen/Qwen3-Embedding-8B", + model_kwargs={ + # "trust_remote_code":True, + "attn_implementation":"flash_attention_2", + "device_map":"auto", + "dtype":torch.float16 + }, tokenizer_kwargs={"padding_side": "left"} +) + +batch: list[str] = ["This son of a bitch has gone too far", "Fuck me baby please", "I'm hungry now", "Charlie Kirk is dead"] + +v = model.encode("pepeee", normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True) +print("v") +print(type(v)) +print(v) +print(v.dtype) +print(v.device) + +# V = torch.cat([v], dim=0) +# print("V") +# print(type(V)) +# print(V) +# print(V.dtype) +# print(V.device) +# print("V_np") +# V_idk = V.cpu().float() + +# when they were pytorch tensors +# V = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB: +# V_np = V.float().cpu().numpy().astype("float32") +# DIM = V_np.shape[1] +# db = sqlite3.connect("./rag.db") + + |