summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
Diffstat (limited to 'rag')
-rw-r--r--rag/constants.py5
-rw-r--r--rag/db.py75
-rw-r--r--rag/ingest.py17
-rw-r--r--rag/main.py106
-rw-r--r--rag/mmr.py17
-rw-r--r--rag/rerank.py36
-rw-r--r--rag/search.py205
-rw-r--r--rag/test.py39
-rw-r--r--rag/utils.py15
9 files changed, 377 insertions, 138 deletions
diff --git a/rag/constants.py b/rag/constants.py
new file mode 100644
index 0000000..327e8e5
--- /dev/null
+++ b/rag/constants.py
@@ -0,0 +1,5 @@
+EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
+RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM
+# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
+MAX_TOKENS = 600
+BATCH = 16
diff --git a/rag/db.py b/rag/db.py
index c015c96..94ff3fb 100644
--- a/rag/db.py
+++ b/rag/db.py
@@ -1,6 +1,7 @@
import sqlite3
from numpy import ndarray
from sqlite_vec import serialize_float32
+import sqlite_vec
def get_db():
db = sqlite3.connect("./rag.db")
@@ -13,43 +14,85 @@ def get_db():
db.enable_load_extension(False)
return db
-def init_schema(db: sqlite3.Connection, DIM:int):
- db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])")
- db.execute('''
- CREATE TABLE IF NOT EXISTS chunks (
+def init_schema(db: sqlite3.Connection, col: str, DIM:int, model_id: str, tok_id: str, normalize: bool, preproc_hash: str):
+ print("initing schema", col)
+ db.execute("""
+ CREATE TABLE IF NOT EXISTS collections(
+ name TEXT PRIMARY KEY,
+ model TEXT,
+ tokenizer TEXT,
+ dim INTEGER,
+ normalize INTEGER,
+ preproc_hash TEXT,
+ created_at INTEGER DEFAULT (unixepoch())
+ )""")
+ db.execute("BEGIN")
+ db.execute("INSERT INTO collections(name, model, tokenizer, dim, normalize, preproc_hash) VALUES(?, ?, ?, ?, ?, ?)",
+ (col, model_id, tok_id, DIM, int(normalize), preproc_hash))
+
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec_{col} USING vec0(embedding float[{DIM}])")
+ db.execute(f'''
+ CREATE TABLE IF NOT EXISTS chunks_{col} (
id INTEGER PRIMARY KEY,
text TEXT
)''')
- db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)")
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS fts_{col} USING fts5(text)")
db.commit()
-def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray):
+
+def check_db(db: sqlite3.Connection, coll: str):
+ row = db.execute("SELECT dim, model, normalize FROM collections WHERE name=?", (coll,)).fetchone()
+ return row
+ # assert row and row[0] == DIM and row[1] == EMBED_MODEL_ID and row[2] == 1 # if you normalize
+def check_db2(db: sqlite3.Connection, coll: str):
+ row = db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name=?",
+ (f"vec_{coll}",)
+ ).fetchone()
+ return bool(row)
+def store_chunks(db: sqlite3.Connection, col: str, chunks: list[str], V_np:ndarray):
assert len(chunks) == len(V_np)
db.execute("BEGIN")
- db.executemany('''
- INSERT INTO chunks(id, text) VALUES (?, ?)
+ db.executemany(f'''
+ INSERT INTO chunks_{col}(id, text) VALUES (?, ?)
''', list(enumerate(chunks, start=1)))
db.executemany(
- "INSERT INTO vec(rowid, embedding) VALUES (?, ?)",
+ f"INSERT INTO vec_{col}(rowid, embedding) VALUES (?, ?)",
[(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]
)
- db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
+ db.executemany(f"INSERT INTO fts_{col}(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
db.commit()
-def vec_topk(db, q_vec_f32, k=10):
+def vec_topk(db,col: str, q_vec_f32, k=10):
+ # rows = db.execute(
+ # "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ # (memoryview(q.tobytes()), k)
+ # ).fetchall()
rows = db.execute(
- "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
+ f"SELECT rowid, distance FROM vec_{col} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
(serialize_float32(q_vec_f32), k)
).fetchall()
return rows # [(rowid, distance)]
-def bm25_topk(db, qtext, k=10):
+def bm25_topk(db: sqlite3.Connection, col: str, qtext, k=10):
+ safe_q = f'"{qtext}"'
return [rid for (rid,) in db.execute(
- "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k)
+ f"SELECT rowid FROM fts_{col} WHERE fts_{col} MATCH ? LIMIT ?", (safe_q, k)
).fetchall()]
- def wipe_db():
+def fetch_chunk(db: sqlite3.Connection, col: str, id: int):
+
+ return db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (id,)).fetchone()[0]
+
+def fetch_chunks_in_range(db: sqlite3.Connection, col: str, ids: list[int]):
+ return db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN (?) ", ids).fetchall()
+
+
+
+
+
+def wipe_db(col: str):
db = sqlite3.connect("./rag.db")
- db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;")
+ db.executescript(f"DROP TABLE IF EXISTS chunks_{col}; DROP TABLE IF EXISTS fts_{col}; DROP TABLE IF EXISTS vec_{col};")
db.close()
diff --git a/rag/ingest.py b/rag/ingest.py
index d17690a..5def23d 100644
--- a/rag/ingest.py
+++ b/rag/ingest.py
@@ -7,16 +7,13 @@ from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from rag.db import get_db, init_schema, store_chunks
from sentence_transformers import SentenceTransformer
+from rag.constants import MAX_TOKENS, BATCH, EMBED_MODEL_ID
-EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B"
-RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B"
-MAX_TOKENS = 600
-BATCH = 16
def get_embed_model():
return SentenceTransformer(
- "Qwen/Qwen3-Embedding-8B",
+ EMBED_MODEL_ID,
model_kwargs={
# "trust_remote_code":True,
"attn_implementation":"flash_attention_2",
@@ -48,14 +45,12 @@ def embed_many(model: SentenceTransformer, texts: list[str]):
return V_np.astype("float32")
-def start_ingest(db: sqlite3.Connection, path: Path):
- model = get_embed_model()
+def start_ingest(db: sqlite3.Connection, model: SentenceTransformer | None, collection: str, path: Path):
+ if model is None:
+ model = get_embed_model()
chunks = parse_and_chunk(path, model)
V_np = embed_many(model, chunks)
- DIM = V_np.shape[1]
- db = get_db()
- init_schema(db, DIM)
- store_chunks(db, chunks, V_np)
+ store_chunks(db, collection, chunks, V_np)
# TODO some try catch?
return True
# float32/fp16 on CPU? ensure float32 for DB:
diff --git a/rag/main.py b/rag/main.py
index 221adfa..bbd549a 100644
--- a/rag/main.py
+++ b/rag/main.py
@@ -1,10 +1,74 @@
import sys
import argparse
from pathlib import Path
-from rag.ingest import start_ingest
+from rag.constants import EMBED_MODEL_ID
+from rag.ingest import get_embed_model, start_ingest
from rag.search import search_hybrid, vec_search
-from rag.db import get_db
-# your Docling chunker imports…
+from rag.db import get_db, check_db, check_db2, init_schema
+
+
+def valid_collection(col: str) -> bool:
+ # TODO must have less than 9 characters and be ascii, no spaces
+ return True
+
+
+def cmd_ingest(args):
+ path = Path(args.file)
+ if not valid_collection(args.collection):
+ print(f"Collection name invalid: {args.collection}", file=sys.stderr)
+ sys.exit(1)
+
+
+ if not path.exists():
+ print(f"File not found: {path}", file=sys.stderr)
+ sys.exit(1)
+
+
+ db = get_db()
+ if not check_db2(db, args.collection):
+ model = get_embed_model()
+ dim = model.get_sentence_embedding_dimension()
+ if dim is None:
+ sys.exit(1)
+
+ # TODO Try catch in here, tell the user if it crashes
+ init_schema(db, args.collection, dim, EMBED_MODEL_ID, EMBED_MODEL_ID, True, 'idk')
+ stats = start_ingest(db, model, args.collection, path)
+ else:
+ stats = start_ingest(db, None, args.collection, path)
+ print(f"Ingested file={args.file} :: {stats}")
+
+def cmd_query(args):
+ if not valid_collection(args.collection):
+ print(f"Collection name invalid: {args.collection}", file=sys.stderr)
+ sys.exit(1)
+ db = get_db()
+ if not check_db2(db, args.collection):
+ print(f"Collection name not in DB, what are you searching: {args.collection}", file=sys.stderr)
+ sys.exit(1)
+
+ if args.simple:
+ results = vec_search(db, args.collection, args.query, k=args.k_final)
+ else:
+ results = search_hybrid(db,
+ args.collection,
+ args.query,
+ k_vec=args.k_vec,
+ k_bm25=args.k_bm25,
+ k_ce=args.k_ce,
+ k_final=args.k_final,
+ use_mmr=args.mmr,
+ mmr_lambda=args.mmr_lambda,
+ )
+
+ for rid, txt, score in results:
+ print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n")
+
+ db.close()
+
+
+
+
def main():
ap = argparse.ArgumentParser(prog="rag")
@@ -13,10 +77,12 @@ def main():
# ingest
ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection")
ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest")
+ ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)")
ap_ing.set_defaults(func=cmd_ingest)
# query
ap_q = sub.add_parser("query", help="Query a collection")
+ ap_q.add_argument("--collection", required=True, help="Collection name to search")
ap_q.add_argument("--query", required=True, help="User query text")
ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)")
ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE")
@@ -35,37 +101,3 @@ if __name__ == "__main__":
main()
-
-
-def cmd_ingest(args):
- path = Path(args.file)
-
- if not path.exists():
- print(f"File not found: {path}", file=sys.stderr)
- sys.exit(1)
-
-
- db = get_db()
- stats = start_ingest(db,path)
- print(f"Ingested file={args.file} :: {stats}")
-
-def cmd_query(args):
-
- db = get_db()
- if args.simple:
- results = vec_search(db, args.query, k=args.k_final)
- else:
- results = search_hybrid(
- db, args.query,
- k_vec=args.k_vec,
- k_bm25=args.k_bm25,
- k_ce=args.k_ce,
- k_final=args.k_final,
- use_mmr=args.mmr,
- mmr_lambda=args.mmr_lambda,
- )
-
- for rid, txt, score in results:
- print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n")
-
- db.close()
diff --git a/rag/mmr.py b/rag/mmr.py
index 5e47c4f..b52751f 100644
--- a/rag/mmr.py
+++ b/rag/mmr.py
@@ -1,3 +1,4 @@
+from rag.constants import BATCH
import numpy as np
def cosine(a, b): return float(np.dot(a, b))
@@ -21,6 +22,20 @@ def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7):
return selected
def embed_unit_np(st_model, texts: list[str]) -> np.ndarray:
- V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32)
+ V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=BATCH)
V = V.astype("float32", copy=False)
return V
+
+def mmr2(qvec: np.ndarray, ids, vecs: np.ndarray, k=8, lamb=0.7):
+ sel, idxs = [], []
+ rest = list(range(len(ids)))
+ best0 = max(rest, key=lambda i: float(qvec @ vecs[i]))
+ sel.append(ids[best0]); idxs.append(best0); rest.remove(best0)
+ while rest and len(sel) < k:
+ def score(i):
+ rel = float(qvec @ vecs[i])
+ red = max(float(vecs[i] @ vecs[j]) for j in idxs)
+ return lamb*rel - (1-lamb)*red
+ nxt = max(rest, key=score)
+ sel.append(ids[nxt]); idxs.append(nxt); rest.remove(nxt)
+ return sel
diff --git a/rag/rerank.py b/rag/rerank.py
index 6ae8938..8a2870d 100644
--- a/rag/rerank.py
+++ b/rag/rerank.py
@@ -1,23 +1,25 @@
# rag/rerank.py
import torch
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import CrossEncoder
+from rag.constants import BATCH, RERANKER_ID
-RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM
-# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
# device: "cuda" | "cpu" | "mps"
def get_rerank_model():
+ id = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM
return CrossEncoder(
- RERANKER_ID,
+ id,
device="cuda",
+ max_length=384,
model_kwargs={
- "attn_implementation":"flash_attention_2",
+ # "attn_implementation":"flash_attention_2",
"device_map":"auto",
"dtype":torch.float16
},
tokenizer_kwargs={"padding_side": "left"}
)
-def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = 32):
+def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = BATCH):
"""
candidates: [(id, text), ...]
returns: [(id, text, score)] sorted desc by score
@@ -25,7 +27,29 @@ def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tu
if not candidates:
return []
ids, texts = zip(*candidates)
- pairs = [(query, t) for t in texts]
+ # pairs = [(query, t) for t in texts]
+ pairs = [(query, t[:1000]) for t in texts]
scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better
ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
return ranked
+
+
+
+
+# tok = AutoTokenizer.from_pretrained(RERANKER_ID, use_fast=True, model_max_length=384)
+# ce = AutoModelForSequenceClassification.from_pretrained(
+# RERANKER_ID, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
+# device_map="auto" # or load_in_8bit=True
+# )
+
+# def ce_scores(query, texts, batch_size=16, max_length=384):
+# scores = []
+# for i in range(0, len(texts), batch_size):
+# batch = texts[i:i+batch_size]
+# enc = tok([ (query, t[:1000]) for t in batch ],
+# padding=True, truncation=True, max_length=max_length,
+# return_tensors="pt").to(ce.device)
+# with torch.inference_mode():
+# logits = ce(**enc).logits.squeeze(-1) # [B]
+# scores.extend(logits.float().cpu().tolist())
+# return scores
diff --git a/rag/search.py b/rag/search.py
index 55b8ffd..291c1b4 100644
--- a/rag/search.py
+++ b/rag/search.py
@@ -5,12 +5,68 @@ import gc
from typing import List, Tuple
from sentence_transformers import CrossEncoder, SentenceTransformer
+from rag.constants import BATCH
from rag.ingest import get_embed_model
from rag.rerank import get_rerank_model, rerank_cross_encoder
-from rag.mmr import mmr, embed_unit_np # if you added MMR; else remove
+from rag.mmr import mmr, mmr2, embed_unit_np # if you added MMR; else remove
+from rag.db import vec_topk, bm25_topk, fetch_chunk
+
+
+# ) Query helper (cosine distance; operator may be <#> in sqlite-vec)
+# def vec_search(db: sqlite3.Connection, col: str, qtext, k=5):
+# model = get_embed_model()
+# q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
+# # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
+# rows = vec_topk(db, col, q, k)
+
+# # db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
+# return [(rid, db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
+
+def _dist_to_sim(dist: float) -> float:
+ # L2 on unit vectors ↔ cosine: ||a-b||^2 = 2 - 2 cos => cos = 1 - dist/2
+ return max(0.0, 1.0 - dist / 2.0)
+
+def vec_search(db, model: SentenceTransformer, col: str, qtext: str, k: int = 10, min_sim: float = 0.25,
+ max_per_doc: int | None = None, use_mmr: bool = False, mmr_lambda: float = 0.7):
+ q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") # <- flatten
+ rows = vec_topk(db, col, q, k * 4) # overfetch a bit, filter below
+
+ # fetch texts + compute cosine sim
+ hits = []
+ for rid, dist in rows:
+ txt = fetch_chunk(db, col, rid)
+ sim = _dist_to_sim(dist)
+ if sim >= min_sim:
+ hits.append((rid, txt, sim))
+
+ # anti-spam (cap near-duplicates from same doc region if you add metadata later)
+ if max_per_doc:
+ capped, seen = [], {}
+ for rid, txt, sim in hits:
+ dockey = col # or derive from a future chunk_meta table
+ cnt = seen.get(dockey, 0)
+ if cnt < max_per_doc:
+ capped.append((rid, txt, sim))
+ seen[dockey] = cnt + 1
+ hits = capped
+
+ # optional light MMR on the filtered set (diversify)
+ if use_mmr and len(hits) > k:
+ from rag.mmr import embed_unit_np, mmr2
+ ids = [h[0] for h in hits]
+ texts = [h[1] for h in hits]
+ qvec = q
+ cvecs = embed_unit_np(model, texts) # [N,D] unit
+ keep = set(mmr2(qvec, ids, cvecs, k=k, lamb=mmr_lambda))
+ hits = [h for h in hits if h[0] in keep]
+
+ # final crop
+ hits.sort(key=lambda x: x[2], reverse=True)
+ return hits[:k]
+
def search_hybrid(
- db: sqlite3.Connection, query: str,
+ db: sqlite3.Connection, col: str, query: str,
k_vec: int = 40, k_bm25: int = 40,
k_ce: int = 30, # rerank this many
k_final: int = 10, # return this many
@@ -19,74 +75,57 @@ def search_hybrid(
emodel = get_embed_model()
# 1) embed query (unit, float32) for vec search + MMR
print("loading")
- q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
+ query_embeddings = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
emodel = None
with torch.no_grad():
torch.cuda.empty_cache()
gc.collect()
print("memory should be free by now!!")
- qbytes = q.tobytes()
-
- # 2) ANN (cosine) + BM25 pools
- vec_ids = [i for (i, _) in db.execute(
- "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
- (memoryview(qbytes), k_vec)
- ).fetchall()]
- bm25_ids = [i for (i,) in db.execute(
- "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?",
- (query, k_bm25)
- ).fetchall()]
-
- # 3) merge (vector-first)
- seen, merged = set(), []
- for i in vec_ids + bm25_ids:
+ # qbytes = q.tobytes()
+
+
+ # 2) ANN + BM25
+ print("phase 2", col, query)
+ vhits = vec_topk(db, col, query_embeddings, k_vec) # [(id, dist)]
+ vh_ids = [i for (i, _) in vhits]
+ bm_ids = bm25_topk(db, col, query, k_bm25)
+ #
+ # 3) merge ids [vector-first]
+ merged, seen = [], set()
+ for i in vh_ids + bm_ids:
if i not in seen:
merged.append(i); seen.add(i)
if not merged:
return []
- # 4) fetch texts for CE
+ # 4) fetch texts
qmarks = ",".join("?"*len(merged))
- cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall()
+ cand = db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN ({qmarks})", merged).fetchall()
+ ids, texts = zip(*cand)
+
+ # 5) rerank
+ print("loading reranking model")
reranker = get_rerank_model()
- # 5) cross-encoder rerank (returns [(id,text,score)] desc)
- ranked = rerank_cross_encoder(reranker, query, cand)
- reranker = None
- print("freeing again!!")
+ scores = reranker.predict([(query, t[:1000]) for t in texts], batch_size=BATCH)
+ reranker =None
with torch.no_grad():
torch.cuda.empty_cache()
gc.collect()
- #
- ranked = ranked[:min(k_ce, len(ranked))]
-
+ print("memory should be free by now!!")
+ print("unloading reranking model")
+ ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True)
if not use_mmr or len(ranked) <= k_final:
- return ranked[:k_final]
+ return ranked[:min(k_ce, k_final)]
- # 6) MMR diversity on CE top-k_ce
- cand_ids = [i for (i,_,_) in ranked]
- cand_text = [t for (_,t,_) in ranked]
- emodel = get_embed_model()
- # god this is annoying I should stop being poor
- cand_vecs = embed_unit_np(emodel, cand_text) # [N,D], unit vectors
- sel_ids = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda))
- final = [trip for trip in ranked if trip[0] in sel_ids] # keep CE order, filter by MMR picks
- return final[:k_final]
-
-
-# ) Query helper (cosine distance; operator may be <#> in sqlite-vec)
-def vec_search(db: sqlite3.Connection, qtext, k=5):
- model = get_embed_model()
- q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
- # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
- rows = db.execute(
- "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
- (memoryview(q.tobytes()), k)
- ).fetchall()
-
-# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
- return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]
+ # 6) MMR
+ ce_ids = [i for (i,_,_) in ranked]
+ ce_texts = [t for (_,t,_) in ranked]
+ st_model = get_embed_model()
+ ce_vecs = st_model.encode(ce_texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32")
+ keep = set(mmr2(query_embeddings, ce_ids, ce_vecs, k=k_final, lamb=mmr_lambda))
+ return [r for r in ranked if r[0] in keep][:k_final]
# # Hybrid + CE rerank query:
# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8)
@@ -99,3 +138,67 @@ def vec_search(db: sqlite3.Connection, qtext, k=5):
#
#
#
+def search_hybrid_with_mmr(db, col, query, k_vec=50, k_bm25=50, k_ce=30, k_final=10, lamb=0.7):
+ ranked = search_hybrid(db, col, query, k_vec, k_bm25, k_ce, k_ce)
+ if not ranked: return []
+ ids = [i for (i,_,_) in ranked]
+ texts = [t for (_,t,_) in ranked]
+ st_model = get_embed_model()
+ qvec = st_model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
+ cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32")
+ keep = set(mmr2(qvec, ids, cvecs, k=k_final, lamb=lamb))
+ return [r for r in ranked if r[0] in keep][:k_final]
+
+
+
+# clean vec expansion says gpt5
+# #
+# def expand(q, aliases=()):
+# qs = [q, *aliases]
+# # embed each, take max similarity per chunk at scoring time
+
+
+# def dist_to_cos(d): return max(0.0, 1.0 - d/2.0) # L2 on unit vecs
+
+# def vec_topk(db, table, q_vec_f32, k):
+# from sqlite_vec import serialize_float32
+# return db.execute(
+# f"SELECT rowid, distance FROM {table} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
+# (serialize_float32(q_vec_f32), k)
+# ).fetchall()
+
+# def vec_search(db, st_model, col, qtext, k=12, k_raw=None, min_sim=0.30, use_mmr=True, mmr_lambda=0.7):
+# if k_raw is None: k_raw = k*4
+# q = st_model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
+# rows = vec_topk(db, f"vec_{col}", q, k_raw)
+
+# hits = []
+# for rid, dist in rows:
+# cos = dist_to_cos(dist)
+# if cos < min_sim: continue
+# txt = db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0]
+# hits.append((rid, txt, cos))
+
+# hits.sort(key=lambda x: x[2], reverse=True)
+# if not use_mmr or len(hits) <= k:
+# return hits[:k]
+
+# # MMR on the (already filtered) pool
+# ids = [h[0] for h in hits]
+# texts = [h[1] for h in hits]
+# cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32).astype("float32")
+# # simple MMR
+# import numpy as np
+# def cosine(a,b): return float(a@b)
+# sel, sel_idx = [], []
+# rem = list(range(len(ids)))
+# best0 = max(rem, key=lambda i: cosine(q, cvecs[i])); sel.append(ids[best0]); sel_idx.append(best0); rem.remove(best0)
+# while rem and len(sel)<k:
+# def score(i):
+# rel = cosine(q, cvecs[i])
+# red = max(cosine(cvecs[i], cvecs[j]) for j in sel_idx)
+# return mmr_lambda*rel - (1.0 - mmr_lambda)*red
+# nxt = max(rem, key=score); sel.append(ids[nxt]); sel_idx.append(nxt); rem.remove(nxt)
+# keep = set(sel)
+# return [h for h in hits if h[0] in keep][:k]
+
diff --git a/rag/test.py b/rag/test.py
index b7a6d8e..864a3ba 100644
--- a/rag/test.py
+++ b/rag/test.py
@@ -47,29 +47,36 @@ def t():
# db = sqlite3.connect("./rag.db")
queries = [
-"How was Shuihu zhuan received in early modern Japan?",
-"Edo-period readers’ image of Song Jiang / Liangshan outlaws",
-"Channels of transmission for Chinese vernacular fiction into Japan (kanbun kundoku, digests, translations)",
-"Role of woodblock prints/illustrations in mediating Chinese fiction",
-"Key Japanese scholars, writers, or publishers who popularized Chinese fiction",
-"Kyokutei Bakin’s engagement with Chinese vernacular narrative",
-"Santō Kyōden, gesaku, and Chinese models",
-"Kanzen chōaku (encourage good, punish evil) and Water Margin in Japan",
-"Moral ambivalence of outlaw heroes as discussed in the text",
-"Censorship or moral debates around reading Chinese fiction",
-"Translation strategies from vernacular Chinese to Japanese (furigana, kundoku, glossing)",
-"Paratexts: prefaces, commentaries, reader guidance apparatus",
-"Bibliographic details: editions, reprints, circulation networks",
-"How does this book challenge older narratives about Sino-Japanese literary influence?",
-"Methodology: sources, archives, limitations mentioned by the author",
+# "How was Shuihu zhuan received in early modern Japan?",
+# "Edo-period readers’ image of Song Jiang / Liangshan outlaws",
+# "Channels of transmission for Chinese vernacular fiction into Japan (kanbun kundoku, digests, translations)",
+# "Role of woodblock prints/illustrations in mediating Chinese fiction",
+# "Key Japanese scholars, writers, or publishers who popularized Chinese fiction",
+# "Kyokutei Bakin’s engagement with Chinese vernacular narrative",
+# "Santō Kyōden, gesaku, and Chinese models",
+# "Kanzen chōaku (encourage good, punish evil) and Water Margin in Japan",
+# "Moral ambivalence of outlaw heroes as discussed in the text",
+# "Censorship or moral debates around reading Chinese fiction",
+# "Translation strategies from vernacular Chinese to Japanese (furigana, kundoku, glossing)",
+# "Paratexts: prefaces, commentaries, reader guidance apparatus",
+# "Bibliographic details: editions, reprints, circulation networks",
+# "How does this book challenge older narratives about Sino-Japanese literary influence?",
+# "Methodology: sources, archives, limitations mentioned by the author",
+"sex"
]
def t2():
db = get_db()
# Hybrid + CE rerank query:
+ model = rag.ingest.get_embed_model()
for query in queries:
print("query", query)
- results = rag.search.search_hybrid(db, query, k_vec=50, k_bm25=50, k_final=8)
+ print("-----------\n\n")
+ # results = rag.search.search_hybrid(db, "muh", query, k_vec=50, k_bm25=50, k_final=8)
+ # for rid, txt, score in results:
+ # sim = score
+ # print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n")
+ results = rag.search.vec_search(db, model, "muh", query, k=50, min_sim=0.25, max_per_doc=5, use_mmr=False, mmr_lambda=0.7)
for rid, txt, score in results:
sim = score
print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n")
diff --git a/rag/utils.py b/rag/utils.py
new file mode 100644
index 0000000..10519a4
--- /dev/null
+++ b/rag/utils.py
@@ -0,0 +1,15 @@
+import re
+
+FTS_META_CHARS = r'''["'*()^+-]''' # include ? if you see issues
+
+def sanitize_query(q: str, *, allow_ops: bool = False) -> str:
+ q = q.strip()
+ if not q:
+ return q
+ if allow_ops:
+ # escape stray double quotes inside, then wrap
+ q = q.replace('"', '""')
+ return f'"{q}"'
+ # literal search: quote and escape special chars
+ q = re.sub(r'"', '""', q)
+ return f'"{q}"'