diff options
author | polwex <polwex@sortug.com> | 2025-09-24 23:38:36 +0700 |
---|---|---|
committer | polwex <polwex@sortug.com> | 2025-09-24 23:38:36 +0700 |
commit | 734b89570040e97f0c7743c4c0bc28e30a3cd4ee (patch) | |
tree | 7142d9f37908138c38d0ade066e960c3a1c69f5d /rag | |
parent | 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (diff) |
Diffstat (limited to 'rag')
-rw-r--r-- | rag/constants.py | 5 | ||||
-rw-r--r-- | rag/db.py | 75 | ||||
-rw-r--r-- | rag/ingest.py | 17 | ||||
-rw-r--r-- | rag/main.py | 106 | ||||
-rw-r--r-- | rag/mmr.py | 17 | ||||
-rw-r--r-- | rag/rerank.py | 36 | ||||
-rw-r--r-- | rag/search.py | 205 | ||||
-rw-r--r-- | rag/test.py | 39 | ||||
-rw-r--r-- | rag/utils.py | 15 |
9 files changed, 377 insertions, 138 deletions
diff --git a/rag/constants.py b/rag/constants.py new file mode 100644 index 0000000..327e8e5 --- /dev/null +++ b/rag/constants.py @@ -0,0 +1,5 @@ +EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" +RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM +# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM +MAX_TOKENS = 600 +BATCH = 16 @@ -1,6 +1,7 @@ import sqlite3 from numpy import ndarray from sqlite_vec import serialize_float32 +import sqlite_vec def get_db(): db = sqlite3.connect("./rag.db") @@ -13,43 +14,85 @@ def get_db(): db.enable_load_extension(False) return db -def init_schema(db: sqlite3.Connection, DIM:int): - db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])") - db.execute(''' - CREATE TABLE IF NOT EXISTS chunks ( +def init_schema(db: sqlite3.Connection, col: str, DIM:int, model_id: str, tok_id: str, normalize: bool, preproc_hash: str): + print("initing schema", col) + db.execute(""" + CREATE TABLE IF NOT EXISTS collections( + name TEXT PRIMARY KEY, + model TEXT, + tokenizer TEXT, + dim INTEGER, + normalize INTEGER, + preproc_hash TEXT, + created_at INTEGER DEFAULT (unixepoch()) + )""") + db.execute("BEGIN") + db.execute("INSERT INTO collections(name, model, tokenizer, dim, normalize, preproc_hash) VALUES(?, ?, ?, ?, ?, ?)", + (col, model_id, tok_id, DIM, int(normalize), preproc_hash)) + + db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec_{col} USING vec0(embedding float[{DIM}])") + db.execute(f''' + CREATE TABLE IF NOT EXISTS chunks_{col} ( id INTEGER PRIMARY KEY, text TEXT )''') - db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)") + db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS fts_{col} USING fts5(text)") db.commit() -def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray): + +def check_db(db: sqlite3.Connection, coll: str): + row = db.execute("SELECT dim, model, normalize FROM collections WHERE name=?", (coll,)).fetchone() + return row + # assert row and row[0] == DIM and row[1] == EMBED_MODEL_ID and row[2] == 1 # if you normalize +def check_db2(db: sqlite3.Connection, coll: str): + row = db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (f"vec_{coll}",) + ).fetchone() + return bool(row) +def store_chunks(db: sqlite3.Connection, col: str, chunks: list[str], V_np:ndarray): assert len(chunks) == len(V_np) db.execute("BEGIN") - db.executemany(''' - INSERT INTO chunks(id, text) VALUES (?, ?) + db.executemany(f''' + INSERT INTO chunks_{col}(id, text) VALUES (?, ?) ''', list(enumerate(chunks, start=1))) db.executemany( - "INSERT INTO vec(rowid, embedding) VALUES (?, ?)", + f"INSERT INTO vec_{col}(rowid, embedding) VALUES (?, ?)", [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))] ) - db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) + db.executemany(f"INSERT INTO fts_{col}(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) db.commit() -def vec_topk(db, q_vec_f32, k=10): +def vec_topk(db,col: str, q_vec_f32, k=10): + # rows = db.execute( + # "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", + # (memoryview(q.tobytes()), k) + # ).fetchall() rows = db.execute( - "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + f"SELECT rowid, distance FROM vec_{col} WHERE embedding MATCH ? ORDER BY distance LIMIT ?", (serialize_float32(q_vec_f32), k) ).fetchall() return rows # [(rowid, distance)] -def bm25_topk(db, qtext, k=10): +def bm25_topk(db: sqlite3.Connection, col: str, qtext, k=10): + safe_q = f'"{qtext}"' return [rid for (rid,) in db.execute( - "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k) + f"SELECT rowid FROM fts_{col} WHERE fts_{col} MATCH ? LIMIT ?", (safe_q, k) ).fetchall()] - def wipe_db(): +def fetch_chunk(db: sqlite3.Connection, col: str, id: int): + + return db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (id,)).fetchone()[0] + +def fetch_chunks_in_range(db: sqlite3.Connection, col: str, ids: list[int]): + return db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN (?) ", ids).fetchall() + + + + + +def wipe_db(col: str): db = sqlite3.connect("./rag.db") - db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;") + db.executescript(f"DROP TABLE IF EXISTS chunks_{col}; DROP TABLE IF EXISTS fts_{col}; DROP TABLE IF EXISTS vec_{col};") db.close() diff --git a/rag/ingest.py b/rag/ingest.py index d17690a..5def23d 100644 --- a/rag/ingest.py +++ b/rag/ingest.py @@ -7,16 +7,13 @@ from docling.document_converter import DocumentConverter from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer from rag.db import get_db, init_schema, store_chunks from sentence_transformers import SentenceTransformer +from rag.constants import MAX_TOKENS, BATCH, EMBED_MODEL_ID -EMBED_MODEL_ID = "Qwen/Qwen3-Embedding-8B" -RERANKER_MODEL_ID = "Qwen/Qwen3-Reranker-8B" -MAX_TOKENS = 600 -BATCH = 16 def get_embed_model(): return SentenceTransformer( - "Qwen/Qwen3-Embedding-8B", + EMBED_MODEL_ID, model_kwargs={ # "trust_remote_code":True, "attn_implementation":"flash_attention_2", @@ -48,14 +45,12 @@ def embed_many(model: SentenceTransformer, texts: list[str]): return V_np.astype("float32") -def start_ingest(db: sqlite3.Connection, path: Path): - model = get_embed_model() +def start_ingest(db: sqlite3.Connection, model: SentenceTransformer | None, collection: str, path: Path): + if model is None: + model = get_embed_model() chunks = parse_and_chunk(path, model) V_np = embed_many(model, chunks) - DIM = V_np.shape[1] - db = get_db() - init_schema(db, DIM) - store_chunks(db, chunks, V_np) + store_chunks(db, collection, chunks, V_np) # TODO some try catch? return True # float32/fp16 on CPU? ensure float32 for DB: diff --git a/rag/main.py b/rag/main.py index 221adfa..bbd549a 100644 --- a/rag/main.py +++ b/rag/main.py @@ -1,10 +1,74 @@ import sys import argparse from pathlib import Path -from rag.ingest import start_ingest +from rag.constants import EMBED_MODEL_ID +from rag.ingest import get_embed_model, start_ingest from rag.search import search_hybrid, vec_search -from rag.db import get_db -# your Docling chunker imports… +from rag.db import get_db, check_db, check_db2, init_schema + + +def valid_collection(col: str) -> bool: + # TODO must have less than 9 characters and be ascii, no spaces + return True + + +def cmd_ingest(args): + path = Path(args.file) + if not valid_collection(args.collection): + print(f"Collection name invalid: {args.collection}", file=sys.stderr) + sys.exit(1) + + + if not path.exists(): + print(f"File not found: {path}", file=sys.stderr) + sys.exit(1) + + + db = get_db() + if not check_db2(db, args.collection): + model = get_embed_model() + dim = model.get_sentence_embedding_dimension() + if dim is None: + sys.exit(1) + + # TODO Try catch in here, tell the user if it crashes + init_schema(db, args.collection, dim, EMBED_MODEL_ID, EMBED_MODEL_ID, True, 'idk') + stats = start_ingest(db, model, args.collection, path) + else: + stats = start_ingest(db, None, args.collection, path) + print(f"Ingested file={args.file} :: {stats}") + +def cmd_query(args): + if not valid_collection(args.collection): + print(f"Collection name invalid: {args.collection}", file=sys.stderr) + sys.exit(1) + db = get_db() + if not check_db2(db, args.collection): + print(f"Collection name not in DB, what are you searching: {args.collection}", file=sys.stderr) + sys.exit(1) + + if args.simple: + results = vec_search(db, args.collection, args.query, k=args.k_final) + else: + results = search_hybrid(db, + args.collection, + args.query, + k_vec=args.k_vec, + k_bm25=args.k_bm25, + k_ce=args.k_ce, + k_final=args.k_final, + use_mmr=args.mmr, + mmr_lambda=args.mmr_lambda, + ) + + for rid, txt, score in results: + print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") + + db.close() + + + + def main(): ap = argparse.ArgumentParser(prog="rag") @@ -13,10 +77,12 @@ def main(): # ingest ap_ing = sub.add_parser("ingest", help="Parse, chunk, embed, and index a file into a collection") ap_ing.add_argument("--file", required=True, help="Path to PDF/TXT to ingest") + ap_ing.add_argument("--collection", required=True, help="Collection name (e.g. wm_qwen3)") ap_ing.set_defaults(func=cmd_ingest) # query ap_q = sub.add_parser("query", help="Query a collection") + ap_q.add_argument("--collection", required=True, help="Collection name to search") ap_q.add_argument("--query", required=True, help="User query text") ap_q.add_argument("--simple", action="store_true", help="Vector-only search (skip reranker)") ap_q.add_argument("--mmr", action="store_true", help="Apply MMR after CE") @@ -35,37 +101,3 @@ if __name__ == "__main__": main() - - -def cmd_ingest(args): - path = Path(args.file) - - if not path.exists(): - print(f"File not found: {path}", file=sys.stderr) - sys.exit(1) - - - db = get_db() - stats = start_ingest(db,path) - print(f"Ingested file={args.file} :: {stats}") - -def cmd_query(args): - - db = get_db() - if args.simple: - results = vec_search(db, args.query, k=args.k_final) - else: - results = search_hybrid( - db, args.query, - k_vec=args.k_vec, - k_bm25=args.k_bm25, - k_ce=args.k_ce, - k_final=args.k_final, - use_mmr=args.mmr, - mmr_lambda=args.mmr_lambda, - ) - - for rid, txt, score in results: - print(f"[{rid:05d}] score={score:.3f}\n{txt[:400]}...\n") - - db.close() @@ -1,3 +1,4 @@ +from rag.constants import BATCH import numpy as np def cosine(a, b): return float(np.dot(a, b)) @@ -21,6 +22,20 @@ def mmr(query_vec, cand_ids, cand_vecs, k=8, lamb=0.7): return selected def embed_unit_np(st_model, texts: list[str]) -> np.ndarray: - V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32) + V = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=BATCH) V = V.astype("float32", copy=False) return V + +def mmr2(qvec: np.ndarray, ids, vecs: np.ndarray, k=8, lamb=0.7): + sel, idxs = [], [] + rest = list(range(len(ids))) + best0 = max(rest, key=lambda i: float(qvec @ vecs[i])) + sel.append(ids[best0]); idxs.append(best0); rest.remove(best0) + while rest and len(sel) < k: + def score(i): + rel = float(qvec @ vecs[i]) + red = max(float(vecs[i] @ vecs[j]) for j in idxs) + return lamb*rel - (1-lamb)*red + nxt = max(rest, key=score) + sel.append(ids[nxt]); idxs.append(nxt); rest.remove(nxt) + return sel diff --git a/rag/rerank.py b/rag/rerank.py index 6ae8938..8a2870d 100644 --- a/rag/rerank.py +++ b/rag/rerank.py @@ -1,23 +1,25 @@ # rag/rerank.py import torch +from transformers import AutoTokenizer, AutoModelForSequenceClassification from sentence_transformers import CrossEncoder +from rag.constants import BATCH, RERANKER_ID -RERANKER_ID = "Qwen/Qwen3-Reranker-8B" # or -large if you’ve got VRAM -# RERANKER_ID = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM # device: "cuda" | "cpu" | "mps" def get_rerank_model(): + id = "BAAI/bge-reranker-base" # or -large if you’ve got VRAM return CrossEncoder( - RERANKER_ID, + id, device="cuda", + max_length=384, model_kwargs={ - "attn_implementation":"flash_attention_2", + # "attn_implementation":"flash_attention_2", "device_map":"auto", "dtype":torch.float16 }, tokenizer_kwargs={"padding_side": "left"} ) -def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = 32): +def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tuple[int, str]], batch_size: int = BATCH): """ candidates: [(id, text), ...] returns: [(id, text, score)] sorted desc by score @@ -25,7 +27,29 @@ def rerank_cross_encoder(reranker: CrossEncoder, query: str, candidates: list[tu if not candidates: return [] ids, texts = zip(*candidates) - pairs = [(query, t) for t in texts] + # pairs = [(query, t) for t in texts] + pairs = [(query, t[:1000]) for t in texts] scores = reranker.predict(pairs, batch_size=batch_size) # np.ndarray [N], higher=better ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) return ranked + + + + +# tok = AutoTokenizer.from_pretrained(RERANKER_ID, use_fast=True, model_max_length=384) +# ce = AutoModelForSequenceClassification.from_pretrained( +# RERANKER_ID, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", +# device_map="auto" # or load_in_8bit=True +# ) + +# def ce_scores(query, texts, batch_size=16, max_length=384): +# scores = [] +# for i in range(0, len(texts), batch_size): +# batch = texts[i:i+batch_size] +# enc = tok([ (query, t[:1000]) for t in batch ], +# padding=True, truncation=True, max_length=max_length, +# return_tensors="pt").to(ce.device) +# with torch.inference_mode(): +# logits = ce(**enc).logits.squeeze(-1) # [B] +# scores.extend(logits.float().cpu().tolist()) +# return scores diff --git a/rag/search.py b/rag/search.py index 55b8ffd..291c1b4 100644 --- a/rag/search.py +++ b/rag/search.py @@ -5,12 +5,68 @@ import gc from typing import List, Tuple from sentence_transformers import CrossEncoder, SentenceTransformer +from rag.constants import BATCH from rag.ingest import get_embed_model from rag.rerank import get_rerank_model, rerank_cross_encoder -from rag.mmr import mmr, embed_unit_np # if you added MMR; else remove +from rag.mmr import mmr, mmr2, embed_unit_np # if you added MMR; else remove +from rag.db import vec_topk, bm25_topk, fetch_chunk + + +# ) Query helper (cosine distance; operator may be <#> in sqlite-vec) +# def vec_search(db: sqlite3.Connection, col: str, qtext, k=5): +# model = get_embed_model() +# q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32") +# # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine +# rows = vec_topk(db, col, q, k) + +# # db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40)) +# return [(rid, db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows] + +def _dist_to_sim(dist: float) -> float: + # L2 on unit vectors ↔ cosine: ||a-b||^2 = 2 - 2 cos => cos = 1 - dist/2 + return max(0.0, 1.0 - dist / 2.0) + +def vec_search(db, model: SentenceTransformer, col: str, qtext: str, k: int = 10, min_sim: float = 0.25, + max_per_doc: int | None = None, use_mmr: bool = False, mmr_lambda: float = 0.7): + q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") # <- flatten + rows = vec_topk(db, col, q, k * 4) # overfetch a bit, filter below + + # fetch texts + compute cosine sim + hits = [] + for rid, dist in rows: + txt = fetch_chunk(db, col, rid) + sim = _dist_to_sim(dist) + if sim >= min_sim: + hits.append((rid, txt, sim)) + + # anti-spam (cap near-duplicates from same doc region if you add metadata later) + if max_per_doc: + capped, seen = [], {} + for rid, txt, sim in hits: + dockey = col # or derive from a future chunk_meta table + cnt = seen.get(dockey, 0) + if cnt < max_per_doc: + capped.append((rid, txt, sim)) + seen[dockey] = cnt + 1 + hits = capped + + # optional light MMR on the filtered set (diversify) + if use_mmr and len(hits) > k: + from rag.mmr import embed_unit_np, mmr2 + ids = [h[0] for h in hits] + texts = [h[1] for h in hits] + qvec = q + cvecs = embed_unit_np(model, texts) # [N,D] unit + keep = set(mmr2(qvec, ids, cvecs, k=k, lamb=mmr_lambda)) + hits = [h for h in hits if h[0] in keep] + + # final crop + hits.sort(key=lambda x: x[2], reverse=True) + return hits[:k] + def search_hybrid( - db: sqlite3.Connection, query: str, + db: sqlite3.Connection, col: str, query: str, k_vec: int = 40, k_bm25: int = 40, k_ce: int = 30, # rerank this many k_final: int = 10, # return this many @@ -19,74 +75,57 @@ def search_hybrid( emodel = get_embed_model() # 1) embed query (unit, float32) for vec search + MMR print("loading") - q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") + query_embeddings = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") emodel = None with torch.no_grad(): torch.cuda.empty_cache() gc.collect() print("memory should be free by now!!") - qbytes = q.tobytes() - - # 2) ANN (cosine) + BM25 pools - vec_ids = [i for (i, _) in db.execute( - "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", - (memoryview(qbytes), k_vec) - ).fetchall()] - bm25_ids = [i for (i,) in db.execute( - "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", - (query, k_bm25) - ).fetchall()] - - # 3) merge (vector-first) - seen, merged = set(), [] - for i in vec_ids + bm25_ids: + # qbytes = q.tobytes() + + + # 2) ANN + BM25 + print("phase 2", col, query) + vhits = vec_topk(db, col, query_embeddings, k_vec) # [(id, dist)] + vh_ids = [i for (i, _) in vhits] + bm_ids = bm25_topk(db, col, query, k_bm25) + # + # 3) merge ids [vector-first] + merged, seen = [], set() + for i in vh_ids + bm_ids: if i not in seen: merged.append(i); seen.add(i) if not merged: return [] - # 4) fetch texts for CE + # 4) fetch texts qmarks = ",".join("?"*len(merged)) - cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall() + cand = db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN ({qmarks})", merged).fetchall() + ids, texts = zip(*cand) + + # 5) rerank + print("loading reranking model") reranker = get_rerank_model() - # 5) cross-encoder rerank (returns [(id,text,score)] desc) - ranked = rerank_cross_encoder(reranker, query, cand) - reranker = None - print("freeing again!!") + scores = reranker.predict([(query, t[:1000]) for t in texts], batch_size=BATCH) + reranker =None with torch.no_grad(): torch.cuda.empty_cache() gc.collect() - # - ranked = ranked[:min(k_ce, len(ranked))] - + print("memory should be free by now!!") + print("unloading reranking model") + ranked = sorted(zip(ids, texts, scores), key=lambda x: x[2], reverse=True) if not use_mmr or len(ranked) <= k_final: - return ranked[:k_final] + return ranked[:min(k_ce, k_final)] - # 6) MMR diversity on CE top-k_ce - cand_ids = [i for (i,_,_) in ranked] - cand_text = [t for (_,t,_) in ranked] - emodel = get_embed_model() - # god this is annoying I should stop being poor - cand_vecs = embed_unit_np(emodel, cand_text) # [N,D], unit vectors - sel_ids = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda)) - final = [trip for trip in ranked if trip[0] in sel_ids] # keep CE order, filter by MMR picks - return final[:k_final] - - -# ) Query helper (cosine distance; operator may be <#> in sqlite-vec) -def vec_search(db: sqlite3.Connection, qtext, k=5): - model = get_embed_model() - q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32") - # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine - rows = db.execute( - "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", - (memoryview(q.tobytes()), k) - ).fetchall() - -# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40)) - return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows] + # 6) MMR + ce_ids = [i for (i,_,_) in ranked] + ce_texts = [t for (_,t,_) in ranked] + st_model = get_embed_model() + ce_vecs = st_model.encode(ce_texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32") + keep = set(mmr2(query_embeddings, ce_ids, ce_vecs, k=k_final, lamb=mmr_lambda)) + return [r for r in ranked if r[0] in keep][:k_final] # # Hybrid + CE rerank query: # results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8) @@ -99,3 +138,67 @@ def vec_search(db: sqlite3.Connection, qtext, k=5): # # # +def search_hybrid_with_mmr(db, col, query, k_vec=50, k_bm25=50, k_ce=30, k_final=10, lamb=0.7): + ranked = search_hybrid(db, col, query, k_vec, k_bm25, k_ce, k_ce) + if not ranked: return [] + ids = [i for (i,_,_) in ranked] + texts = [t for (_,t,_) in ranked] + st_model = get_embed_model() + qvec = st_model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") + cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True).astype("float32") + keep = set(mmr2(qvec, ids, cvecs, k=k_final, lamb=lamb)) + return [r for r in ranked if r[0] in keep][:k_final] + + + +# clean vec expansion says gpt5 +# # +# def expand(q, aliases=()): +# qs = [q, *aliases] +# # embed each, take max similarity per chunk at scoring time + + +# def dist_to_cos(d): return max(0.0, 1.0 - d/2.0) # L2 on unit vecs + +# def vec_topk(db, table, q_vec_f32, k): +# from sqlite_vec import serialize_float32 +# return db.execute( +# f"SELECT rowid, distance FROM {table} WHERE embedding MATCH ? ORDER BY distance LIMIT ?", +# (serialize_float32(q_vec_f32), k) +# ).fetchall() + +# def vec_search(db, st_model, col, qtext, k=12, k_raw=None, min_sim=0.30, use_mmr=True, mmr_lambda=0.7): +# if k_raw is None: k_raw = k*4 +# q = st_model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32") +# rows = vec_topk(db, f"vec_{col}", q, k_raw) + +# hits = [] +# for rid, dist in rows: +# cos = dist_to_cos(dist) +# if cos < min_sim: continue +# txt = db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (rid,)).fetchone()[0] +# hits.append((rid, txt, cos)) + +# hits.sort(key=lambda x: x[2], reverse=True) +# if not use_mmr or len(hits) <= k: +# return hits[:k] + +# # MMR on the (already filtered) pool +# ids = [h[0] for h in hits] +# texts = [h[1] for h in hits] +# cvecs = st_model.encode(texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=32).astype("float32") +# # simple MMR +# import numpy as np +# def cosine(a,b): return float(a@b) +# sel, sel_idx = [], [] +# rem = list(range(len(ids))) +# best0 = max(rem, key=lambda i: cosine(q, cvecs[i])); sel.append(ids[best0]); sel_idx.append(best0); rem.remove(best0) +# while rem and len(sel)<k: +# def score(i): +# rel = cosine(q, cvecs[i]) +# red = max(cosine(cvecs[i], cvecs[j]) for j in sel_idx) +# return mmr_lambda*rel - (1.0 - mmr_lambda)*red +# nxt = max(rem, key=score); sel.append(ids[nxt]); sel_idx.append(nxt); rem.remove(nxt) +# keep = set(sel) +# return [h for h in hits if h[0] in keep][:k] + diff --git a/rag/test.py b/rag/test.py index b7a6d8e..864a3ba 100644 --- a/rag/test.py +++ b/rag/test.py @@ -47,29 +47,36 @@ def t(): # db = sqlite3.connect("./rag.db") queries = [ -"How was Shuihu zhuan received in early modern Japan?", -"Edo-period readers’ image of Song Jiang / Liangshan outlaws", -"Channels of transmission for Chinese vernacular fiction into Japan (kanbun kundoku, digests, translations)", -"Role of woodblock prints/illustrations in mediating Chinese fiction", -"Key Japanese scholars, writers, or publishers who popularized Chinese fiction", -"Kyokutei Bakin’s engagement with Chinese vernacular narrative", -"Santō Kyōden, gesaku, and Chinese models", -"Kanzen chōaku (encourage good, punish evil) and Water Margin in Japan", -"Moral ambivalence of outlaw heroes as discussed in the text", -"Censorship or moral debates around reading Chinese fiction", -"Translation strategies from vernacular Chinese to Japanese (furigana, kundoku, glossing)", -"Paratexts: prefaces, commentaries, reader guidance apparatus", -"Bibliographic details: editions, reprints, circulation networks", -"How does this book challenge older narratives about Sino-Japanese literary influence?", -"Methodology: sources, archives, limitations mentioned by the author", +# "How was Shuihu zhuan received in early modern Japan?", +# "Edo-period readers’ image of Song Jiang / Liangshan outlaws", +# "Channels of transmission for Chinese vernacular fiction into Japan (kanbun kundoku, digests, translations)", +# "Role of woodblock prints/illustrations in mediating Chinese fiction", +# "Key Japanese scholars, writers, or publishers who popularized Chinese fiction", +# "Kyokutei Bakin’s engagement with Chinese vernacular narrative", +# "Santō Kyōden, gesaku, and Chinese models", +# "Kanzen chōaku (encourage good, punish evil) and Water Margin in Japan", +# "Moral ambivalence of outlaw heroes as discussed in the text", +# "Censorship or moral debates around reading Chinese fiction", +# "Translation strategies from vernacular Chinese to Japanese (furigana, kundoku, glossing)", +# "Paratexts: prefaces, commentaries, reader guidance apparatus", +# "Bibliographic details: editions, reprints, circulation networks", +# "How does this book challenge older narratives about Sino-Japanese literary influence?", +# "Methodology: sources, archives, limitations mentioned by the author", +"sex" ] def t2(): db = get_db() # Hybrid + CE rerank query: + model = rag.ingest.get_embed_model() for query in queries: print("query", query) - results = rag.search.search_hybrid(db, query, k_vec=50, k_bm25=50, k_final=8) + print("-----------\n\n") + # results = rag.search.search_hybrid(db, "muh", query, k_vec=50, k_bm25=50, k_final=8) + # for rid, txt, score in results: + # sim = score + # print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n") + results = rag.search.vec_search(db, model, "muh", query, k=50, min_sim=0.25, max_per_doc=5, use_mmr=False, mmr_lambda=0.7) for rid, txt, score in results: sim = score print(f"[{rid:04d}] ce_score={sim:.3f}\n{txt[:300]}...\n") diff --git a/rag/utils.py b/rag/utils.py new file mode 100644 index 0000000..10519a4 --- /dev/null +++ b/rag/utils.py @@ -0,0 +1,15 @@ +import re + +FTS_META_CHARS = r'''["'*()^+-]''' # include ? if you see issues + +def sanitize_query(q: str, *, allow_ops: bool = False) -> str: + q = q.strip() + if not q: + return q + if allow_ops: + # escape stray double quotes inside, then wrap + q = q.replace('"', '""') + return f'"{q}"' + # literal search: quote and escape special chars + q = re.sub(r'"', '""', q) + return f'"{q}"' |