From 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 Mon Sep 17 00:00:00 2001 From: polwex Date: Tue, 23 Sep 2025 03:50:53 +0700 Subject: init --- rag/db.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 rag/db.py (limited to 'rag/db.py') diff --git a/rag/db.py b/rag/db.py new file mode 100644 index 0000000..c015c96 --- /dev/null +++ b/rag/db.py @@ -0,0 +1,55 @@ +import sqlite3 +from numpy import ndarray +from sqlite_vec import serialize_float32 + +def get_db(): + db = sqlite3.connect("./rag.db") + db.execute("PRAGMA journal_mode=WAL;") + db.execute("PRAGMA synchronous=NORMAL;") + db.execute("PRAGMA temp_store=MEMORY;") + db.execute("PRAGMA mmap_size=300000000;") # 30GB if your OS allows + db.enable_load_extension(True) + sqlite_vec.load(db) + db.enable_load_extension(False) + return db + +def init_schema(db: sqlite3.Connection, DIM:int): + db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])") + db.execute(''' + CREATE TABLE IF NOT EXISTS chunks ( + id INTEGER PRIMARY KEY, + text TEXT + )''') + db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)") + db.commit() + +def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray): + assert len(chunks) == len(V_np) + db.execute("BEGIN") + db.executemany(''' + INSERT INTO chunks(id, text) VALUES (?, ?) + ''', list(enumerate(chunks, start=1))) + db.executemany( + "INSERT INTO vec(rowid, embedding) VALUES (?, ?)", + [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))] + ) + db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) + db.commit() + +def vec_topk(db, q_vec_f32, k=10): + rows = db.execute( + "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + (serialize_float32(q_vec_f32), k) + ).fetchall() + return rows # [(rowid, distance)] + + +def bm25_topk(db, qtext, k=10): + return [rid for (rid,) in db.execute( + "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k) + ).fetchall()] + + def wipe_db(): + db = sqlite3.connect("./rag.db") + db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;") + db.close() -- cgit v1.2.3