python3Packages.xformers: fix building with cudaSupport

(cherry picked from commit 0c3b659e77)
This commit is contained in:
Robin Appelman 2024-06-24 20:19:26 +02:00 committed by github-actions[bot]
parent 3680bb6b00
commit dd18a6e7ca

View File

@ -1,5 +1,6 @@
{
lib,
stdenv,
buildPythonPackage,
pythonOlder,
fetchFromGitHub,
@ -52,12 +53,14 @@ buildPythonPackage {
# noqa: C801
__version__ = "${version}"
EOF
''
+ lib.optionalString cudaSupport ''
export CUDA_HOME=${cudaPackages.cuda_nvcc}
export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
'';
env = lib.attrsets.optionalAttrs cudaSupport {
TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
};
stdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv;
buildInputs = lib.optionals cudaSupport (
with cudaPackages;
[
@ -71,7 +74,9 @@ buildPythonPackage {
]
);
nativeBuildInputs = [ which ];
nativeBuildInputs = [ which ] ++ lib.optionals cudaSupport (with cudaPackages; [
cuda_nvcc
]);
propagatedBuildInputs = [
numpy