import torch import torch.nn.functional as F from torch import Tensor from transformers import AutoModel, AutoTokenizer def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] def get_detailed_instruct(task_description: str, query: str) -> str: return f'Instruct: {task_description}\nQuery:{query}' # Each query must come with a one-sentence instruction that describes the task task = 'Given a web search query, retrieve relevant passages that answer the query' queries = [ get_detailed_instruct(task, 'What is the capital of China?'), get_detailed_instruct(task, 'Explain gravity') ] # No need to add instruction for retrieval documents documents = [ "The capital of China is Beijing.", "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." ] input_texts = queries + documents # We recommend enabling flash_attention_2 for better acceleration and memory saving, # together with setting `padding_side` to "left": model = AutoModel.from_pretrained( "Qwen/Qwen3-Embedding-8B", attn_implementation="flash_attention_2", device_map="auto", torch_dtype=torch.float16 ) tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-8B', padding_side="left") # The queries and documents to embed max_length = 8192 # Tokenize the input texts batch_dict = tokenizer( input_texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt", ) batch_dict.to(model.device) with torch.no_grad(): outputs = model(**batch_dict) embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) # normalize embeddings embeddings = F.normalize(embeddings, p=2, dim=1) scores = (embeddings[:2] @ embeddings[2:].T) print(scores.tolist()) # [[0.7645568251609802, 0.14142508804798126], [0.13549736142158508, 0.5999549627304077]]