summaryrefslogtreecommitdiff
path: root/rag/search.py
blob: 55b8ffd60eaaab47558d6f88ea5aee1488077f50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import sqlite3
import numpy as np
import torch
import gc
from typing import List, Tuple

from sentence_transformers import CrossEncoder, SentenceTransformer
from rag.ingest import get_embed_model
from rag.rerank import get_rerank_model, rerank_cross_encoder
from rag.mmr import mmr, embed_unit_np  # if you added MMR; else remove

def search_hybrid(
    db: sqlite3.Connection, query: str,
    k_vec: int = 40, k_bm25: int = 40,
    k_ce: int = 30,                 # rerank this many
    k_final: int = 10,              # return this many
    use_mmr: bool = False, mmr_lambda: float = 0.7
):
    emodel = get_embed_model()
    # 1) embed query (unit, float32) for vec search + MMR
    print("loading")
    q = emodel.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].astype("float32")
    emodel = None
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect()
    print("memory should be free by now!!")
    qbytes = q.tobytes()

    # 2) ANN (cosine) + BM25 pools
    vec_ids = [i for (i, _) in db.execute(
        "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
        (memoryview(qbytes), k_vec)
    ).fetchall()]
    bm25_ids = [i for (i,) in db.execute(
        "SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?",
        (query, k_bm25)
    ).fetchall()]

    # 3) merge (vector-first)
    seen, merged = set(), []
    for i in vec_ids + bm25_ids:
        if i not in seen:
            merged.append(i); seen.add(i)
    if not merged:
        return []

    # 4) fetch texts for CE
    qmarks = ",".join("?"*len(merged))
    cand = db.execute(f"SELECT id, text FROM chunks WHERE id IN ({qmarks})", merged).fetchall()

    reranker = get_rerank_model()
    # 5) cross-encoder rerank (returns [(id,text,score)] desc)
    ranked = rerank_cross_encoder(reranker, query, cand)
    reranker = None
    print("freeing again!!")
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect()
    #
    ranked = ranked[:min(k_ce, len(ranked))]

    if not use_mmr or len(ranked) <= k_final:
        return ranked[:k_final]

    # 6) MMR diversity on CE top-k_ce
    cand_ids  = [i for (i,_,_) in ranked]
    cand_text = [t for (_,t,_) in ranked]
    emodel = get_embed_model()
    # god this is annoying I should stop being poor
    cand_vecs = embed_unit_np(emodel, cand_text)   # [N,D], unit vectors
    sel_ids   = set(mmr(q, cand_ids, cand_vecs, k=k_final, lamb=mmr_lambda))
    final = [trip for trip in ranked if trip[0] in sel_ids]  # keep CE order, filter by MMR picks
    return final[:k_final]


# ) Query helper (cosine distance; operator may be <#> in sqlite-vec)
def vec_search(db: sqlite3.Connection,   qtext, k=5):
  model = get_embed_model()
  q = model.encode([qtext], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
    # Cosine distance operator in sqlite-vec is `<#>`; if your build differs, check docs: <-> L2, <=> dot, <#> cosine
  rows = db.execute(
          "SELECT rowid, distance FROM vec ORDER BY embedding <#> ? LIMIT ?",
          (memoryview(q.tobytes()), k)
           ).fetchall()

# db.execute("SELECT rowid FROM fts WHERE fts MATCH ? LIMIT ?", (qtext, 40))
  return [(rid, db.execute("SELECT text FROM chunks WHERE id=?", (rid,)).fetchone()[0], dist) for rid, dist in rows]


# # Hybrid + CE rerank query:
# results = search_hybrid("indemnification obligations survive termination", model, k_vec=50, k_bm25=50, k_final=8)
# for rid, txt, score in results:
#     print(f"[{rid:04d}] score={score:.3f}\n{txt[:300]}...\n")
#
#
#
#
#
#
#