summaryrefslogtreecommitdiff
path: root/derivations
diff options
context:
space:
mode:
Diffstat (limited to 'derivations')
-rw-r--r--derivations/flash-attn/default.nix75
-rw-r--r--derivations/flash-attn/flake.lock27
-rw-r--r--derivations/flash-attn/flake.nix27
-rw-r--r--derivations/flash-attn/nvidia.nix24
4 files changed, 153 insertions, 0 deletions
diff --git a/derivations/flash-attn/default.nix b/derivations/flash-attn/default.nix
new file mode 100644
index 0000000..5cadaff
--- /dev/null
+++ b/derivations/flash-attn/default.nix
@@ -0,0 +1,75 @@
+{
+ lib,
+ python3Packages,
+ fetchFromGitHub,
+ symlinkJoin,
+ pkgs,
+}: let
+ inherit (python3Packages.torch) cudaCapabilities cudaPackages;
+ inherit (cudaPackages) backendStdenv;
+
+ nvidia = pkgs.callPackage ./nvidia.nix {};
+in
+ python3Packages.buildPythonPackage rec {
+ inherit (nvidia) BUILD_CUDA_EXT CUDA_VERSION preBuild;
+ pname = "flash-attn";
+ version = "2.8.2";
+ pyproject = true;
+
+ src = fetchFromGitHub {
+ owner = "Dao-AILab";
+ repo = "flash-attention";
+ rev = "v${version}";
+ hash = "sha256-iHxfDh+rGanhymP5F7g8rQcQUlP0oXliVF+y+ur/iJ0=";
+ fetchSubmodules = true;
+ };
+
+ preConfigure = ''
+ export CC=${lib.getExe' backendStdenv.cc "cc"}
+ export CXX=${lib.getExe' backendStdenv.cc "c++"}
+ export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
+ export FORCE_CUDA=1
+ '';
+
+ build-tools = with python3Packages; [
+ setuptools
+ wheel
+ ];
+
+ nativeBuildInputs = with pkgs; [
+ git
+ which
+ ninja
+ ];
+
+ env.CUDA_HOME = symlinkJoin {
+ name = "cuda-redist";
+ paths = buildInputs;
+ };
+
+ buildInputs = with cudaPackages; [
+ cuda_cudart # cuda_runtime_api.h
+ libcusparse # cusparse.h
+ cuda_cccl # nv/target
+ libcublas # cublas_v2.h
+ libcusolver # cusolverDn.h
+ libcurand # curand_kernel.h
+ cuda_nvcc
+ ];
+
+ dependencies = with python3Packages; [
+ torch
+ psutil
+ ninja
+ einops
+ packaging
+ ];
+
+ pythonImportsCheck = ["flash_attn"];
+
+ meta = with lib; {
+ description = "Fast and memory-efficient exact attention";
+ homepage = "https://github.com/Dao-AILab/flash-attention";
+ license = licenses.bsd3;
+ };
+ }
diff --git a/derivations/flash-attn/flake.lock b/derivations/flash-attn/flake.lock
new file mode 100644
index 0000000..80b4b37
--- /dev/null
+++ b/derivations/flash-attn/flake.lock
@@ -0,0 +1,27 @@
+{
+ "nodes": {
+ "nixpkgs": {
+ "locked": {
+ "lastModified": 1758035966,
+ "narHash": "sha256-qqIJ3yxPiB0ZQTT9//nFGQYn8X/PBoJbofA7hRKZnmE=",
+ "owner": "NixOS",
+ "repo": "nixpkgs",
+ "rev": "8d4ddb19d03c65a36ad8d189d001dc32ffb0306b",
+ "type": "github"
+ },
+ "original": {
+ "owner": "NixOS",
+ "ref": "nixos-unstable",
+ "repo": "nixpkgs",
+ "type": "github"
+ }
+ },
+ "root": {
+ "inputs": {
+ "nixpkgs": "nixpkgs"
+ }
+ }
+ },
+ "root": "root",
+ "version": 7
+}
diff --git a/derivations/flash-attn/flake.nix b/derivations/flash-attn/flake.nix
new file mode 100644
index 0000000..0e09fdf
--- /dev/null
+++ b/derivations/flash-attn/flake.nix
@@ -0,0 +1,27 @@
+# https://github.com/BatteredBunny/nix-ai-stuff/tree/main
+{
+ description = "Gemini CLI flake";
+
+ inputs = {
+ nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
+ };
+
+ outputs = {
+ self,
+ nixpkgs,
+ }: let
+ system = "x86_64-linux";
+ pkgs = import nixpkgs {
+ system = system;
+ config = {
+ allowUnfree = true;
+ cudaSupport = true;
+ };
+ };
+ in {
+ packages.${system}.default = pkgs.callPackage ./default.nix {
+ # inherit (pkgs) lib python3Packages fetchFromGitHub symlinkJoin;
+ # pkgs = pkgs;
+ };
+ };
+}
diff --git a/derivations/flash-attn/nvidia.nix b/derivations/flash-attn/nvidia.nix
new file mode 100644
index 0000000..7b83e84
--- /dev/null
+++ b/derivations/flash-attn/nvidia.nix
@@ -0,0 +1,24 @@
+{
+ symlinkJoin,
+ cudaPackages,
+ pkgs,
+ cudaCapabilities ? pkgs.cudaPackages.flags.cudaCapabilities,
+ lib,
+}: {
+ BUILD_CUDA_EXT = "1";
+
+ CUDA_HOME = symlinkJoin {
+ name = "cuda-redist";
+ paths = with cudaPackages; [
+ cuda_cudart # cuda_runtime.h
+ cuda_nvcc
+ ];
+ };
+
+ CUDA_VERSION = cudaPackages.cudaMajorMinorVersion;
+
+ preBuild = ''
+ export PATH=${pkgs.gcc13Stdenv.cc}/bin:$PATH
+ export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
+ '';
+}