blob: 5cadaffe638dec248f4f142ae580078649310137 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
|
{
lib,
python3Packages,
fetchFromGitHub,
symlinkJoin,
pkgs,
}: let
inherit (python3Packages.torch) cudaCapabilities cudaPackages;
inherit (cudaPackages) backendStdenv;
nvidia = pkgs.callPackage ./nvidia.nix {};
in
python3Packages.buildPythonPackage rec {
inherit (nvidia) BUILD_CUDA_EXT CUDA_VERSION preBuild;
pname = "flash-attn";
version = "2.8.2";
pyproject = true;
src = fetchFromGitHub {
owner = "Dao-AILab";
repo = "flash-attention";
rev = "v${version}";
hash = "sha256-iHxfDh+rGanhymP5F7g8rQcQUlP0oXliVF+y+ur/iJ0=";
fetchSubmodules = true;
};
preConfigure = ''
export CC=${lib.getExe' backendStdenv.cc "cc"}
export CXX=${lib.getExe' backendStdenv.cc "c++"}
export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
export FORCE_CUDA=1
'';
build-tools = with python3Packages; [
setuptools
wheel
];
nativeBuildInputs = with pkgs; [
git
which
ninja
];
env.CUDA_HOME = symlinkJoin {
name = "cuda-redist";
paths = buildInputs;
};
buildInputs = with cudaPackages; [
cuda_cudart # cuda_runtime_api.h
libcusparse # cusparse.h
cuda_cccl # nv/target
libcublas # cublas_v2.h
libcusolver # cusolverDn.h
libcurand # curand_kernel.h
cuda_nvcc
];
dependencies = with python3Packages; [
torch
psutil
ninja
einops
packaging
];
pythonImportsCheck = ["flash_attn"];
meta = with lib; {
description = "Fast and memory-efficient exact attention";
homepage = "https://github.com/Dao-AILab/flash-attention";
license = licenses.bsd3;
};
}
|