diff options
Diffstat (limited to 'derivations/flash-attn/nvidia.nix')
| -rw-r--r-- | derivations/flash-attn/nvidia.nix | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/derivations/flash-attn/nvidia.nix b/derivations/flash-attn/nvidia.nix new file mode 100644 index 0000000..7b83e84 --- /dev/null +++ b/derivations/flash-attn/nvidia.nix @@ -0,0 +1,24 @@ +{ + symlinkJoin, + cudaPackages, + pkgs, + cudaCapabilities ? pkgs.cudaPackages.flags.cudaCapabilities, + lib, +}: { + BUILD_CUDA_EXT = "1"; + + CUDA_HOME = symlinkJoin { + name = "cuda-redist"; + paths = with cudaPackages; [ + cuda_cudart # cuda_runtime.h + cuda_nvcc + ]; + }; + + CUDA_VERSION = cudaPackages.cudaMajorMinorVersion; + + preBuild = '' + export PATH=${pkgs.gcc13Stdenv.cc}/bin:$PATH + export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}" + ''; +} |
