python3Packages.tensorflow: add cudaCapabilities argument

Rearrange tensorflow to allow overriding cudaCapabilities.
This is needed when debugging the tensorflow derivation
This commit is contained in:
Someone Serge 2023-02-27 14:54:09 +02:00
parent 94bbbb0471
commit cf7fb1d08f
No known key found for this signature in database
GPG Key ID: 7B0E3B1390D61DA4
2 changed files with 6 additions and 5 deletions

View File

@ -1,6 +1,6 @@
{ config
, lib
, cudatoolkit
, cudaVersion
}:
# Type aliases
@ -13,7 +13,6 @@
let
inherit (lib) attrsets lists strings trivial versions;
cudaVersion = cudatoolkit.version;
# Flags are determined based on your CUDA toolkit by default. You may benefit
# from improved performance, reduced file size, or greater hardware suppport by

View File

@ -17,7 +17,9 @@
# that in nix as well. It would make some things easier and less confusing, but
# it would also make the default tensorflow package unfree. See
# https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0
, cudaSupport ? false, cudaPackages ? {}
, cudaSupport ? false
, cudaPackages ? { }
, cudaCapabilities ? cudaPackages.cudaFlags.cudaCapabilities
, mklSupport ? false, mkl
, tensorboardSupport ? true
# XLA without CUDA is broken
@ -30,7 +32,7 @@
}:
let
inherit (cudaPackages) cudatoolkit cudaFlags cudnn nccl;
inherit (cudaPackages) cudatoolkit cudnn nccl;
in
assert cudaSupport -> cudatoolkit != null
@ -301,7 +303,7 @@ let
TF_CUDA_PATHS = lib.optionalString cudaSupport "${cudatoolkit_joined},${cudnn},${nccl}";
GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc";
TF_CUDA_COMPUTE_CAPABILITIES = builtins.concatStringsSep "," cudaFlags.cudaRealArches;
TF_CUDA_COMPUTE_CAPABILITIES = lib.concatStringsSep "," cudaCapabilities;
postPatch = ''
# bazel 3.3 should work just as well as bazel 3.1