embeds/pkg/main.py
2024-11-27 22:29:30 +07:00

79 lines
2.4 KiB
Python

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import time
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import uvicorn
import gc
import os
app = FastAPI()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
# Load model (do this only once at startup)
print("Loading model...")
model = AutoModel.from_pretrained('nvidia/NV-Embed-v2', trust_remote_code=True, torch_dtype="auto")
model.eval()
# model.half()
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(device.type)
model.to(device)
if device.type == 'cuda':
# gc.collect()
torch.cuda.empty_cache()
torch.cuda.memory.empty_cache()
torch.cuda.max_memory_allocated()
print(f"Model loaded and moved to {device}")
# Define request and response models
class EmbeddingRequest(BaseModel):
texts: List[str]
is_query: bool
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
time_taken: float
# Each query needs to be accompanied by an corresponding instruction describing the task.
task_name_to_instruct = {"example": "Given a question, retrieve passages that answer the question",}
query_prefix = "Instruct: " + task_name_to_instruct["example"] + "\nQuery: "
max_length = 32768
@app.post("/embed", response_model=EmbeddingResponse)
async def embed_texts(request: EmbeddingRequest):
start_time = time.time()
try:
with torch.no_grad():
if device.type == 'cuda':
# gc.collect()
torch.cuda.empty_cache()
torch.cuda.memory.empty_cache()
if request.is_query:
texts = [query_prefix + text for text in request.texts]
else:
texts = request.texts
embeddings = model.encode(texts, max_length=max_length)
embeddings = F.normalize(embeddings, p=2, dim=1)
print(f"Embedding size: {embeddings.size()}")
embeddings_list = embeddings.cpu().float().tolist()
time_taken = time.time() - start_time
return EmbeddingResponse(embeddings=embeddings_list, time_taken=time_taken)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)