summaryrefslogtreecommitdiff
path: root/cuda/flashattn/sentence.py
blob: f29927aa190ead86fce223dca0520a8434b20ddc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Requires transformers>=4.51.0
# Requires sentence-transformers>=2.7.0
import torch
from sentence_transformers import SentenceTransformer

# Load the model
# model = SentenceTransformer("Qwen/Qwen3-Embedding-8B")

# We recommend enabling flash_attention_2 for better acceleration and memory saving,
# together with setting `padding_side` to "left":
model = SentenceTransformer(
    "Qwen/Qwen3-Embedding-8B",
    model_kwargs={"attn_implementation": "flash_attention_2", "device_map": "auto", "dtype": torch.float16},
    tokenizer_kwargs={"padding_side": "left"},
)


# Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen3Model is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`

# The queries and documents to embed
queries = [
    "What is the capital of China?",
    "Explain gravity",
]
documents = [
    "The capital of China is Beijing.",
    "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]

# with torch.autocast(device_type='torch_device'):
with torch.no_grad():
# Encode the queries and documents. Note that queries benefit from using a prompt
# Here we use the prompt called "query" stored under `model.prompts`, but you can
# also pass your own prompt via the `prompt` argument
    query_embeddings = model.encode(queries, prompt_name="query")
    document_embeddings = model.encode(documents)

# Compute the (cosine) similarity between the query and document embeddings
similarity = model.similarity(query_embeddings, document_embeddings)
print(similarity)
# tensor([[0.7493, 0.0751],
#         [0.0880, 0.6318]])