summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore4
-rw-r--r--NOTES.md22
-rw-r--r--flake.lock48
-rw-r--r--flake.nix96
-rw-r--r--knobs.md118
-rw-r--r--rag/__init__.py0
-rw-r--r--rag/db.py55
-rw-r--r--rag/ingest.py61
-rw-r--r--rag/main.py71
-rw-r--r--rag/mmr.py26
-rw-r--r--rag/nmain.py118
-rw-r--r--rag/rerank.py31
-rw-r--r--rag/search.py101
-rw-r--r--rag/test.py77
-rw-r--r--sentence.py42
-rw-r--r--tf.py65
-rw-r--r--yek.md316
17 files changed, 1251 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..0dad9e9
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+input
+*.db
+*.bkp
+__pycache__
diff --git a/NOTES.md b/NOTES.md
new file mode 100644
index 0000000..85f571b
--- /dev/null
+++ b/NOTES.md
@@ -0,0 +1,22 @@
+We install some stuff on Nix (as much as we can) and the rest with `uv pip`.
+
+That works well. Which means it installs. But the `venv`, which is in `.venv`, also has a python runtime there, and as we're using the nix store python runtime, it won't see the venv packages. So you gotta do `source .venv/bin/activate` which is pretty weird but it does work.
+
+
+21-9-2025- Long convo with ChatGPT about torch and numpy and tensors and ndarrays. idgi but interesting, should revisit.
+Apparently transformers need the with torch.no_grad(): thing all the time but sentence transformers don't. at all
+
+
+
+
+## DIM creating the DB table is retarded damn you chatgpt
+>> he says
+TL;DR
+
+DIM = number of floats per embedding vector (fixed by the model).
+
+sqlite-vec needs it at table creation.
+
+You don’t restart the DB every run — you just make sure new embeddings match the stored DIM.
+
+Best practice: store model+dim in a meta table so you don’t accidentally mix dimensions.
diff --git a/flake.lock b/flake.lock
new file mode 100644
index 0000000..731a240
--- /dev/null
+++ b/flake.lock
@@ -0,0 +1,48 @@
+{
+ "nodes": {
+ "nix-ai-stuff": {
+ "inputs": {
+ "nixpkgs": [
+ "nixpkgs"
+ ]
+ },
+ "locked": {
+ "lastModified": 1756911915,
+ "narHash": "sha256-2b+GPPCM3Av2rZyuqALsOhnN2LTDmg6GmqBGUm8x/ww=",
+ "owner": "BatteredBunny",
+ "repo": "nix-ai-stuff",
+ "rev": "84db92a097d2c87234e096b880e685cd6423eb88",
+ "type": "github"
+ },
+ "original": {
+ "owner": "BatteredBunny",
+ "repo": "nix-ai-stuff",
+ "type": "github"
+ }
+ },
+ "nixpkgs": {
+ "locked": {
+ "lastModified": 1758277210,
+ "narHash": "sha256-iCGWf/LTy+aY0zFu8q12lK8KuZp7yvdhStehhyX1v8w=",
+ "owner": "nixos",
+ "repo": "nixpkgs",
+ "rev": "8eaee110344796db060382e15d3af0a9fc396e0e",
+ "type": "github"
+ },
+ "original": {
+ "owner": "nixos",
+ "ref": "nixos-unstable",
+ "repo": "nixpkgs",
+ "type": "github"
+ }
+ },
+ "root": {
+ "inputs": {
+ "nix-ai-stuff": "nix-ai-stuff",
+ "nixpkgs": "nixpkgs"
+ }
+ }
+ },
+ "root": "root",
+ "version": 7
+}
diff --git a/flake.nix b/flake.nix
new file mode 100644
index 0000000..61b64d3
--- /dev/null
+++ b/flake.nix
@@ -0,0 +1,96 @@
+{
+ description = "Torch cuda flake using nix-community cachix";
+
+ nixConfig = {
+ extra-substituters = [
+ "https://nix-community.cachix.org"
+ "https://nix-ai-stuff.cachix.org"
+ "https://ai.cachix.org"
+ "https://cuda-maintainers.cachix.org"
+ ];
+ extra-trusted-public-keys = [
+ "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs="
+ "cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E="
+ "ai.cachix.org-1:N9dzRK+alWwoKXQlnn0H6aUx0lU/mspIoz8hMvGvbbc="
+ "nix-ai-stuff.cachix.org-1:WlUGeVCs26w9xF0/rjyg32PujDqbVMlSHufpj1fqix8="
+ ];
+ };
+ inputs = {
+ nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
+ nix-ai-stuff = {
+ url = "github:BatteredBunny/nix-ai-stuff";
+ inputs.nixpkgs.follows = "nixpkgs";
+ };
+ };
+
+ outputs = {
+ self,
+ nixpkgs,
+ nix-ai-stuff,
+ ...
+ }: let
+ pkgs = import nixpkgs {
+ system = "x86_64-linux";
+ config.allowUnfree = true;
+ config.cudaSupport = true;
+ };
+ in {
+ devShell.x86_64-linux = with pkgs;
+ mkShell rec {
+ venvDir = "./.venv";
+ pythonPkgs = pkgs.python3.withPackages (
+ ps:
+ with ps; [
+ mypy
+ ruff
+ torch
+ accelerate
+ transformers
+ typing-extensions
+ psutil
+ ninja
+ einops
+ packaging
+ sentence-transformers
+ nix-ai-stuff.packages.${pkgs.system}.flash-attn
+ # docling
+ uv
+ ]
+ );
+
+ py = with pkgs.python3Packages; [
+ mypy
+ ruff
+ torch
+ accelerate
+ transformers
+ typing-extensions
+ psutil
+ ninja
+ einops
+ packaging
+ sentence-transformers
+ nix-ai-stuff.packages.${pkgs.system}.flash-attn
+ # docling
+ uv
+ ];
+ buildInputs = [
+ pyright
+ pythonPkgs
+ pkgs.sqlite
+ pkgs.sqlite-vec
+ pkgs.virtualenv
+ pkgs.python3Packages.venvShellHook
+ ];
+
+ # LOLPATH = lib.makeSearchPath "" buildInputs;
+ LOLPATH = lib.makeSearchPath "" py;
+ postVenvCreation = ''
+ unset SOURCE_DATE_EPOCH
+ '';
+ shellHook = ''
+ fish
+ '';
+ };
+ };
+}
diff --git a/knobs.md b/knobs.md
new file mode 100644
index 0000000..0199714
--- /dev/null
+++ b/knobs.md
@@ -0,0 +1,118 @@
+
+Got it. Let’s break down all the knobs in this hybrid + rerank (+ optional MMR) pipeline and what they control. Think of them as dials on your retrieval engine; how you set them shifts the tradeoff between recall, precision, speed, and diversity.
+
+1. k_vec — how many nearest-neighbor (embedding) hits to pull
+
+What it does: number of top candidates retrieved by cosine similarity from the vector index (vec0).
+
+Low values (10–20): very fast, but risky — if the embedding missed some relevant chunks, they’ll never make it into rerank.
+
+High values (50–100+): better recall, since the CE reranker has a wider pool to choose from. Slower, but vector search is cheap.
+
+Typical sweet spot: 40–80. Vector ANN is blazing fast, so it’s safe to overshoot.
+
+2. k_bm25 — how many lexical hits to pull (FTS5)
+
+What it does: number of chunks fetched with BM25 scoring (keyword-based search).
+
+Why needed: embeddings sometimes fuzz things too much; BM25 catches exact matches, rare names, technical jargon.
+
+Low values (10–20): cheap, but may miss keyword-rich relevant hits.
+
+High values (50–100+): good for “needle in haystack” terms, but can pull lots of noise.
+
+Typical sweet spot: 30–60. Balances recall with noise.
+
+3. Merging strategy (vec+bm25)
+
+Current code: concatenates vector hits then BM25 hits, deduplicates, passes to CE.
+
+Effect: vector has slight priority, but BM25 ensures coverage.
+
+Alternative: interleave or weighted merge (future upgrade if you want).
+
+4. k_ce — how many merged candidates to rerank
+
+What it does: size of candidate pool fed into the CrossEncoder.
+
+Why important: CE is expensive — each (query,doc) is a transformer forward pass.
+
+Low (10–20): very fast, but can miss gems that were just outside the cutoff.
+
+High (50–100): CE sees more context, better chance to surface true top chunks, but slower (linear in k_ce).
+
+Ballpark costs:
+
+bge-reranker-base on GPU: ~2ms per pair.
+
+k_ce=30 → ~60ms.
+
+k_ce=100 → ~200ms.
+
+Typical sweet spot: 20–50. Enough diversity without killing latency.
+
+5. k_final — how many chunks you actually keep
+
+What it does: final number of chunks to return for context injection or answer.
+
+Low (3–5): compact context, but maybe too narrow for complex queries.
+
+High (15–20): more coverage, but can bloat your prompt and confuse the LLM.
+
+Typical sweet spot: 8–12. Enough context richness, still fits in a 4k–8k token window easily.
+
+6. use_mmr — toggle for Maximal Marginal Relevance
+
+What it does: apply MMR on the CE top-N (e.g. 30) before picking final K.
+
+Why: rerankers often cluster — you’ll get 5 almost-identical chunks from one section. MMR diversifies.
+
+Cost: you need vectors for those CE top candidates (either re-embed on the fly or store in DB). Cheap compared to CE.
+
+When to turn on: long documents where redundancy is high (e.g., laws, academic papers, transcripts).
+
+When to skip: short docs, or if you want maximum precision and don’t care about duplicates.
+
+7. mmr_lambda — relevance vs. diversity balance
+
+Range: 0 → pure diversity, 1 → pure relevance.
+
+Typical settings:
+
+0.6 → favors relevance but still kicks out duplicates.
+
+0.7–0.8 → more focused, just enough diversity.
+
+0.4–0.5 → exploratory search, less focused but broad coverage.
+
+Use case: If CE is already precise, set 0.7+. If your doc is redundant, drop closer to 0.5.
+
+8. Secondary knobs (not in your code yet but worth considering)
+
+BM25 cutoff / minimum match: require a keyword overlap for lexical candidates.
+
+Chunk length / overlap: directly affects retriever performance. Shorter chunks = finer retrieval, but noisier. Longer = richer context, but less precise.
+
+Normalization choice: your pipeline uses cosine (good default). Alternatives: dot-product (works if embeddings are already normalized).
+
+Practical example
+
+Let’s say you ask: “How did Japanese scholars engage with Shuihu zhuan?”
+
+If k_vec=20, k_bm25=20, k_ce=20: CE only sees 40 candidates, may miss the one chapter that actually describes Bakin’s commentary.
+
+If k_vec=80, k_bm25=50, k_ce=50: CE sees 130 candidates, reranks, and reliably bubbles up the right passage. Latency maybe 150ms, but precision ↑.
+
+If use_mmr=True, mmr_lambda=0.6: instead of 10 chunks all from the same chapter, you get 10 chunks spread across reception, transmission, and commentary — much better for LLM summarization.
+
+👉 So the way to think about it:
+
+k_vec + k_bm25 = recall reservoir (make this generously high).
+
+k_ce = how much of that reservoir the expensive reranker drinks.
+
+k_final = how many glasses of water you hand to the LLM.
+
+use_mmr + mmr_lambda = whether you want those glasses from one pitcher or spread across the table.
+
+Do you want me to also suggest default knob profiles (like “fast mode”, “balanced mode”, “deep recall mode”) so you can flip between them depending on your use-case?
diff --git a/rag/__init__.py b/rag/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rag/__init__.py
diff --git a/rag/db.py b/rag/db.py
new file mode 100644
index 0000000..c015c96
--- /dev/null
+++ b/rag/db.py
@@ -0,0 +1,55 @@
+import sqlite3
+from numpy import ndarray
+from sqlite_vec import serialize_float32
+
+def get_db():
+ db = sqlite3.connect("./rag.db")
+ db.execute("PRAGMA journal_mode=WAL;")
+ db.execute("PRAGMA synchronous=NORMAL;")
+ db.execute("PRAGMA temp_store=MEMORY;")
+ db.execute("PRAGMA mmap_size=300000000;") # 30GB if your OS allows
+ db.enable_load_extension(True)
+ sqlite_vec.load(db)
+ db.enable_load_extension(False)
+ return db
+
+def init_schema(db: sqlite3.Connection, DIM:int):
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])")
+ db.execute('''
+ CREATE TABLE IF NOT EXISTS chunks (
+ id INTEGER PRIMARY KEY,
+ text TEXT
+ )''')
+ db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)")
+ db.commit()
+
+def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray):
+ assert len(chunks) == len(V_np)
+ db.execute("BEGIN")
+ db.executemany('''
+ INSERT INTO chunks(id, text) VALUES (?, ?)
+ ''', list(enumerate(chunks, start=1)))
+ db.executemany(
+ "INSERT INTO vec(rowid, embedding) VALUES (?, ?)",
+ [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]
+ )
+ db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
+ db.commit()
+
+def vec_topk(db, q_vec_f32, k=10):
+ rows = db.execute(
+ "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
+ (serialize_float32(q_vec_f32), k)
+ ).fetchall()
+ return rows # [(rowid, distance)]
+
+
+def bm25_topk(db, qtext, k=10):
+ return [rid for (rid,) in db.execute(
+ "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k)
+ ).fetchall()]
+
+ def wipe_db():
+ db = sqlite3.connect("./rag.db")
+ db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;")
+ db.close()
diff --git a/rag/ingest.py b/rag/ingest.py
new file mode 100644
index 0000000..d17690a
--- /dev/null
+++ b/rag/ingest.py
@@ -0,0 +1,61 @@
+# from rag.rerank import search_hybrid
+import sqlite3
+import torch
+from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
+from pathlib import Path
+from docling.document_converter import DocumentConverter
+from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
+from rag.db import get_db, init_schema, store_chunks
+from sentence_transformers import SentenceTransformer
+
+EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
+RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B"
+MAX_TOKENS = 600
+BATCH = 16
+
+
+def get_embed_model():
+ return SentenceTransformer(
+ "Qwen/Qwen3-Embedding-8B",
+ model_kwargs={
+ # "trust_remote_code":True,
+ "attn_implementation":"flash_attention_2",
+ "device_map":"auto",
+ "dtype":torch.float16
+ }, tokenizer_kwargs={"padding_side": "left"}
+ )
+
+
+def parse_and_chunk(source: Path, model: SentenceTransformer)-> list[str]:
+ tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS)
+
+ converter = DocumentConverter()
+ doc = converter.convert(source)
+ chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True)
+ out = []
+ for ch in chunker.chunk(doc.document):
+ txt = chunker.contextualize(ch)
+ if txt.strip():
+ out.append(txt)
+ return out
+
+def embed_many(model: SentenceTransformer, texts: list[str]):
+ V_np = model.encode(texts,
+ batch_size=BATCH,
+ normalize_embeddings=True,
+ convert_to_numpy=True,
+ show_progress_bar=True)
+ return V_np.astype("float32")
+
+
+def start_ingest(db: sqlite3.Connection, path: Path):
+ model = get_embed_model()
+ chunks = parse_and_chunk(path, model)
+ V_np = embed_many(model, chunks)
+ DIM = V_np.shape[1]
+ db = get_db()
+ init_schema(db, DIM)
+ store_chunks(db, chunks, V_np)
+ # TODO some try catch?
+ return True
+# float32/fp16 on CPU? ensure float32 for DB:
diff --git a/rag/main.py b/rag/main.py
new file mode 100644
index 0000000..221adfa
--- /dev/null
+++ b/rag/main.py
@@ -0,0 +1,71 @@
+import sys
+import argparse
+from pathlib import Path
+from rag.ingest import start_ingest
+from rag.search import search_hybrid, vec_search
+from rag.db import get_db
+# your Docling chunker imports…
+
+def main():
+ ap = argparse.ArgumentParser(prog="rag")
+ sub = ap.add_subparsers(dest="cmd", required=True)
+
+ # ingest
+ ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection")
+ ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest")
+ ap_ing.set_defaults(func=cmd_ingest)
+
+ # query
+ ap_q = sub.add_parser("query", help="Query a collection")
+ ap_q.add_argument("--query", required=True, help="User query text")
+ ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)")
+ ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE")
+ ap_q.add_argument("--mmr-lambda", type=float, default=0.7)
+ ap_q.add_argument("--k-vec", type=int, default=50)
+ ap_q.add_argument("--k-bm25", type=int, default=50)
+ ap_q.add_argument("--k-ce", type=int, default=30)
+ ap_q.add_argument("--k-final", type=int, default=10)
+ ap_q.set_defaults(func=cmd_query)
+
+ args = ap.parse_args()
+ args.func(args)
+
+
+if __name__ == "__main__":
+ main()
+
+
+
+
+def cmd_ingest(args):
+ path = Path(args.file)
+
+ if not path.exists():
+ print(f"File not found: {path}", file=sys.stderr)
+ sys.exit(1)
+
+
+ db = get_db()
+ stats = start_ingest(db,path)
+ print(f"Ingested file={args.file} :: {stats}")
+
+def cmd_query(args):
+
+ db = get_db()
+ if args.simple:
+ results = vec_search(db, args.query, k=args.k_final)
+ else:
+ results = search_hybrid(
+ db, args.query,
+ k_vec=args.k_vec,
+ k_bm25=args.k_bm25,
+ k_ce=args.k_ce,
+ k_final=args.k_final,
+ use_mmr=args.mmr,
+ mmr_lambda=args.mmr_lambda,
+ )
+
+ for rid, txt, score in results:
+ print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n")
+
+ db.close()
diff --git a/rag/mmr.py b/rag/mmr.py
new file mode 100644
index 0000000..5e47c4f
--- /dev/null
+++ b/rag/mmr.py
@@ -0,0 +1,26 @@
+import numpy as np
+
+def cosine(a, b): return float(np.dot(a, b))
+
+def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7):
+ """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids"""
+ selected, selected_idx = [], []
+ remaining = list(range(len(cand_ids)))
+
+ # seed with the most relevant
+ best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i]))
+ selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0)
+
+ while remaining and len(selected) < k:
+ def mmr_score(i):
+ rel = cosine(query_vec, cand_vecs[i])
+ red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0
+ return lamb * rel - (1.0 - lamb) * red
+ nxt = max(remaining, key=mmr_score)
+ selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt)
+ return selected
+
+def embed_unit_np(st_model, texts: list[str]) -> np.ndarray:
+ V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32)
+ V = V.astype("float32", copy=False)
+ return V
diff --git a/rag/nmain.py b/rag/nmain.py
new file mode 100644
index 0000000..aa67dea
--- /dev/null
+++ b/rag/nmain.py
@@ -0,0 +1,118 @@
+# rag/main.py
+import argparse, sqlite3, sys
+from pathlib import Path
+
+from sentence_transformers import SentenceTransformer
+from rag.ingest import start_ingest
+from rag.search import search_hybrid, search as vec_search
+
+DB_PATH = "./rag.db"
+EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
+
+def open_db():
+ db = sqlite3.connect(DB_PATH)
+ # speed-ish pragmas
+ db.execute("PRAGMA journal_mode=WAL;")
+ db.execute("PRAGMA synchronous=NORMAL;")
+ return db
+
+def load_st_model():
+ # ST handles batching + GPU internally
+ return SentenceTransformer(
+ EMBED_MODEL_ID,
+ model_kwargs={
+ "attn_implementation": "flash_attention_2",
+ "device_map": "auto",
+ "torch_dtype": "float16",
+ },
+ tokenizer_kwargs={"padding_side": "left"},
+ )
+
+def ensure_collection_exists(db, collection: str):
+ row = db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name=?",
+ (f"vec_{collection}",)
+ ).fetchone()
+ return bool(row)
+
+def cmd_ingest(args):
+ db = open_db()
+ st_model = load_st_model()
+ path = Path(args.file)
+
+ if not path.exists():
+ print(f"File not found: {path}", file=sys.stderr)
+ sys.exit(1)
+
+ if args.rebuild and ensure_collection_exists(db, args.collection):
+ db.executescript(f"""
+ DROP TABLE IF EXISTS chunks_{args.collection};
+ DROP TABLE IF EXISTS fts_{args.collection};
+ DROP TABLE IF EXISTS vec_{args.collection};
+ """)
+
+ stats = start_ingest(
+ db, st_model,
+ path=path,
+ collection=args.collection,
+ )
+ print(f"Ingested collection={args.collection} :: {stats}")
+ db.close()
+
+def cmd_query(args):
+ db = open_db()
+ st_model = load_st_model()
+
+ coll_ok = ensure_collection_exists(db, args.collection)
+ if not coll_ok:
+ print(f"Collection '{args.collection}' not found. Ingest first.", file=sys.stderr)
+ sys.exit(2)
+
+ if args.simple:
+ results = vec_search(db, st_model, args.query, collection=args.collection, k=args.k_final)
+ else:
+ results = search_hybrid(
+ db, st_model, args.query,
+ collection=args.collection,
+ k_vec=args.k_vec,
+ k_bm25=args.k_bm25,
+ k_ce=args.k_ce,
+ k_final=args.k_final,
+ use_mmr=args.mmr,
+ mmr_lambda=args.mmr_lambda,
+ )
+
+ for rid, txt, score in results:
+ print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n")
+
+ db.close()
+
+def main():
+ ap = argparse.ArgumentParser(prog="rag")
+ sub = ap.add_subparsers(dest="cmd", required=True)
+
+ # ingest
+ ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection")
+ ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest")
+ ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)")
+ ap_ing.add_argument("--rebuild", action="store_true", help="Drop and recreate collection tables")
+ ap_ing.set_defaults(func=cmd_ingest)
+
+ # query
+ ap_q = sub.add_parser("query", help="Query a collection")
+ ap_q.add_argument("--collection", required=True, help="Collection name to search")
+ ap_q.add_argument("--query", required=True, help="User query text")
+ ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)")
+ ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE")
+ ap_q.add_argument("--mmr-lambda", type=float, default=0.7)
+ ap_q.add_argument("--k-vec", type=int, default=50)
+ ap_q.add_argument("--k-bm25", type=int, default=50)
+ ap_q.add_argument("--k-ce", type=int, default=30)
+ ap_q.add_argument("--k-final", type=int, default=10)
+ ap_q.set_defaults(func=cmd_query)
+
+ args = ap.parse_args()
+ args.func(args)
+
+if __name__ == "__main__":
+ main()
diff --git a/rag/rerank.py b/rag/rerank.py
new file mode 100644
index 0000000..6ae8938
--- /dev/null
+++ b/rag/rerank.py
@@ -0,0 +1,31 @@
+# rag/rerank.py
+import torch
+from sentence_transformers import CrossEncoder
+
+RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM
+# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
+# device: "cuda" | "cpu" | "mps"
+def get_rerank_model():
+ return CrossEncoder(
+ RERANKER_ID,
+ device="cuda",
+ model_kwargs={
+ "attn_implementation":"flash_attention_2",
+ "device_map":"auto",
+ "dtype":torch.float16
+ },
+ tokenizer_kwargs={"padding_side": "left"}
+ )
+
+def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = 32):
+ """
+ candidates: [(id, text), ...]
+ returns: [(id, text, score)] sorted desc by score
+ """
+ if not candidates:
+ return []
+ ids, texts = zip(*candidates)
+ pairs = [(query, t) for t in texts]
+ scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better
+ ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
+ return ranked
diff --git a/rag/search.py b/rag/search.py
new file mode 100644
index 0000000..55b8ffd
--- /dev/null
+++ b/rag/search.py
@@ -0,0 +1,101 @@
+import sqlite3
+import numpy as np
+import torch
+import gc
+from typing import List, Tuple
+
+from sentence_transformers import CrossEncoder, SentenceTransformer
+from rag.ingest import get_embed_model
+from rag.rerank import get_rerank_model, rerank_cross_encoder
+from rag.mmr import mmr, embed_unit_np # if you added MMR; else remove
+
+def search_hybrid(
+ db: sqlite3.Connection, query: str,
+ k_vec: int = 40, k_bm25: int = 40,
+ k_ce: int = 30, # rerank this many
+ k_final: int = 10, # return this many
+ use_mmr: bool = False, mmr_lambda: float = 0.7
+):
+ emodel = get_embed_model()
+ # 1) embed query (unit, float32) for vec search + MMR
+ print("loading")
+ q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
+ emodel = None
+ with torch.no_grad():
+ torch.cuda.empty_cache()
+ gc.collect()
+ print("memory should be free by now!!")
+ qbytes = q.tobytes()
+
+ # 2) ANN (cosine) + BM25 pools
+ vec_ids = [i for (i, _) in db.execute(
+ "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ (memoryview(qbytes), k_vec)
+ ).fetchall()]
+ bm25_ids = [i for (i,) in db.execute(
+ "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?",
+ (query, k_bm25)
+ ).fetchall()]
+
+ # 3) merge (vector-first)
+ seen, merged = set(), []
+ for i in vec_ids + bm25_ids:
+ if i not in seen:
+ merged.append(i); seen.add(i)
+ if not merged:
+ return []
+
+ # 4) fetch texts for CE
+ qmarks = ",".join("?"*len(merged))
+ cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall()
+
+ reranker = get_rerank_model()
+ # 5) cross-encoder rerank (returns [(id,text,score)] desc)
+ ranked = rerank_cross_encoder(reranker, query, cand)
+ reranker = None
+ print("freeing again!!")
+ with torch.no_grad():
+ torch.cuda.empty_cache()
+ gc.collect()
+ #
+ ranked = ranked[:min(k_ce, len(ranked))]
+
+ if not use_mmr or len(ranked) <= k_final:
+ return ranked[:k_final]
+
+ # 6) MMR diversity on CE top-k_ce
+ cand_ids = [i for (i,_,_) in ranked]
+ cand_text = [t for (_,t,_) in ranked]
+ emodel = get_embed_model()
+ # god this is annoying I should stop being poor
+ cand_vecs = embed_unit_np(emodel, cand_text) # [N,D], unit vectors
+ sel_ids = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda))
+ final = [trip for trip in ranked if trip[0] in sel_ids] # keep CE order, filter by MMR picks
+ return final[:k_final]
+
+
+# ) Query helper (cosine distance; operator may be <#> in sqlite-vec)
+def vec_search(db: sqlite3.Connection, qtext, k=5):
+ model = get_embed_model()
+ q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
+ # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
+ rows = db.execute(
+ "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ (memoryview(q.tobytes()), k)
+ ).fetchall()
+
+# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
+ return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
+
+
+# # Hybrid + CE rerank query:
+# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8)
+# for rid, txt, score in results:
+# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n")
+#
+#
+#
+#
+#
+#
+#
diff --git a/rag/test.py b/rag/test.py
new file mode 100644
index 0000000..b7a6d8e
--- /dev/null
+++ b/rag/test.py
@@ -0,0 +1,77 @@
+from pathlib import Path
+from docling.document_converter import DocumentConverter
+from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
+from transformers import AutoModel
+from sentence_transformers import SentenceTransformer
+from rag.db import get_db
+from rag.rerank import get_rerank_model
+import rag.ingest
+import rag.search
+import torch
+
+# converter = DocumentConverter()
+# chunker = HybridChunker()
+# file = Path("yek.md")
+# doc = converter.convert(file).document
+# chunk_iter = chunker.chunk(doc)
+# for chunk in chunk_iter:
+# print(chunk)
+# txt = chunker.contextualize(chunk)
+# print(txt)
+
+
+
+def t():
+ batch: list[str] = ["This son of a bitch has gone too far", "Fuck me baby please", "I'm hungry now", "Charlie Kirk is dead"]
+ model = rag.ingest.get_embed_model()
+ v = model.encode("pepeee", normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True)
+ print("v")
+ print(type(v))
+ print(v)
+ print(v.dtype)
+ print(v.device)
+
+# V = torch.cat([v], dim=0)
+# print("V")
+# print(type(V))
+# print(V)
+# print(V.dtype)
+# print(V.device)
+# print("V_np")
+# V_idk = V.cpu().float()
+
+# when they were pytorch tensors
+# V = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB:
+# V_np = V.float().cpu().numpy().astype("float32")
+# DIM = V_np.shape[1]
+# db = sqlite3.connect("./rag.db")
+
+queries = [
+"How was Shuihu zhuan received in early modern Japan?",
+"Edo-period readers’ image of Song Jiang / Liangshan outlaws",
+"Channels of transmission for Chinese vernacular fiction into Japan (kanbun kundoku, digests, translations)",
+"Role of woodblock prints/illustrations in mediating Chinese fiction",
+"Key Japanese scholars, writers, or publishers who popularized Chinese fiction",
+"Kyokutei Bakin’s engagement with Chinese vernacular narrative",
+"Santō Kyōden, gesaku, and Chinese models",
+"Kanzen chōaku (encourage good, punish evil) and Water Margin in Japan",
+"Moral ambivalence of outlaw heroes as discussed in the text",
+"Censorship or moral debates around reading Chinese fiction",
+"Translation strategies from vernacular Chinese to Japanese (furigana, kundoku, glossing)",
+"Paratexts: prefaces, commentaries, reader guidance apparatus",
+"Bibliographic details: editions, reprints, circulation networks",
+"How does this book challenge older narratives about Sino-Japanese literary influence?",
+"Methodology: sources, archives, limitations mentioned by the author",
+
+]
+def t2():
+ db = get_db()
+ # Hybrid + CE rerank query:
+ for query in queries:
+ print("query", query)
+ results = rag.search.search_hybrid(db, query, k_vec=50, k_bm25=50, k_final=8)
+ for rid, txt, score in results:
+ sim = score
+ print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n")
+
+t2()
diff --git a/sentence.py b/sentence.py
new file mode 100644
index 0000000..f29927a
--- /dev/null
+++ b/sentence.py
@@ -0,0 +1,42 @@
+# Requires transformers>=4.51.0
+# Requires sentence-transformers>=2.7.0
+import torch
+from sentence_transformers import SentenceTransformer
+
+# Load the model
+# model = SentenceTransformer("Qwen/Qwen3-Embedding-8B")
+
+# We recommend enabling flash_attention_2 for better acceleration and memory saving,
+# together with setting `padding_side` to "left":
+model = SentenceTransformer(
+ "Qwen/Qwen3-Embedding-8B",
+ model_kwargs={"attn_implementation": "flash_attention_2", "device_map": "auto", "dtype": torch.float16},
+ tokenizer_kwargs={"padding_side": "left"},
+)
+
+
+# Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen3Model is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
+
+# The queries and documents to embed
+queries = [
+ "What is the capital of China?",
+ "Explain gravity",
+]
+documents = [
+ "The capital of China is Beijing.",
+ "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
+]
+
+# with torch.autocast(device_type='torch_device'):
+with torch.no_grad():
+# Encode the queries and documents. Note that queries benefit from using a prompt
+# Here we use the prompt called "query" stored under `model.prompts`, but you can
+# also pass your own prompt via the `prompt` argument
+ query_embeddings = model.encode(queries, prompt_name="query")
+ document_embeddings = model.encode(documents)
+
+# Compute the (cosine) similarity between the query and document embeddings
+similarity = model.similarity(query_embeddings, document_embeddings)
+print(similarity)
+# tensor([[0.7493, 0.0751],
+# [0.0880, 0.6318]])
diff --git a/tf.py b/tf.py
new file mode 100644
index 0000000..9ffd868
--- /dev/null
+++ b/tf.py
@@ -0,0 +1,65 @@
+
+import torch
+import torch.nn.functional as F
+
+from torch import Tensor
+from transformers import AutoModel, AutoTokenizer
+
+
+def last_token_pool(last_hidden_states: Tensor,
+ attention_mask: Tensor) -> Tensor:
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
+ if left_padding:
+ return last_hidden_states[:, -1]
+ else:
+ sequence_lengths = attention_mask.sum(dim=1) - 1
+ batch_size = last_hidden_states.shape[0]
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
+
+
+def get_detailed_instruct(task_description: str, query: str) -> str:
+ return f'Instruct: {task_description}\nQuery:{query}'
+
+# Each query must come with a one-sentence instruction that describes the task
+task = 'Given a web search query, retrieve relevant passages that answer the query'
+
+queries = [
+ get_detailed_instruct(task, 'What is the capital of China?'),
+ get_detailed_instruct(task, 'Explain gravity')
+]
+# No need to add instruction for retrieval documents
+documents = [
+ "The capital of China is Beijing.",
+ "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
+]
+input_texts = queries + documents
+
+
+# We recommend enabling flash_attention_2 for better acceleration and memory saving,
+# together with setting `padding_side` to "left":
+model = AutoModel.from_pretrained(
+ "Qwen/Qwen3-Embedding-8B", attn_implementation="flash_attention_2", device_map="auto", torch_dtype=torch.float16
+)
+tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-8B', padding_side="left")
+# The queries and documents to embed
+max_length = 8192
+
+# Tokenize the input texts
+batch_dict = tokenizer(
+ input_texts,
+ padding=True,
+ truncation=True,
+ max_length=max_length,
+ return_tensors="pt",
+)
+batch_dict.to(model.device)
+with torch.no_grad():
+ outputs = model(**batch_dict)
+ embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
+
+ # normalize embeddings
+ embeddings = F.normalize(embeddings, p=2, dim=1)
+ scores = (embeddings[:2] @ embeddings[2:].T)
+
+print(scores.tolist())
+# [[0.7645568251609802, 0.14142508804798126], [0.13549736142158508, 0.5999549627304077]]
diff --git a/yek.md b/yek.md
new file mode 100644
index 0000000..d71c21d
--- /dev/null
+++ b/yek.md
@@ -0,0 +1,316 @@
+>>>> __init__.py
+
+>>>> main.py
+# from rag.rerank import search_hybrid
+import torch
+import sqlite3
+import sqlite_vec
+from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
+from pathlib import Path
+from docling.document_converter import DocumentConverter
+from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
+
+from sentence_transformers import SentenceTransformer
+
+EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
+RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B"
+MAX_TOKENS = 600
+BATCH = 16
+
+model = SentenceTransformer(
+ "Qwen/Qwen3-Embedding-8B",
+ model_kwargs={
+ # "trust_remote_code":True,
+ "attn_implementation":"flash_attention_2",
+ "device_map":"auto",
+ "dtype":torch.float16
+ }, tokenizer_kwargs={"padding_side": "left"}
+)
+tokenizer = HuggingFaceTokenizer(tokenizer=model.tokenizer, max_tokens=MAX_TOKENS)
+
+def parse_and_chunk()-> list[str]:
+ source = Path("./japan-water-margin.pdf")
+
+ converter = DocumentConverter()
+ doc = converter.convert(source)
+ chunker = HybridChunker(tokenizer=tokenizer, merge_peers=True)
+ out = []
+ for ch in chunker.chunk(doc.document):
+ txt = chunker.contextualize(ch)
+ if txt.strip():
+ out.append(txt)
+ return out
+
+def embed_many(texts: list[str]):
+ V_np = model.encode(texts,
+ batch_size=BATCH,
+ normalize_embeddings=True,
+ convert_to_numpy=True,
+ show_progress_bar=True)
+ return V_np.astype("float32")
+
+
+
+chunks =parse_and_chunk()
+V_np = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB:
+DIM = V_np.shape[1]
+db = sqlite3.connect("./rag.db")
+
+
+
+def init_db(DIM: str):
+ db.enable_load_extension(True)
+ sqlite_vec.load(db)
+ db.enable_load_extension(False)
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])")
+ db.execute('''
+ CREATE TABLE IF NOT EXISTS chunks (
+ id INTEGER PRIMARY KEY,
+ text TEXT
+ )''')
+ db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)")
+
+def store_chunks(chunks: list[str], V_np):
+ assert len(chunks) == len(V_np)
+ db.execute("BEGIN")
+ db.executemany('''
+ INSERT INTO chunks(id, text) VALUES (?, ?)
+ ''', list(enumerate(chunks, start=1)))
+ db.executemany(
+ "INSERT INTO vec(rowid, embedding) VALUES (?, ?)",
+ [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]
+ )
+ db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
+ db.commit()
+
+# 5) Query helper (cosine distance; operator may be <#> in sqlite-vec)
+def search(qtext, k=5):
+ q = model.encode([qtext], normalize=True, convert_to_numpy=True).astype("float32")
+ # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
+ rows = db.execute(
+ "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ (memoryview(q.tobytes()), k)
+ ).fetchall()
+
+# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
+ return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
+
+
+
+init_db(DIM)
+store_chunks(chunks, V_np)
+
+
+# # Hybrid + CE rerank query:
+# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8)
+# for rid, txt, score in results:
+# print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n")
+
+>>>> mmr.py
+import numpy as np
+
+def cosine(a, b): return float(np.dot(a, b))
+
+def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7):
+ """cand_ids: [int], cand_vecs: np.ndarray float32 [N,D] (unit vectors) aligned with cand_ids"""
+ selected, selected_idx = [], []
+ remaining = list(range(len(cand_ids)))
+
+ # seed with the most relevant
+ best0 = max(remaining, key=lambda i: cosine(query_vec, cand_vecs[i]))
+ selected.append(cand_ids[best0]); selected_idx.append(best0); remaining.remove(best0)
+
+ while remaining and len(selected) < k:
+ def mmr_score(i):
+ rel = cosine(query_vec, cand_vecs[i])
+ red = max(cosine(cand_vecs[i], cand_vecs[j]) for j in selected_idx) if selected_idx else 0.0
+ return lamb * rel - (1.0 - lamb) * red
+ nxt = max(remaining, key=mmr_score)
+ selected.append(cand_ids[nxt]); selected_idx.append(nxt); remaining.remove(nxt)
+ return selected
+
+>>>> rag.db
+
+>>>> rerank.py
+import torch
+import sqlite3
+import sqlite_vec
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
+
+# -----------------------------
+# RERANKER (cross-encoder)
+# -----------------------------
+RERANKER_ID = "BAAI/bge-reranker-base" # small & solid
+# RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # small & solid
+_rr_tok = AutoTokenizer.from_pretrained(RERANKER_ID)
+_rr_mod = AutoModelForSequenceClassification.from_pretrained(
+ RERANKER_ID, device_map="auto", dtype=torch.bfloat16
+)
+
+def rerank_cross_encoder(query: str, candidates: list[tuple[int, str]], batch_size: int = 32) -> list[tuple[int, str, float]]:
+ """
+ candidates: [(id, text), ...]
+ returns: [(id, text, score)] sorted desc by score
+ """
+ if not candidates:
+ return []
+ ids, texts = zip(*candidates)
+ scores = []
+ for i in range(0, len(texts), batch_size):
+ batch = texts[i:i+batch_size]
+ enc = _rr_tok(list(zip([query]*len(batch), batch)),
+ padding=True, truncation=True, max_length=512,
+ return_tensors="pt").to(_rr_mod.device)
+ with torch.no_grad():
+ logits = _rr_mod(**enc).logits.squeeze(-1) # [B]
+ scores.extend(logits.float().cpu().tolist())
+ ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
+ return ranked
+
+# -----------------------------
+# DB: vec + chunks + FTS5
+# -----------------------------
+db = sqlite3.connect("./rag.db")
+
+def init_db(dim: int):
+ db.enable_load_extension(True)
+ sqlite_vec.load(db) # loads vec0
+ db.enable_load_extension(False)
+
+ # Vector table (cosine expects float32 blob)
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{dim}])")
+
+ # Raw chunks
+ db.execute("""
+ CREATE TABLE IF NOT EXISTS chunks(
+ id INTEGER PRIMARY KEY,
+ text TEXT
+ )
+ """)
+
+ # FTS5 over chunks.text (contentless mode keeps it simple)
+ db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)")
+ db.execute("CREATE INDEX IF NOT EXISTS idx_chunks_id ON chunks(id)")
+
+def store_chunks(chunks: list[str], V_np):
+ # chunks: list[str], V_np: np.ndarray float32 [N, D]
+ assert len(chunks) == len(V_np), "mismatch chunks vs vectors"
+
+ db.execute("BEGIN")
+ db.executemany("INSERT INTO chunks(id, text) VALUES (?,?)",
+ list(enumerate(chunks, start=1)))
+ db.executemany("INSERT INTO vec(rowid, embedding) VALUES (?,?)",
+ [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))])
+ # FTS content
+ db.executemany("INSERT INTO fts(rowid, text) VALUES (?,?)",
+ list(enumerate(chunks, start=1)))
+ db.commit()
+
+# -----------------------------
+# SEARCH: hybrid → rerank
+# -----------------------------
+def _vec_topk(qvec_f32: bytes, k: int):
+ # cosine distance: smaller is better
+ rows = db.execute(
+ "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ (memoryview(qvec_f32), k)
+ ).fetchall()
+ return [(rid, float(dist)) for (rid, dist) in rows]
+
+def _bm25_topk(query: str, k: int):
+ # note: FTS5 MATCH syntax, tokenizer defaults ok for English/legal
+ rows = db.execute(
+ "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?",
+ (query, k)
+ ).fetchall()
+ return [rid for (rid,) in rows]
+
+def _fetch_text(ids: list[int]) -> list[tuple[int, str]]:
+ if not ids:
+ return []
+ qmarks = ",".join("?"*len(ids))
+ rows = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", ids).fetchall()
+ by_id = {i:t for (i,t) in rows}
+ # keep input order
+ return [(i, by_id[i]) for i in ids if i in by_id]
+
+def search_hybrid(query: str, embed_model, k_vec: int = 40, k_bm25: int = 40, k_final: int = 10):
+ # 1) embed query (unit-norm) with same model you used for chunks
+ with torch.no_grad():
+ q = embed_model.encode([query], normalize=True).to("cpu").to(torch.float32).numpy()[0]
+ qbytes = q.tobytes()
+
+ # 2) candidate pools
+ vec_hits = _vec_topk(qbytes, k_vec) # [(id, dist)]
+ vec_ids = [i for (i, _) in vec_hits]
+ bm25_ids = _bm25_topk(query, k_bm25)
+
+ # 3) merge (preserve vector order first)
+ seen = set()
+ merged_ids = []
+ for i in vec_ids + bm25_ids:
+ if i not in seen:
+ merged_ids.append(i); seen.add(i)
+
+ # 4) fetch texts, then rerank with cross-encoder
+ cand = _fetch_text(merged_ids)
+ ranked = rerank_cross_encoder(query, cand)
+
+ # 5) return top k_final: [(id, text, ce_score)]
+ return ranked[:k_final]
+
+>>>> test.py
+from pathlib import Path
+from docling.document_converter import DocumentConverter
+from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
+from transformers import AutoModel
+from sentence_transformers import SentenceTransformer
+import torch
+
+# converter = DocumentConverter()
+# chunker = HybridChunker()
+# file = Path("yek.md")
+# doc = converter.convert(file).document
+# chunk_iter = chunker.chunk(doc)
+# for chunk in chunk_iter:
+# print(chunk)
+# txt = chunker.contextualize(chunk)
+# print(txt)
+
+
+
+model = SentenceTransformer(
+ "Qwen/Qwen3-Embedding-8B",
+ model_kwargs={
+ # "trust_remote_code":True,
+ "attn_implementation":"flash_attention_2",
+ "device_map":"auto",
+ "dtype":torch.float16
+ }, tokenizer_kwargs={"padding_side": "left"}
+)
+
+batch: list[str] = ["This son of a bitch has gone too far", "Fuck me baby please", "I'm hungry now", "Charlie Kirk is dead"]
+
+v = model.encode("pepeee", normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True)
+print("v")
+print(type(v))
+print(v)
+print(v.dtype)
+print(v.device)
+
+# V = torch.cat([v], dim=0)
+# print("V")
+# print(type(V))
+# print(V)
+# print(V.dtype)
+# print(V.device)
+# print("V_np")
+# V_idk = V.cpu().float()
+
+# when they were pytorch tensors
+# V = embed_many(chunks) # float32/fp16 on CPU? ensure float32 for DB:
+# V_np = V.float().cpu().numpy().astype("float32")
+# DIM = V_np.shape[1]
+# db = sqlite3.connect("./rag.db")
+
+