mirror of
https://github.com/NixOS/nixpkgs.git
synced 2024-11-27 09:23:01 +00:00
Merge pull request #158186 from samuela/samuela/jaxlib
python3Packages.jaxlib-bin: 0.1.71 -> 0.1.75
This commit is contained in:
commit
b6558a0aec
@ -16,19 +16,23 @@
|
||||
{ addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config
|
||||
, fetchurl, isPy39, lib, stdenv
|
||||
# propagatedBuildInputs
|
||||
, absl-py, flatbuffers, scipy, cudatoolkit_11
|
||||
, absl-py, flatbuffers, scipy, cudatoolkit_11, cudnn
|
||||
# Options:
|
||||
, cudaSupport ? config.cudaSupport or false
|
||||
}:
|
||||
|
||||
# Note that these values are tied to the specific version of the GPU wheel that
|
||||
# we fetch. When updating, try to go for the latest possible versions that are
|
||||
# still compatible with the cudatoolkit and cudnn versions available in nixpkgs.
|
||||
assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1";
|
||||
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5";
|
||||
|
||||
let
|
||||
device = if cudaSupport then "gpu" else "cpu";
|
||||
in
|
||||
buildPythonPackage rec {
|
||||
pname = "jaxlib";
|
||||
version = "0.1.71";
|
||||
version = "0.1.75";
|
||||
format = "wheel";
|
||||
|
||||
# At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting
|
||||
@ -36,14 +40,23 @@ buildPythonPackage rec {
|
||||
# version.
|
||||
disabled = !isPy39;
|
||||
|
||||
# Find new releases at https://storage.googleapis.com/jax-releases.
|
||||
src = {
|
||||
cpu = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl";
|
||||
sha256 = "sha256:0rqhs6qabydizlv5d3rb20dbv6612rr7dqfniy9r6h4kazdinsn6";
|
||||
sha256 = "1davmx9dvai8dq3h5ac82634gjhv6l46kq6baajrxjqczbp0w7m6";
|
||||
};
|
||||
gpu = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda111/jaxlib-${version}+cuda111-cp39-none-manylinux2010_x86_64.whl";
|
||||
sha256 = "sha256:065kyzjsk9m84d138p99iymdiiicm1qz8a3iwxz8rspl43rwrw89";
|
||||
# Note that there's also a release targeting cuDNN 8.2, but unfortunately
|
||||
# we don't yet have that packaged at the time of writing (02/03/2022).
|
||||
# Check pkgs/development/libraries/science/math/cudnn/default.nix for more
|
||||
# details.
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl";
|
||||
sha256 = "1mk618lq1q5x0dc3xbid8bim59l9j6l47xq232gdbn401ykrid7r";
|
||||
|
||||
# This is what the cuDNN 8.2 download looks like for future reference:
|
||||
# url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl";
|
||||
# sha256 = "000mnm2masm3sx3haddcmgw43j4gxa3m4fcm14p9nb8dnncjkgpb";
|
||||
};
|
||||
}.${device};
|
||||
|
||||
@ -71,7 +84,7 @@ buildPythonPackage rec {
|
||||
rpath=$(patchelf --print-rpath $file)
|
||||
# For some reason `makeLibraryPath` on `cudatoolkit_11` maps to
|
||||
# <cudatoolkit_11.lib>/lib which is different from <cudatoolkit_11>/lib.
|
||||
patchelf --set-rpath "$rpath:${cudatoolkit_11}/lib:${lib.makeLibraryPath [ cudatoolkit_11.lib ]}" $file
|
||||
patchelf --set-rpath "$rpath:${cudatoolkit_11}/lib:${lib.makeLibraryPath [ cudatoolkit_11.lib cudnn ]}" $file
|
||||
done
|
||||
'';
|
||||
|
||||
|
@ -4139,7 +4139,11 @@ in {
|
||||
|
||||
jax = callPackage ../development/python-modules/jax { };
|
||||
|
||||
jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix { };
|
||||
jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix {
|
||||
cudaSupport = pkgs.config.cudaSupport or false;
|
||||
cudatoolkit_11 = tensorflow_compat_cudatoolkit;
|
||||
cudnn = tensorflow_compat_cudnn;
|
||||
};
|
||||
|
||||
jaxlib-build = callPackage ../development/python-modules/jaxlib {
|
||||
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
|
||||
|
Loading…
Reference in New Issue
Block a user