mirror of
https://github.com/NixOS/nixpkgs.git
synced 2024-11-23 07:23:20 +00:00
python3Packages.xformers: fix building with cudaSupport
(cherry picked from commit 0c3b659e77
)
This commit is contained in:
parent
3680bb6b00
commit
dd18a6e7ca
@ -1,5 +1,6 @@
|
||||
{
|
||||
lib,
|
||||
stdenv,
|
||||
buildPythonPackage,
|
||||
pythonOlder,
|
||||
fetchFromGitHub,
|
||||
@ -52,12 +53,14 @@ buildPythonPackage {
|
||||
# noqa: C801
|
||||
__version__ = "${version}"
|
||||
EOF
|
||||
''
|
||||
+ lib.optionalString cudaSupport ''
|
||||
export CUDA_HOME=${cudaPackages.cuda_nvcc}
|
||||
export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
|
||||
'';
|
||||
|
||||
env = lib.attrsets.optionalAttrs cudaSupport {
|
||||
TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
|
||||
};
|
||||
|
||||
stdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv;
|
||||
|
||||
buildInputs = lib.optionals cudaSupport (
|
||||
with cudaPackages;
|
||||
[
|
||||
@ -71,7 +74,9 @@ buildPythonPackage {
|
||||
]
|
||||
);
|
||||
|
||||
nativeBuildInputs = [ which ];
|
||||
nativeBuildInputs = [ which ] ++ lib.optionals cudaSupport (with cudaPackages; [
|
||||
cuda_nvcc
|
||||
]);
|
||||
|
||||
propagatedBuildInputs = [
|
||||
numpy
|
||||
|
Loading…
Reference in New Issue
Block a user