summaryrefslogtreecommitdiff
path: root/rag/db.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/db.py')
-rw-r--r--rag/db.py75
1 files changed, 59 insertions, 16 deletions
diff --git a/rag/db.py b/rag/db.py
index c015c96..94ff3fb 100644
--- a/rag/db.py
+++ b/rag/db.py
@@ -1,6 +1,7 @@
import sqlite3
from numpy import ndarray
from sqlite_vec import serialize_float32
+import sqlite_vec
def get_db():
db = sqlite3.connect("./rag.db")
@@ -13,43 +14,85 @@ def get_db():
db.enable_load_extension(False)
return db
-def init_schema(db: sqlite3.Connection, DIM:int):
- db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])")
- db.execute('''
- CREATE TABLE IF NOT EXISTS chunks (
+def init_schema(db: sqlite3.Connection, col: str, DIM:int, model_id: str, tok_id: str, normalize: bool, preproc_hash: str):
+ print("initing schema", col)
+ db.execute("""
+ CREATE TABLE IF NOT EXISTS collections(
+ name TEXT PRIMARY KEY,
+ model TEXT,
+ tokenizer TEXT,
+ dim INTEGER,
+ normalize INTEGER,
+ preproc_hash TEXT,
+ created_at INTEGER DEFAULT (unixepoch())
+ )""")
+ db.execute("BEGIN")
+ db.execute("INSERT INTO collections(name, model, tokenizer, dim, normalize, preproc_hash) VALUES(?, ?, ?, ?, ?, ?)",
+ (col, model_id, tok_id, DIM, int(normalize), preproc_hash))
+
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec_{col} USING vec0(embedding float[{DIM}])")
+ db.execute(f'''
+ CREATE TABLE IF NOT EXISTS chunks_{col} (
id INTEGER PRIMARY KEY,
text TEXT
)''')
- db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)")
+ db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS fts_{col} USING fts5(text)")
db.commit()
-def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray):
+
+def check_db(db: sqlite3.Connection, coll: str):
+ row = db.execute("SELECT dim, model, normalize FROM collections WHERE name=?", (coll,)).fetchone()
+ return row
+ # assert row and row[0] == DIM and row[1] == EMBED_MODEL_ID and row[2] == 1 # if you normalize
+def check_db2(db: sqlite3.Connection, coll: str):
+ row = db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name=?",
+ (f"vec_{coll}",)
+ ).fetchone()
+ return bool(row)
+def store_chunks(db: sqlite3.Connection, col: str, chunks: list[str], V_np:ndarray):
assert len(chunks) == len(V_np)
db.execute("BEGIN")
- db.executemany('''
- INSERT INTO chunks(id, text) VALUES (?, ?)
+ db.executemany(f'''
+ INSERT INTO chunks_{col}(id, text) VALUES (?, ?)
''', list(enumerate(chunks, start=1)))
db.executemany(
- "INSERT INTO vec(rowid, embedding) VALUES (?, ?)",
+ f"INSERT INTO vec_{col}(rowid, embedding) VALUES (?, ?)",
[(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]
)
- db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
+ db.executemany(f"INSERT INTO fts_{col}(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
db.commit()
-def vec_topk(db, q_vec_f32, k=10):
+def vec_topk(db,col: str, q_vec_f32, k=10):
+ # rows = db.execute(
+ # "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
+ # (memoryview(q.tobytes()), k)
+ # ).fetchall()
rows = db.execute(
- "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
+ f"SELECT rowid, distance FROM vec_{col} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
(serialize_float32(q_vec_f32), k)
).fetchall()
return rows # [(rowid, distance)]
-def bm25_topk(db, qtext, k=10):
+def bm25_topk(db: sqlite3.Connection, col: str, qtext, k=10):
+ safe_q = f'"{qtext}"'
return [rid for (rid,) in db.execute(
- "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k)
+ f"SELECT rowid FROM fts_{col} WHERE fts_{col} MATCH ? LIMIT ?", (safe_q, k)
).fetchall()]
- def wipe_db():
+def fetch_chunk(db: sqlite3.Connection, col: str, id: int):
+
+ return db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (id,)).fetchone()[0]
+
+def fetch_chunks_in_range(db: sqlite3.Connection, col: str, ids: list[int]):
+ return db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN (?) ", ids).fetchall()
+
+
+
+
+
+def wipe_db(col: str):
db = sqlite3.connect("./rag.db")
- db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;")
+ db.executescript(f"DROP TABLE IF EXISTS chunks_{col}; DROP TABLE IF EXISTS fts_{col}; DROP TABLE IF EXISTS vec_{col};")
db.close()