From cf7fb1d08f928f48725f15e595cbb84793278379 Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Mon, 27 Feb 2023 14:54:09 +0200 Subject: [PATCH] python3Packages.tensorflow: add cudaCapabilities argument Rearrange tensorflow to allow overriding cudaCapabilities. This is needed when debugging the tensorflow derivation --- pkgs/development/compilers/cudatoolkit/flags.nix | 3 +-- pkgs/development/python-modules/tensorflow/default.nix | 8 +++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pkgs/development/compilers/cudatoolkit/flags.nix b/pkgs/development/compilers/cudatoolkit/flags.nix index 8e1e54723b2e..9d7b7f884ad2 100644 --- a/pkgs/development/compilers/cudatoolkit/flags.nix +++ b/pkgs/development/compilers/cudatoolkit/flags.nix @@ -1,6 +1,6 @@ { config , lib -, cudatoolkit +, cudaVersion }: # Type aliases @@ -13,7 +13,6 @@ let inherit (lib) attrsets lists strings trivial versions; - cudaVersion = cudatoolkit.version; # Flags are determined based on your CUDA toolkit by default. You may benefit # from improved performance, reduced file size, or greater hardware suppport by diff --git a/pkgs/development/python-modules/tensorflow/default.nix b/pkgs/development/python-modules/tensorflow/default.nix index f7d920c37221..f18a924c31fa 100644 --- a/pkgs/development/python-modules/tensorflow/default.nix +++ b/pkgs/development/python-modules/tensorflow/default.nix @@ -17,7 +17,9 @@ # that in nix as well. It would make some things easier and less confusing, but # it would also make the default tensorflow package unfree. See # https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0 -, cudaSupport ? false, cudaPackages ? {} +, cudaSupport ? false +, cudaPackages ? { } +, cudaCapabilities ? cudaPackages.cudaFlags.cudaCapabilities , mklSupport ? false, mkl , tensorboardSupport ? true # XLA without CUDA is broken @@ -30,7 +32,7 @@ }: let - inherit (cudaPackages) cudatoolkit cudaFlags cudnn nccl; + inherit (cudaPackages) cudatoolkit cudnn nccl; in assert cudaSupport -> cudatoolkit != null @@ -301,7 +303,7 @@ let TF_CUDA_PATHS = lib.optionalString cudaSupport "${cudatoolkit_joined},${cudnn},${nccl}"; GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin"; GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc"; - TF_CUDA_COMPUTE_CAPABILITIES = builtins.concatStringsSep "," cudaFlags.cudaRealArches; + TF_CUDA_COMPUTE_CAPABILITIES = lib.concatStringsSep "," cudaCapabilities; postPatch = '' # bazel 3.3 should work just as well as bazel 3.1