diff options
author | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
---|---|---|
committer | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
commit | 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (patch) | |
tree | 1a7556927bed94377630d33dd29c3bf07d159619 /rag/db.py |
init
Diffstat (limited to 'rag/db.py')
-rw-r--r-- | rag/db.py | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/rag/db.py b/rag/db.py new file mode 100644 index 0000000..c015c96 --- /dev/null +++ b/rag/db.py @@ -0,0 +1,55 @@ +import sqlite3 +from numpy import ndarray +from sqlite_vec import serialize_float32 + +def get_db(): + db = sqlite3.connect("./rag.db") + db.execute("PRAGMA journal_mode=WAL;") + db.execute("PRAGMA synchronous=NORMAL;") + db.execute("PRAGMA temp_store=MEMORY;") + db.execute("PRAGMA mmap_size=300000000;") # 30GB if your OS allows + db.enable_load_extension(True) + sqlite_vec.load(db) + db.enable_load_extension(False) + return db + +def init_schema(db: sqlite3.Connection, DIM:int): + db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])") + db.execute(''' + CREATE TABLE IF NOT EXISTS chunks ( + id INTEGER PRIMARY KEY, + text TEXT + )''') + db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)") + db.commit() + +def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray): + assert len(chunks) == len(V_np) + db.execute("BEGIN") + db.executemany(''' + INSERT INTO chunks(id, text) VALUES (?, ?) + ''', list(enumerate(chunks, start=1))) + db.executemany( + "INSERT INTO vec(rowid, embedding) VALUES (?, ?)", + [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))] + ) + db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) + db.commit() + +def vec_topk(db, q_vec_f32, k=10): + rows = db.execute( + "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + (serialize_float32(q_vec_f32), k) + ).fetchall() + return rows # [(rowid, distance)] + + +def bm25_topk(db, qtext, k=10): + return [rid for (rid,) in db.execute( + "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k) + ).fetchall()] + + def wipe_db(): + db = sqlite3.connect("./rag.db") + db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;") + db.close() |