summaryrefslogtreecommitdiff
path: root/derivations/flash-attn/nvidia.nix
blob: 7b83e84641df83cf5bf43cc97dc27e47f97de371 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
{
  symlinkJoin,
  cudaPackages,
  pkgs,
  cudaCapabilities ? pkgs.cudaPackages.flags.cudaCapabilities,
  lib,
}: {
  BUILD_CUDA_EXT = "1";

  CUDA_HOME = symlinkJoin {
    name = "cuda-redist";
    paths = with cudaPackages; [
      cuda_cudart # cuda_runtime.h
      cuda_nvcc
    ];
  };

  CUDA_VERSION = cudaPackages.cudaMajorMinorVersion;

  preBuild = ''
    export PATH=${pkgs.gcc13Stdenv.cc}/bin:$PATH
    export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
  '';
}