summaryrefslogtreecommitdiff
path: root/sentence.py
diff options
context:
space:
mode:
authorpolwex <polwex@sortug.com>2025-09-23 03:50:53 +0700
committerpolwex <polwex@sortug.com>2025-09-23 03:50:53 +0700
commit57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (patch)
tree1a7556927bed94377630d33dd29c3bf07d159619 /sentence.py
init
Diffstat (limited to 'sentence.py')
-rw-r--r--sentence.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/sentence.py b/sentence.py
new file mode 100644
index 0000000..f29927a
--- /dev/null
+++ b/sentence.py
@@ -0,0 +1,42 @@
+# Requires transformers>=4.51.0
+# Requires sentence-transformers>=2.7.0
+import torch
+from sentence_transformers import SentenceTransformer
+
+# Load the model
+# model = SentenceTransformer("Qwen/Qwen3-Embedding-8B")
+
+# We recommend enabling flash_attention_2 for better acceleration and memory saving,
+# together with setting `padding_side` to "left":
+model = SentenceTransformer(
+ "Qwen/Qwen3-Embedding-8B",
+ model_kwargs={"attn_implementation": "flash_attention_2", "device_map": "auto", "dtype": torch.float16},
+ tokenizer_kwargs={"padding_side": "left"},
+)
+
+
+# Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen3Model is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
+
+# The queries and documents to embed
+queries = [
+ "What is the capital of China?",
+ "Explain gravity",
+]
+documents = [
+ "The capital of China is Beijing.",
+ "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
+]
+
+# with torch.autocast(device_type='torch_device'):
+with torch.no_grad():
+# Encode the queries and documents. Note that queries benefit from using a prompt
+# Here we use the prompt called "query" stored under `model.prompts`, but you can
+# also pass your own prompt via the `prompt` argument
+ query_embeddings = model.encode(queries, prompt_name="query")
+ document_embeddings = model.encode(documents)
+
+# Compute the (cosine) similarity between the query and document embeddings
+similarity = model.similarity(query_embeddings, document_embeddings)
+print(similarity)
+# tensor([[0.7493, 0.0751],
+# [0.0880, 0.6318]])