summaryrefslogtreecommitdiff
path: root/tf.py
diff options
context:
space:
mode:
Diffstat (limited to 'tf.py')
-rw-r--r--tf.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/tf.py b/tf.py
new file mode 100644
index 0000000..9ffd868
--- /dev/null
+++ b/tf.py
@@ -0,0 +1,65 @@
+
+import torch
+import torch.nn.functional as F
+
+from torch import Tensor
+from transformers import AutoModel, AutoTokenizer
+
+
+def last_token_pool(last_hidden_states: Tensor,
+ attention_mask: Tensor) -> Tensor:
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
+ if left_padding:
+ return last_hidden_states[:, -1]
+ else:
+ sequence_lengths = attention_mask.sum(dim=1) - 1
+ batch_size = last_hidden_states.shape[0]
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
+
+
+def get_detailed_instruct(task_description: str, query: str) -> str:
+ return f'Instruct: {task_description}\nQuery:{query}'
+
+# Each query must come with a one-sentence instruction that describes the task
+task = 'Given a web search query, retrieve relevant passages that answer the query'
+
+queries = [
+ get_detailed_instruct(task, 'What is the capital of China?'),
+ get_detailed_instruct(task, 'Explain gravity')
+]
+# No need to add instruction for retrieval documents
+documents = [
+ "The capital of China is Beijing.",
+ "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
+]
+input_texts = queries + documents
+
+
+# We recommend enabling flash_attention_2 for better acceleration and memory saving,
+# together with setting `padding_side` to "left":
+model = AutoModel.from_pretrained(
+ "Qwen/Qwen3-Embedding-8B", attn_implementation="flash_attention_2", device_map="auto", torch_dtype=torch.float16
+)
+tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-8B', padding_side="left")
+# The queries and documents to embed
+max_length = 8192
+
+# Tokenize the input texts
+batch_dict = tokenizer(
+ input_texts,
+ padding=True,
+ truncation=True,
+ max_length=max_length,
+ return_tensors="pt",
+)
+batch_dict.to(model.device)
+with torch.no_grad():
+ outputs = model(**batch_dict)
+ embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
+
+ # normalize embeddings
+ embeddings = F.normalize(embeddings, p=2, dim=1)
+ scores = (embeddings[:2] @ embeddings[2:].T)
+
+print(scores.tolist())
+# [[0.7645568251609802, 0.14142508804798126], [0.13549736142158508, 0.5999549627304077]]