summaryrefslogtreecommitdiff
path: root/rag/db.py
diff options
context:
space:
mode:
authorpolwex <polwex@sortug.com>2025-09-23 03:50:53 +0700
committerpolwex <polwex@sortug.com>2025-09-23 03:50:53 +0700
commit57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (patch)
tree1a7556927bed94377630d33dd29c3bf07d159619 /rag/db.py
init
Diffstat (limited to 'rag/db.py')
-rw-r--r--rag/db.py55
1 files changed, 55 insertions, 0 deletions
diff --git a/rag/db.py b/rag/db.py
new file mode 100644
index 0000000..c015c96
--- /dev/null
+++ b/rag/db.py
@@ -0,0 +1,55 @@
+import sqlite3
+from numpy import ndarray
+from sqlite_vec import serialize_float32
+
+def get_db():
+ db = sqlite3.connect("./rag.db")
+ db.execute("PRAGMA journal_mode=WAL;")
+ db.execute("PRAGMA synchronous=NORMAL;")
+ db.execute("PRAGMA temp_store=MEMORY;")
+ db.execute("PRAGMA mmap_size=300000000;") # 30GB if your OS allows
+ db.enable_load_extension(True)
+ sqlite_vec.load(db)
+ db.enable_load_extension(False)
+ return db
+
+def init_schema(db: sqlite3.Connection, DIM:int):
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])")
+ db.execute('''
+ CREATE TABLE IF NOT EXISTS chunks (
+ id INTEGER PRIMARY KEY,
+ text TEXT
+ )''')
+ db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)")
+ db.commit()
+
+def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray):
+ assert len(chunks) == len(V_np)
+ db.execute("BEGIN")
+ db.executemany('''
+ INSERT INTO chunks(id, text) VALUES (?, ?)
+ ''', list(enumerate(chunks, start=1)))
+ db.executemany(
+ "INSERT INTO vec(rowid, embedding) VALUES (?, ?)",
+ [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]
+ )
+ db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
+ db.commit()
+
+def vec_topk(db, q_vec_f32, k=10):
+ rows = db.execute(
+ "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
+ (serialize_float32(q_vec_f32), k)
+ ).fetchall()
+ return rows # [(rowid, distance)]
+
+
+def bm25_topk(db, qtext, k=10):
+ return [rid for (rid,) in db.execute(
+ "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k)
+ ).fetchall()]
+
+ def wipe_db():
+ db = sqlite3.connect("./rag.db")
+ db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;")
+ db.close()