diff options
Diffstat (limited to 'derivations/flash-attn/default.nix')
| -rw-r--r-- | derivations/flash-attn/default.nix | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/derivations/flash-attn/default.nix b/derivations/flash-attn/default.nix new file mode 100644 index 0000000..5cadaff --- /dev/null +++ b/derivations/flash-attn/default.nix @@ -0,0 +1,75 @@ +{ + lib, + python3Packages, + fetchFromGitHub, + symlinkJoin, + pkgs, +}: let + inherit (python3Packages.torch) cudaCapabilities cudaPackages; + inherit (cudaPackages) backendStdenv; + + nvidia = pkgs.callPackage ./nvidia.nix {}; +in + python3Packages.buildPythonPackage rec { + inherit (nvidia) BUILD_CUDA_EXT CUDA_VERSION preBuild; + pname = "flash-attn"; + version = "2.8.2"; + pyproject = true; + + src = fetchFromGitHub { + owner = "Dao-AILab"; + repo = "flash-attention"; + rev = "v${version}"; + hash = "sha256-iHxfDh+rGanhymP5F7g8rQcQUlP0oXliVF+y+ur/iJ0="; + fetchSubmodules = true; + }; + + preConfigure = '' + export CC=${lib.getExe' backendStdenv.cc "cc"} + export CXX=${lib.getExe' backendStdenv.cc "c++"} + export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}" + export FORCE_CUDA=1 + ''; + + build-tools = with python3Packages; [ + setuptools + wheel + ]; + + nativeBuildInputs = with pkgs; [ + git + which + ninja + ]; + + env.CUDA_HOME = symlinkJoin { + name = "cuda-redist"; + paths = buildInputs; + }; + + buildInputs = with cudaPackages; [ + cuda_cudart # cuda_runtime_api.h + libcusparse # cusparse.h + cuda_cccl # nv/target + libcublas # cublas_v2.h + libcusolver # cusolverDn.h + libcurand # curand_kernel.h + cuda_nvcc + ]; + + dependencies = with python3Packages; [ + torch + psutil + ninja + einops + packaging + ]; + + pythonImportsCheck = ["flash_attn"]; + + meta = with lib; { + description = "Fast and memory-efficient exact attention"; + homepage = "https://github.com/Dao-AILab/flash-attention"; + license = licenses.bsd3; + }; + } |
