diff options
author | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
---|---|---|
committer | polwex <polwex@sortug.com> | 2025-09-23 03:50:53 +0700 |
commit | 57aaafdb137fe49930711f6ed5ccc83b3a119cd2 (patch) | |
tree | 1a7556927bed94377630d33dd29c3bf07d159619 /tf.py |
init
Diffstat (limited to 'tf.py')
-rw-r--r-- | tf.py | 65 |
1 files changed, 65 insertions, 0 deletions
@@ -0,0 +1,65 @@ + +import torch +import torch.nn.functional as F + +from torch import Tensor +from transformers import AutoModel, AutoTokenizer + + +def last_token_pool(last_hidden_states: Tensor, + attention_mask: Tensor) -> Tensor: + left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + if left_padding: + return last_hidden_states[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] + + +def get_detailed_instruct(task_description: str, query: str) -> str: + return f'Instruct: {task_description}\nQuery:{query}' + +# Each query must come with a one-sentence instruction that describes the task +task = 'Given a web search query, retrieve relevant passages that answer the query' + +queries = [ + get_detailed_instruct(task, 'What is the capital of China?'), + get_detailed_instruct(task, 'Explain gravity') +] +# No need to add instruction for retrieval documents +documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." +] +input_texts = queries + documents + + +# We recommend enabling flash_attention_2 for better acceleration and memory saving, +# together with setting `padding_side` to "left": +model = AutoModel.from_pretrained( + "Qwen/Qwen3-Embedding-8B", attn_implementation="flash_attention_2", device_map="auto", torch_dtype=torch.float16 +) +tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-8B', padding_side="left") +# The queries and documents to embed +max_length = 8192 + +# Tokenize the input texts +batch_dict = tokenizer( + input_texts, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", +) +batch_dict.to(model.device) +with torch.no_grad(): + outputs = model(**batch_dict) + embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) + + # normalize embeddings + embeddings = F.normalize(embeddings, p=2, dim=1) + scores = (embeddings[:2] @ embeddings[2:].T) + +print(scores.tolist()) +# [[0.7645568251609802, 0.14142508804798126], [0.13549736142158508, 0.5999549627304077]] |