diff options
Diffstat (limited to 'cuda/flashattn')
-rw-r--r-- | cuda/flashattn/flake.bkp | 69 | ||||
-rw-r--r-- | cuda/flashattn/flake.lock | 48 | ||||
-rw-r--r-- | cuda/flashattn/flake.nix | 73 | ||||
-rw-r--r-- | cuda/flashattn/sentence.py | 42 |
4 files changed, 232 insertions, 0 deletions
diff --git a/cuda/flashattn/flake.bkp b/cuda/flashattn/flake.bkp new file mode 100644 index 0000000..a24ebff --- /dev/null +++ b/cuda/flashattn/flake.bkp @@ -0,0 +1,69 @@ +{ + description = "Torch cuda flake using nix-community cachix"; + + nixConfig = { + extra-substituters = [ + "https://nix-community.cachix.org" + "https://nix-ai-stuff.cachix.org" + "https://ai.cachix.org" + "https://cuda-maintainers.cachix.org" + ]; + extra-trusted-public-keys = [ + "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=" + "cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=" + "ai.cachix.org-1:N9dzRK+alWwoKXQlnn0H6aUx0lU/mspIoz8hMvGvbbc=" + "nix-ai-stuff.cachix.org-1:WlUGeVCs26w9xF0/rjyg32PujDqbVMlSHufpj1fqix8=" + ]; + }; + inputs = { + nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; + nix-ai-stuff = { + url = "github:BatteredBunny/nix-ai-stuff"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + }; + + outputs = { + self, + nixpkgs, + nix-ai-stuff, + ... + }: let + pkgs = import nixpkgs { + system = "x86_64-linux"; + config.allowUnfree = true; + config.cudaSupport = true; + }; + in { + devShell.x86_64-linux = with pkgs; + mkShell rec { + venvDir = "./.venv"; + buildInputs = [ + (pkgs.python3.withPackages ( + ps: + with ps; [ + torch + accelerate + transformers + typing-extensions + psutil + ninja + einops + packaging + sentence-transformers + nix-ai-stuff.packages.${pkgs.system}.flash-attn + ] + )) + pkgs.virtualenv + pkgs.python3Packages.venvShellHook + ]; + HENLO = "${pkgs.lib.makeLibraryPath buildInputs}"; + postVenvCreation = '' + unset SOURCE_DATE_EPOCH + ''; + shellHook = '' + fish + ''; + }; + }; +} diff --git a/cuda/flashattn/flake.lock b/cuda/flashattn/flake.lock new file mode 100644 index 0000000..731a240 --- /dev/null +++ b/cuda/flashattn/flake.lock @@ -0,0 +1,48 @@ +{ + "nodes": { + "nix-ai-stuff": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1756911915, + "narHash": "sha256-2b+GPPCM3Av2rZyuqALsOhnN2LTDmg6GmqBGUm8x/ww=", + "owner": "BatteredBunny", + "repo": "nix-ai-stuff", + "rev": "84db92a097d2c87234e096b880e685cd6423eb88", + "type": "github" + }, + "original": { + "owner": "BatteredBunny", + "repo": "nix-ai-stuff", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1758277210, + "narHash": "sha256-iCGWf/LTy+aY0zFu8q12lK8KuZp7yvdhStehhyX1v8w=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "8eaee110344796db060382e15d3af0a9fc396e0e", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "nix-ai-stuff": "nix-ai-stuff", + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/cuda/flashattn/flake.nix b/cuda/flashattn/flake.nix new file mode 100644 index 0000000..978bff2 --- /dev/null +++ b/cuda/flashattn/flake.nix @@ -0,0 +1,73 @@ +{ + description = "Torch cuda flake using nix-community cachix"; + + nixConfig = { + extra-substituters = [ + "https://nix-community.cachix.org" + "https://nix-ai-stuff.cachix.org" + "https://ai.cachix.org" + "https://cuda-maintainers.cachix.org" + ]; + extra-trusted-public-keys = [ + "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=" + "cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=" + "ai.cachix.org-1:N9dzRK+alWwoKXQlnn0H6aUx0lU/mspIoz8hMvGvbbc=" + "nix-ai-stuff.cachix.org-1:WlUGeVCs26w9xF0/rjyg32PujDqbVMlSHufpj1fqix8=" + ]; + }; + inputs = { + nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; + nix-ai-stuff = { + url = "github:BatteredBunny/nix-ai-stuff"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + }; + + outputs = { + self, + nixpkgs, + nix-ai-stuff, + ... + }: let + pkgs = import nixpkgs { + system = "x86_64-linux"; + config.allowUnfree = true; + config.cudaSupport = true; + }; + in { + devShell.x86_64-linux = with pkgs; + mkShell rec { + venvDir = "./.venv"; + buildInputs = [ + (pkgs.python3.withPackages ( + ps: + with ps; [ + torch + accelerate + transformers + typing-extensions + psutil + ninja + einops + packaging + sentence-transformers + nix-ai-stuff.packages.${pkgs.system}.flash-attn + ] + )) + pkgs.virtualenv + pkgs.python3Packages.venvShellHook + ]; + lulz = [ + pkgs.python3Packages.sentence-transformers + nix-ai-stuff.packages.${pkgs.system}.flash-attn + ]; + HENLO = "${pkgs.lib.makeLibraryPath lulz}"; + postVenvCreation = '' + unset SOURCE_DATE_EPOCH + ''; + shellHook = '' + fish + ''; + }; + }; +} diff --git a/cuda/flashattn/sentence.py b/cuda/flashattn/sentence.py new file mode 100644 index 0000000..f29927a --- /dev/null +++ b/cuda/flashattn/sentence.py @@ -0,0 +1,42 @@ +# Requires transformers>=4.51.0 +# Requires sentence-transformers>=2.7.0 +import torch +from sentence_transformers import SentenceTransformer + +# Load the model +# model = SentenceTransformer("Qwen/Qwen3-Embedding-8B") + +# We recommend enabling flash_attention_2 for better acceleration and memory saving, +# together with setting `padding_side` to "left": +model = SentenceTransformer( + "Qwen/Qwen3-Embedding-8B", + model_kwargs={"attn_implementation": "flash_attention_2", "device_map": "auto", "dtype": torch.float16}, + tokenizer_kwargs={"padding_side": "left"}, +) + + +# Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen3Model is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)` + +# The queries and documents to embed +queries = [ + "What is the capital of China?", + "Explain gravity", +] +documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", +] + +# with torch.autocast(device_type='torch_device'): +with torch.no_grad(): +# Encode the queries and documents. Note that queries benefit from using a prompt +# Here we use the prompt called "query" stored under `model.prompts`, but you can +# also pass your own prompt via the `prompt` argument + query_embeddings = model.encode(queries, prompt_name="query") + document_embeddings = model.encode(documents) + +# Compute the (cosine) similarity between the query and document embeddings +similarity = model.similarity(query_embeddings, document_embeddings) +print(similarity) +# tensor([[0.7493, 0.0751], +# [0.0880, 0.6318]]) |