summaryrefslogtreecommitdiff
path: root/derivations/flash-attn/default.nix
blob: 5cadaffe638dec248f4f142ae580078649310137 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
{
  lib,
  python3Packages,
  fetchFromGitHub,
  symlinkJoin,
  pkgs,
}: let
  inherit (python3Packages.torch) cudaCapabilities cudaPackages;
  inherit (cudaPackages) backendStdenv;

  nvidia = pkgs.callPackage ./nvidia.nix {};
in
  python3Packages.buildPythonPackage rec {
    inherit (nvidia) BUILD_CUDA_EXT CUDA_VERSION preBuild;
    pname = "flash-attn";
    version = "2.8.2";
    pyproject = true;

    src = fetchFromGitHub {
      owner = "Dao-AILab";
      repo = "flash-attention";
      rev = "v${version}";
      hash = "sha256-iHxfDh+rGanhymP5F7g8rQcQUlP0oXliVF+y+ur/iJ0=";
      fetchSubmodules = true;
    };

    preConfigure = ''
      export CC=${lib.getExe' backendStdenv.cc "cc"}
      export CXX=${lib.getExe' backendStdenv.cc "c++"}
      export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
      export FORCE_CUDA=1
    '';

    build-tools = with python3Packages; [
      setuptools
      wheel
    ];

    nativeBuildInputs = with pkgs; [
      git
      which
      ninja
    ];

    env.CUDA_HOME = symlinkJoin {
      name = "cuda-redist";
      paths = buildInputs;
    };

    buildInputs = with cudaPackages; [
      cuda_cudart # cuda_runtime_api.h
      libcusparse # cusparse.h
      cuda_cccl # nv/target
      libcublas # cublas_v2.h
      libcusolver # cusolverDn.h
      libcurand # curand_kernel.h
      cuda_nvcc
    ];

    dependencies = with python3Packages; [
      torch
      psutil
      ninja
      einops
      packaging
    ];

    pythonImportsCheck = ["flash_attn"];

    meta = with lib; {
      description = "Fast and memory-efficient exact attention";
      homepage = "https://github.com/Dao-AILab/flash-attention";
      license = licenses.bsd3;
    };
  }