summaryrefslogtreecommitdiff
path: root/rag/db.py
blob: 94ff3fb8ac3f5d60bd124158ff3ce6bd18fc4fc4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import sqlite3
from numpy import ndarray
from sqlite_vec import serialize_float32
import sqlite_vec

def get_db():
  db = sqlite3.connect("./rag.db")
  db.execute("PRAGMA journal_mode=WAL;")
  db.execute("PRAGMA synchronous=NORMAL;")
  db.execute("PRAGMA temp_store=MEMORY;")
  db.execute("PRAGMA mmap_size=300000000;")  # 30GB if your OS allows
  db.enable_load_extension(True)
  sqlite_vec.load(db)
  db.enable_load_extension(False)
  return db

def init_schema(db: sqlite3.Connection, col: str, DIM:int, model_id: str, tok_id: str, normalize: bool, preproc_hash: str):
  print("initing schema", col)
  db.execute("""
    CREATE TABLE IF NOT EXISTS collections(
     name TEXT PRIMARY KEY,
     model TEXT,
     tokenizer TEXT,
     dim INTEGER,
     normalize INTEGER,
     preproc_hash TEXT,
     created_at INTEGER DEFAULT (unixepoch())
  )""")
  db.execute("BEGIN")
  db.execute("INSERT INTO collections(name, model, tokenizer, dim, normalize, preproc_hash) VALUES(?, ?, ?, ?, ?, ?)",
     (col, model_id, tok_id, DIM, int(normalize), preproc_hash))

  db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS vec_{col} USING vec0(embedding float[{DIM}])")
  db.execute(f'''
    CREATE TABLE IF NOT EXISTS chunks_{col} (
     id INTEGER PRIMARY KEY,
     text TEXT
    )''')
  db.execute(f"CREATE VIRTUAL TABLE IF NOT EXISTS fts_{col} USING fts5(text)")
  db.commit()

  
def check_db(db: sqlite3.Connection,   coll: str):
  row = db.execute("SELECT dim, model, normalize FROM collections WHERE name=?", (coll,)).fetchone()
  return row
  # assert row and row[0] == DIM and row[1] == EMBED_MODEL_ID and row[2] == 1  # if you normalize
def check_db2(db: sqlite3.Connection, coll: str):
    row = db.execute(
        "SELECT name FROM sqlite_master WHERE type='table' AND name=?",
        (f"vec_{coll}",)
    ).fetchone()
    return bool(row)
def store_chunks(db: sqlite3.Connection, col: str, chunks: list[str], V_np:ndarray):
  assert len(chunks) == len(V_np)
  db.execute("BEGIN")
  db.executemany(f'''
    INSERT INTO chunks_{col}(id, text) VALUES (?, ?)
  ''', list(enumerate(chunks, start=1)))
  db.executemany(
      f"INSERT INTO vec_{col}(rowid, embedding) VALUES (?, ?)",
      [(i+1, memoryview(V_np[i].tobytes())) for i in range(len(chunks))]
  )
  db.executemany(f"INSERT INTO fts_{col}(rowid, text) VALUES (?, ?)", list(enumerate(chunks, start=1)))
  db.commit()

def vec_topk(db,col: str,  q_vec_f32, k=10):
  # rows = db.execute(
  #         "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
  #         (memoryview(q.tobytes()), k)
  #          ).fetchall()
    rows = db.execute(
        f"SELECT rowid, distance FROM vec_{col} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
        (serialize_float32(q_vec_f32), k)
    ).fetchall()
    return rows  # [(rowid, distance)]


def bm25_topk(db: sqlite3.Connection, col: str, qtext, k=10):
    safe_q = f'"{qtext}"'
    return [rid for (rid,) in db.execute(
        f"SELECT rowid FROM fts_{col} WHERE fts_{col} MATCH ? LIMIT ?", (safe_q, k)
    ).fetchall()]    

def fetch_chunk(db: sqlite3.Connection, col: str, id: int):

  return db.execute(f"SELECT text FROM chunks_{col} WHERE id=?", (id,)).fetchone()[0]
  
def fetch_chunks_in_range(db: sqlite3.Connection, col: str, ids: list[int]):
  return db.execute(f"SELECT id, text FROM chunks_{col} WHERE id IN (?) ", ids).fetchall()
  


  

def wipe_db(col: str):
    db = sqlite3.connect("./rag.db")
    db.executescript(f"DROP TABLE IF EXISTS chunks_{col}; DROP TABLE IF EXISTS fts_{col}; DROP TABLE IF EXISTS vec_{col};")
    db.close()