import sqlite3 from numpy import ndarray from sqlite_vec import serialize_float32 def get_db(): db = sqlite3.connect("./rag.db") db.execute("PRAGMA journal_mode=WAL;") db.execute("PRAGMA synchronous=NORMAL;") db.execute("PRAGMA temp_store=MEMORY;") db.execute("PRAGMA mmap_size=300000000;") # 30GB if your OS allows db.enable_load_extension(True) sqlite_vec.load(db) db.enable_load_extension(False) return db def init_schema(db: sqlite3.Connection, DIM:int): db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])") db.execute(''' CREATE TABLE IF NOT EXISTS chunks ( id INTEGER PRIMARY KEY, text TEXT )''') db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)") db.commit() def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray): assert len(chunks) == len(V_np) db.execute("BEGIN") db.executemany(''' INSERT INTO chunks(id, text) VALUES (?, ?) ''', list(enumerate(chunks, start=1))) db.executemany( "INSERT INTO vec(rowid, embedding) VALUES (?, ?)", [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))] ) db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1))) db.commit() def vec_topk(db, q_vec_f32, k=10): rows = db.execute( "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?", (serialize_float32(q_vec_f32), k) ).fetchall() return rows # [(rowid, distance)] def bm25_topk(db, qtext, k=10): return [rid for (rid,) in db.execute( "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k) ).fetchall()] def wipe_db(): db = sqlite3.connect("./rag.db") db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;") db.close()