import sqlite3 from numpy import ndarray from sqlite_vec import serialize_float32 import sqlite_vec def get_db(): db = sqlite3.connect("./rag.db") db.execute("PRAGMA journal_mode=WAL;") db.execute("PRAGMA synchronous=NORMAL;") db.execute("PRAGMA temp_store=MEMORY;") db.execute("PRAGMA mmap_size=300000000;") # 30GB if your OS allows db.enable_load_extension(True) sqlite_vec.load(db) db.enable_load_extension(False) return db def init_schema(db: sqlite3.Connection, col: str, DIM:int, model_id: str, tok_id: str, normalize: bool, preproc_hash: str): print("initing schema", col) db.execute(""" CREATE TABLE IF NOT EXISTS collections( name TEXT PRIMARY KEY, model TEXT, tokenizer TEXT, dim INTEGER, normalize INTEGER, preproc_hash TEXT, created_at INTEGER DEFAULT (unixepoch()) )""") db.execute("BEGIN") db.execute("INSERT INTO collections(name, model, tokenizer, dim, normalize, preproc_hash) VALUES(?, ?, ?, ?, ?, ?)", (col, model_id, tok_id, DIM, int(normalize), preproc_hash)) db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec_{col} USING vec0(embedding float[{DIM}])") db.execute(f''' CREATE TABLE IF NOT EXISTS chunks_{col} ( id INTEGER PRIMARY KEY, text TEXT )''') db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS fts_{col} USING fts5(text)") db.commit() def check_db(db: sqlite3.Connection, coll: str): row = db.execute("SELECT dim, model, normalize FROM collections WHERE name=?", (coll,)).fetchone() return row # assert row and row[0] == DIM and row[1] == EMBED_MODEL_ID and row[2] == 1 # if you normalize def check_db2(db: sqlite3.Connection, coll: str): row = db.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (f"vec_{coll}",) ).fetchone() return bool(row) def store_chunks(db: sqlite3.Connection, col: str, chunks: list[str], V_np:ndarray): assert len(chunks) == len(V_np) db.execute("BEGIN") db.executemany(f''' INSERT INTO chunks_{col}(id, text) VALUES (?, ?) ''', list(enumerate(chunks, start=1))) db.executemany( f"INSERT INTO vec_{col}(rowid, embedding) VALUES (?, ?)", [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))] ) db.executemany(f"INSERT INTO fts_{col}(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) db.commit() def vec_topk(db,col: str, q_vec_f32, k=10): # rows = db.execute( # "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?", # (memoryview(q.tobytes()), k) # ).fetchall() rows = db.execute( f"SELECT rowid, distance FROM vec_{col} WHERE embedding MATCH ? ORDER BY distance LIMIT ?", (serialize_float32(q_vec_f32), k) ).fetchall() return rows # [(rowid, distance)] def bm25_topk(db: sqlite3.Connection, col: str, qtext, k=10): safe_q = f'"{qtext}"' return [rid for (rid,) in db.execute( f"SELECT rowid FROM fts_{col} WHERE fts_{col} MATCH ? LIMIT ?", (safe_q, k) ).fetchall()] def fetch_chunk(db: sqlite3.Connection, col: str, id: int): return db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (id,)).fetchone()[0] def fetch_chunks_in_range(db: sqlite3.Connection, col: str, ids: list[int]): return db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN (?) ", ids).fetchall() def wipe_db(col: str): db = sqlite3.connect("./rag.db") db.executescript(f"DROP TABLE IF EXISTS chunks_{col}; DROP TABLE IF EXISTS fts_{col}; DROP TABLE IF EXISTS vec_{col};") db.close()