summaryrefslogtreecommitdiff
path: root/cuda/flashattn
diff options
context:
space:
mode:
Diffstat (limited to 'cuda/flashattn')
-rw-r--r--cuda/flashattn/flake.bkp69
-rw-r--r--cuda/flashattn/flake.lock48
-rw-r--r--cuda/flashattn/flake.nix73
-rw-r--r--cuda/flashattn/sentence.py42
4 files changed, 232 insertions, 0 deletions
diff --git a/cuda/flashattn/flake.bkp b/cuda/flashattn/flake.bkp
new file mode 100644
index 0000000..a24ebff
--- /dev/null
+++ b/cuda/flashattn/flake.bkp
@@ -0,0 +1,69 @@
+{
+ description = "Torch cuda flake using nix-community cachix";
+
+ nixConfig = {
+ extra-substituters = [
+ "https://nix-community.cachix.org"
+ "https://nix-ai-stuff.cachix.org"
+ "https://ai.cachix.org"
+ "https://cuda-maintainers.cachix.org"
+ ];
+ extra-trusted-public-keys = [
+ "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs="
+ "cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E="
+ "ai.cachix.org-1:N9dzRK+alWwoKXQlnn0H6aUx0lU/mspIoz8hMvGvbbc="
+ "nix-ai-stuff.cachix.org-1:WlUGeVCs26w9xF0/rjyg32PujDqbVMlSHufpj1fqix8="
+ ];
+ };
+ inputs = {
+ nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
+ nix-ai-stuff = {
+ url = "github:BatteredBunny/nix-ai-stuff";
+ inputs.nixpkgs.follows = "nixpkgs";
+ };
+ };
+
+ outputs = {
+ self,
+ nixpkgs,
+ nix-ai-stuff,
+ ...
+ }: let
+ pkgs = import nixpkgs {
+ system = "x86_64-linux";
+ config.allowUnfree = true;
+ config.cudaSupport = true;
+ };
+ in {
+ devShell.x86_64-linux = with pkgs;
+ mkShell rec {
+ venvDir = "./.venv";
+ buildInputs = [
+ (pkgs.python3.withPackages (
+ ps:
+ with ps; [
+ torch
+ accelerate
+ transformers
+ typing-extensions
+ psutil
+ ninja
+ einops
+ packaging
+ sentence-transformers
+ nix-ai-stuff.packages.${pkgs.system}.flash-attn
+ ]
+ ))
+ pkgs.virtualenv
+ pkgs.python3Packages.venvShellHook
+ ];
+ HENLO = "${pkgs.lib.makeLibraryPath buildInputs}";
+ postVenvCreation = ''
+ unset SOURCE_DATE_EPOCH
+ '';
+ shellHook = ''
+ fish
+ '';
+ };
+ };
+}
diff --git a/cuda/flashattn/flake.lock b/cuda/flashattn/flake.lock
new file mode 100644
index 0000000..731a240
--- /dev/null
+++ b/cuda/flashattn/flake.lock
@@ -0,0 +1,48 @@
+{
+ "nodes": {
+ "nix-ai-stuff": {
+ "inputs": {
+ "nixpkgs": [
+ "nixpkgs"
+ ]
+ },
+ "locked": {
+ "lastModified": 1756911915,
+ "narHash": "sha256-2b+GPPCM3Av2rZyuqALsOhnN2LTDmg6GmqBGUm8x/ww=",
+ "owner": "BatteredBunny",
+ "repo": "nix-ai-stuff",
+ "rev": "84db92a097d2c87234e096b880e685cd6423eb88",
+ "type": "github"
+ },
+ "original": {
+ "owner": "BatteredBunny",
+ "repo": "nix-ai-stuff",
+ "type": "github"
+ }
+ },
+ "nixpkgs": {
+ "locked": {
+ "lastModified": 1758277210,
+ "narHash": "sha256-iCGWf/LTy+aY0zFu8q12lK8KuZp7yvdhStehhyX1v8w=",
+ "owner": "nixos",
+ "repo": "nixpkgs",
+ "rev": "8eaee110344796db060382e15d3af0a9fc396e0e",
+ "type": "github"
+ },
+ "original": {
+ "owner": "nixos",
+ "ref": "nixos-unstable",
+ "repo": "nixpkgs",
+ "type": "github"
+ }
+ },
+ "root": {
+ "inputs": {
+ "nix-ai-stuff": "nix-ai-stuff",
+ "nixpkgs": "nixpkgs"
+ }
+ }
+ },
+ "root": "root",
+ "version": 7
+}
diff --git a/cuda/flashattn/flake.nix b/cuda/flashattn/flake.nix
new file mode 100644
index 0000000..978bff2
--- /dev/null
+++ b/cuda/flashattn/flake.nix
@@ -0,0 +1,73 @@
+{
+ description = "Torch cuda flake using nix-community cachix";
+
+ nixConfig = {
+ extra-substituters = [
+ "https://nix-community.cachix.org"
+ "https://nix-ai-stuff.cachix.org"
+ "https://ai.cachix.org"
+ "https://cuda-maintainers.cachix.org"
+ ];
+ extra-trusted-public-keys = [
+ "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs="
+ "cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E="
+ "ai.cachix.org-1:N9dzRK+alWwoKXQlnn0H6aUx0lU/mspIoz8hMvGvbbc="
+ "nix-ai-stuff.cachix.org-1:WlUGeVCs26w9xF0/rjyg32PujDqbVMlSHufpj1fqix8="
+ ];
+ };
+ inputs = {
+ nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
+ nix-ai-stuff = {
+ url = "github:BatteredBunny/nix-ai-stuff";
+ inputs.nixpkgs.follows = "nixpkgs";
+ };
+ };
+
+ outputs = {
+ self,
+ nixpkgs,
+ nix-ai-stuff,
+ ...
+ }: let
+ pkgs = import nixpkgs {
+ system = "x86_64-linux";
+ config.allowUnfree = true;
+ config.cudaSupport = true;
+ };
+ in {
+ devShell.x86_64-linux = with pkgs;
+ mkShell rec {
+ venvDir = "./.venv";
+ buildInputs = [
+ (pkgs.python3.withPackages (
+ ps:
+ with ps; [
+ torch
+ accelerate
+ transformers
+ typing-extensions
+ psutil
+ ninja
+ einops
+ packaging
+ sentence-transformers
+ nix-ai-stuff.packages.${pkgs.system}.flash-attn
+ ]
+ ))
+ pkgs.virtualenv
+ pkgs.python3Packages.venvShellHook
+ ];
+ lulz = [
+ pkgs.python3Packages.sentence-transformers
+ nix-ai-stuff.packages.${pkgs.system}.flash-attn
+ ];
+ HENLO = "${pkgs.lib.makeLibraryPath lulz}";
+ postVenvCreation = ''
+ unset SOURCE_DATE_EPOCH
+ '';
+ shellHook = ''
+ fish
+ '';
+ };
+ };
+}
diff --git a/cuda/flashattn/sentence.py b/cuda/flashattn/sentence.py
new file mode 100644
index 0000000..f29927a
--- /dev/null
+++ b/cuda/flashattn/sentence.py
@@ -0,0 +1,42 @@
+# Requires transformers>=4.51.0
+# Requires sentence-transformers>=2.7.0
+import torch
+from sentence_transformers import SentenceTransformer
+
+# Load the model
+# model = SentenceTransformer("Qwen/Qwen3-Embedding-8B")
+
+# We recommend enabling flash_attention_2 for better acceleration and memory saving,
+# together with setting `padding_side` to "left":
+model = SentenceTransformer(
+ "Qwen/Qwen3-Embedding-8B",
+ model_kwargs={"attn_implementation": "flash_attention_2", "device_map": "auto", "dtype": torch.float16},
+ tokenizer_kwargs={"padding_side": "left"},
+)
+
+
+# Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen3Model is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
+
+# The queries and documents to embed
+queries = [
+ "What is the capital of China?",
+ "Explain gravity",
+]
+documents = [
+ "The capital of China is Beijing.",
+ "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
+]
+
+# with torch.autocast(device_type='torch_device'):
+with torch.no_grad():
+# Encode the queries and documents. Note that queries benefit from using a prompt
+# Here we use the prompt called "query" stored under `model.prompts`, but you can
+# also pass your own prompt via the `prompt` argument
+ query_embeddings = model.encode(queries, prompt_name="query")
+ document_embeddings = model.encode(documents)
+
+# Compute the (cosine) similarity between the query and document embeddings
+similarity = model.similarity(query_embeddings, document_embeddings)
+print(similarity)
+# tensor([[0.7493, 0.0751],
+# [0.0880, 0.6318]])