79 lines
2.4 KiB
Python
79 lines
2.4 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from transformers import AutoTokenizer, AutoModel
|
|
import time
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import List
|
|
import uvicorn
|
|
import gc
|
|
import os
|
|
|
|
app = FastAPI()
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
|
|
|
|
# Load model (do this only once at startup)
|
|
print("Loading model...")
|
|
model = AutoModel.from_pretrained('nvidia/NV-Embed-v2', trust_remote_code=True, torch_dtype="auto")
|
|
model.eval()
|
|
# model.half()
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
|
print(device.type)
|
|
model.to(device)
|
|
if device.type == 'cuda':
|
|
# gc.collect()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.memory.empty_cache()
|
|
torch.cuda.max_memory_allocated()
|
|
|
|
|
|
print(f"Model loaded and moved to {device}")
|
|
|
|
# Define request and response models
|
|
class EmbeddingRequest(BaseModel):
|
|
texts: List[str]
|
|
is_query: bool
|
|
|
|
class EmbeddingResponse(BaseModel):
|
|
embeddings: List[List[float]]
|
|
time_taken: float
|
|
|
|
# Each query needs to be accompanied by an corresponding instruction describing the task.
|
|
task_name_to_instruct = {"example": "Given a question, retrieve passages that answer the question",}
|
|
query_prefix = "Instruct: " + task_name_to_instruct["example"] + "\nQuery: "
|
|
|
|
max_length = 32768
|
|
|
|
@app.post("/embed", response_model=EmbeddingResponse)
|
|
async def embed_texts(request: EmbeddingRequest):
|
|
start_time = time.time()
|
|
|
|
try:
|
|
with torch.no_grad():
|
|
if device.type == 'cuda':
|
|
# gc.collect()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.memory.empty_cache()
|
|
if request.is_query:
|
|
texts = [query_prefix + text for text in request.texts]
|
|
else:
|
|
texts = request.texts
|
|
|
|
embeddings = model.encode(texts, max_length=max_length)
|
|
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
|
|
print(f"Embedding size: {embeddings.size()}")
|
|
|
|
embeddings_list = embeddings.cpu().float().tolist()
|
|
|
|
time_taken = time.time() - start_time
|
|
return EmbeddingResponse(embeddings=embeddings_list, time_taken=time_taken)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|