summaryrefslogtreecommitdiff
path: root/rag/db.py
blob: c015c962abaed9a2b121a9104373b14b98f1425e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sqlite3
from numpy import ndarray
from sqlite_vec import serialize_float32

def get_db():
  db = sqlite3.connect("./rag.db")
  db.execute("PRAGMA journal_mode=WAL;")
  db.execute("PRAGMA synchronous=NORMAL;")
  db.execute("PRAGMA temp_store=MEMORY;")
  db.execute("PRAGMA mmap_size=300000000;")  # 30GB if your OS allows
  db.enable_load_extension(True)
  sqlite_vec.load(db)
  db.enable_load_extension(False)
  return db

def init_schema(db: sqlite3.Connection, DIM:int):
  db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec USING vec0(embedding float[{DIM}])")
  db.execute('''
    CREATE TABLE IF NOT EXISTS chunks (
     id INTEGER PRIMARY KEY,
     text TEXT
    )''')
  db.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts USING fts5(text)")
  db.commit()

def store_chunks(db: sqlite3.Connection, chunks: list[str], V_np:ndarray):
  assert len(chunks) == len(V_np)
  db.execute("BEGIN")
  db.executemany('''
    INSERT INTO chunks(id, text) VALUES (?, ?)
  ''', list(enumerate(chunks, start=1)))
  db.executemany(
      "INSERT INTO vec(rowid, embedding) VALUES (?, ?)",
      [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]
  )
  db.executemany("INSERT INTO fts(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
  db.commit()

def vec_topk(db, q_vec_f32, k=10):
    rows = db.execute(
        "SELECT rowid, distance FROM vec WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
        (serialize_float32(q_vec_f32), k)
    ).fetchall()
    return rows  # [(rowid, distance)]


def bm25_topk(db, qtext, k=10):
    return [rid for (rid,) in db.execute(
        "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, k)
    ).fetchall()]    

  def wipe_db():
    db = sqlite3.connect("./rag.db")
    db.executescript("DROP TABLE IF EXISTS chunks; DROP TABLE IF EXISTS fts; DROP TABLE IF EXISTS vec;")
    db.close()