Merge pull request #169279 from samuela/samuela/fix-jaxlibwithcuda

Fix python3Packages.jaxlibWithCuda
This commit is contained in:
Samuel Ainsworth 2022-04-20 10:05:39 -07:00 committed by GitHub
commit 8c10278ac1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 8 deletions

View File

@ -44,16 +44,16 @@ gcc = "gcc9"
version = "11.4.2" version = "11.4.2"
url = "https://developer.download.nvidia.com/compute/cuda/11.4.2/local_installers/cuda_11.4.2_470.57.02_linux.run" url = "https://developer.download.nvidia.com/compute/cuda/11.4.2/local_installers/cuda_11.4.2_470.57.02_linux.run"
sha256 = "sha256-u9h8oOkT+DdFSnljZ0c1E83e9VUILk2G7Zo4ZZzIHwo=" sha256 = "sha256-u9h8oOkT+DdFSnljZ0c1E83e9VUILk2G7Zo4ZZzIHwo="
gcc = "gcc10" # can bump to 11 along with stdenv.cc gcc = "gcc11"
["11.5"] ["11.5"]
version = "11.5.0" version = "11.5.0"
url = "https://developer.download.nvidia.com/compute/cuda/11.5.0/local_installers/cuda_11.5.0_495.29.05_linux.run" url = "https://developer.download.nvidia.com/compute/cuda/11.5.0/local_installers/cuda_11.5.0_495.29.05_linux.run"
sha256 = "sha256-rgoWk9lJfPPYHmlIlD43lGNpANtxyY1Y7v2sr38aHkw=" sha256 = "sha256-rgoWk9lJfPPYHmlIlD43lGNpANtxyY1Y7v2sr38aHkw="
gcc = "gcc10" # can bump to 11 along with stdenv.cc gcc = "gcc11"
["11.6"] ["11.6"]
version = "11.6.1" version = "11.6.1"
url = "https://developer.download.nvidia.com/compute/cuda/11.6.1/local_installers/cuda_11.6.1_510.47.03_linux.run" url = "https://developer.download.nvidia.com/compute/cuda/11.6.1/local_installers/cuda_11.6.1_510.47.03_linux.run"
sha256 = "sha256-qyGa/OALdCABEyaYZvv/derQN7z8I1UagzjCaEyYTX4=" sha256 = "sha256-qyGa/OALdCABEyaYZvv/derQN7z8I1UagzjCaEyYTX4="
gcc = "gcc10" # can bump to 11 along with stdenv.cc gcc = "gcc11"

View File

@ -3579,7 +3579,12 @@ with pkgs;
inherit (darwin.apple_sdk.frameworks) Security; inherit (darwin.apple_sdk.frameworks) Security;
}; };
gpu-burn = callPackage ../applications/misc/gpu-burn { }; gpu-burn = callPackage ../applications/misc/gpu-burn {
# gpu-burn doesn't build on gcc11. CUDA 11.3 is the last version to use
# pre-gcc11, in particular gcc9.
cudatoolkit = cudaPackages_11_3.cudatoolkit;
stdenv = gcc9Stdenv;
};
greg = callPackage ../applications/audio/greg { greg = callPackage ../applications/audio/greg {
pythonPackages = python3Packages; pythonPackages = python3Packages;
@ -23097,6 +23102,10 @@ with pkgs;
librealsenseWithCuda = callPackage ../development/libraries/librealsense { librealsenseWithCuda = callPackage ../development/libraries/librealsense {
cudaSupport = true; cudaSupport = true;
# librealsenseWithCuda doesn't build on gcc11. CUDA 11.3 is the last version
# to use pre-gcc11, in particular gcc9.
cudaPackages = cudaPackages_11_3;
stdenv = gcc9Stdenv;
}; };
librealsenseWithoutCuda = callPackage ../development/libraries/librealsense { librealsenseWithoutCuda = callPackage ../development/libraries/librealsense {

View File

@ -1958,7 +1958,11 @@ in {
cufflinks = callPackage ../development/python-modules/cufflinks { }; cufflinks = callPackage ../development/python-modules/cufflinks { };
cupy = callPackage ../development/python-modules/cupy { }; cupy = callPackage ../development/python-modules/cupy {
# cupy doesn't build on gcc11. CUDA 11.3 is the last version to use
# pre-gcc11, in particular gcc9.
cudaPackages = pkgs.cudaPackages_11_3;
};
curio = callPackage ../development/python-modules/curio { }; curio = callPackage ../development/python-modules/curio { };
@ -4243,20 +4247,23 @@ in {
jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix { jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix {
cudaSupport = pkgs.config.cudaSupport or false; cudaSupport = pkgs.config.cudaSupport or false;
inherit (self.tensorflow) cudaPackages; # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we
# pin to `cudaPackages_11_6` instead.
cudaPackages = pkgs.cudaPackages_11_6;
}; };
jaxlib-build = callPackage ../development/python-modules/jaxlib { jaxlib-build = callPackage ../development/python-modules/jaxlib {
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'. # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
cudaSupport = pkgs.config.cudaSupport or false; cudaSupport = pkgs.config.cudaSupport or false;
inherit (self.tensorflow) cudaPackages; # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we
# pin to `cudaPackages_11_6` instead.
cudaPackages = pkgs.cudaPackages_11_6;
}; };
jaxlib = self.jaxlib-build; jaxlib = self.jaxlib-build;
jaxlibWithCuda = self.jaxlib-build.override { jaxlibWithCuda = self.jaxlib-build.override {
cudaSupport = true; cudaSupport = true;
}; };
jaxlibWithoutCuda = self.jaxlib-build.override { jaxlibWithoutCuda = self.jaxlib-build.override {