summaryrefslogtreecommitdiff
path: root/derivations/flash-attn/nvidia.nix
diff options
context:
space:
mode:
Diffstat (limited to 'derivations/flash-attn/nvidia.nix')
-rw-r--r--derivations/flash-attn/nvidia.nix24
1 files changed, 24 insertions, 0 deletions
diff --git a/derivations/flash-attn/nvidia.nix b/derivations/flash-attn/nvidia.nix
new file mode 100644
index 0000000..7b83e84
--- /dev/null
+++ b/derivations/flash-attn/nvidia.nix
@@ -0,0 +1,24 @@
+{
+ symlinkJoin,
+ cudaPackages,
+ pkgs,
+ cudaCapabilities ? pkgs.cudaPackages.flags.cudaCapabilities,
+ lib,
+}: {
+ BUILD_CUDA_EXT = "1";
+
+ CUDA_HOME = symlinkJoin {
+ name = "cuda-redist";
+ paths = with cudaPackages; [
+ cuda_cudart # cuda_runtime.h
+ cuda_nvcc
+ ];
+ };
+
+ CUDA_VERSION = cudaPackages.cudaMajorMinorVersion;
+
+ preBuild = ''
+ export PATH=${pkgs.gcc13Stdenv.cc}/bin:$PATH
+ export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
+ '';
+}