diff options
author | polwex <polwex@sortug.com> | 2025-09-24 23:38:36 +0700 |
---|---|---|
committer | polwex <polwex@sortug.com> | 2025-09-24 23:38:36 +0700 |
commit | 734b89570040e97f0c7743c4c0bc28e30a3cd4ee (patch) | |
tree | 7142d9f37908138c38d0ade066e960c3a1c69f5d /rag/db.py | |
parent | 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (diff) |
Diffstat (limited to 'rag/db.py')
-rw-r--r-- | rag/db.py | 75 |
1 files changed, 59 insertions, 16 deletions
@@ -1,6 +1,7 @@ import sqlite3 from numpy import ndarray from sqlite_vec import serialize_float32 +import sqlite_vec def get_db(): db = sqlite3.connect("./rag.db") @@ -13,43 +14,85 @@ def get_db(): db.enable_load_extension(False) return db -def init_schema(db: sqlite3.Connection, DIM:int): - db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])") - db.execute(''' - CREATE TABLE IF NOT EXISTS chunks ( +def init_schema(db: sqlite3.Connection, col: str, DIM:int, model_id: str, tok_id: str, normalize: bool, preproc_hash: str): + print("initing schema", col) + db.execute(""" + CREATE TABLE IF NOT EXISTS collections( + name TEXT PRIMARY KEY, + model TEXT, + tokenizer TEXT, + dim INTEGER, + normalize INTEGER, + preproc_hash TEXT, + created_at INTEGER DEFAULT (unixepoch()) + )""") + db.execute("BEGIN") + db.execute("INSERT INTO collections(name, model, tokenizer, dim, normalize, preproc_hash) VALUES(?, ?, ?, ?, ?, ?)", + (col, model_id, tok_id, DIM, int(normalize), preproc_hash)) + + db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec_{col} USING vec0(embedding float[{DIM}])") + db.execute(f''' + CREATE TABLE IF NOT EXISTS chunks_{col} ( id INTEGER PRIMARY KEY, text TEXT )''') - db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)") + db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS fts_{col} USING fts5(text)") db.commit() -def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray): + +def check_db(db: sqlite3.Connection, coll: str): + row = db.execute("SELECT dim, model, normalize FROM collections WHERE name=?", (coll,)).fetchone() + return row + # assert row and row[0] == DIM and row[1] == EMBED_MODEL_ID and row[2] == 1 # if you normalize +def check_db2(db: sqlite3.Connection, coll: str): + row = db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (f"vec_{coll}",) + ).fetchone() + return bool(row) +def store_chunks(db: sqlite3.Connection, col: str, chunks: list[str], V_np:ndarray): assert len(chunks) == len(V_np) db.execute("BEGIN") - db.executemany(''' - INSERT INTO chunks(id, text) VALUES (?, ?) + db.executemany(f''' + INSERT INTO chunks_{col}(id, text) VALUES (?, ?) ''', list(enumerate(chunks, start=1))) db.executemany( - "INSERT INTO vec(rowid, embedding) VALUES (?, ?)", + f"INSERT INTO vec_{col}(rowid, embedding) VALUES (?, ?)", [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))] ) - db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) + db.executemany(f"INSERT INTO fts_{col}(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) db.commit() -def vec_topk(db, q_vec_f32, k=10): +def vec_topk(db,col: str, q_vec_f32, k=10): + # rows = db.execute( + # "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", + # (memoryview(q.tobytes()), k) + # ).fetchall() rows = db.execute( - "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + f"SELECT rowid, distance FROM vec_{col} WHERE embedding MATCH ? ORDER BY distance LIMIT ?", (serialize_float32(q_vec_f32), k) ).fetchall() return rows # [(rowid, distance)] -def bm25_topk(db, qtext, k=10): +def bm25_topk(db: sqlite3.Connection, col: str, qtext, k=10): + safe_q = f'"{qtext}"' return [rid for (rid,) in db.execute( - "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k) + f"SELECT rowid FROM fts_{col} WHERE fts_{col} MATCH ? LIMIT ?", (safe_q, k) ).fetchall()] - def wipe_db(): +def fetch_chunk(db: sqlite3.Connection, col: str, id: int): + + return db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (id,)).fetchone()[0] + +def fetch_chunks_in_range(db: sqlite3.Connection, col: str, ids: list[int]): + return db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN (?) ", ids).fetchall() + + + + + +def wipe_db(col: str): db = sqlite3.connect("./rag.db") - db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;") + db.executescript(f"DROP TABLE IF EXISTS chunks_{col}; DROP TABLE IF EXISTS fts_{col}; DROP TABLE IF EXISTS vec_{col};") db.close() |