diff options
Diffstat (limited to 'cuda/flashattn/sentence.py')
-rw-r--r-- | cuda/flashattn/sentence.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/cuda/flashattn/sentence.py b/cuda/flashattn/sentence.py new file mode 100644 index 0000000..f29927a --- /dev/null +++ b/cuda/flashattn/sentence.py @@ -0,0 +1,42 @@ +# Requires transformers>=4.51.0 +# Requires sentence-transformers>=2.7.0 +import torch +from sentence_transformers import SentenceTransformer + +# Load the model +# model = SentenceTransformer("Qwen/Qwen3-Embedding-8B") + +# We recommend enabling flash_attention_2 for better acceleration and memory saving, +# together with setting `padding_side` to "left": +model = SentenceTransformer( + "Qwen/Qwen3-Embedding-8B", + model_kwargs={"attn_implementation": "flash_attention_2", "device_map": "auto", "dtype": torch.float16}, + tokenizer_kwargs={"padding_side": "left"}, +) + + +# Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen3Model is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)` + +# The queries and documents to embed +queries = [ + "What is the capital of China?", + "Explain gravity", +] +documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", +] + +# with torch.autocast(device_type='torch_device'): +with torch.no_grad(): +# Encode the queries and documents. Note that queries benefit from using a prompt +# Here we use the prompt called "query" stored under `model.prompts`, but you can +# also pass your own prompt via the `prompt` argument + query_embeddings = model.encode(queries, prompt_name="query") + document_embeddings = model.encode(documents) + +# Compute the (cosine) similarity between the query and document embeddings +similarity = model.similarity(query_embeddings, document_embeddings) +print(similarity) +# tensor([[0.7493, 0.0751], +# [0.0880, 0.6318]]) |