mirror of
https://github.com/NixOS/nixpkgs.git
synced 2025-04-17 04:18:24 +00:00
Merge pull request #291705 from GaetanLepage/jax
python311Packages.{jax,jaxlib,jaxlib-bin}: 0.4.24 -> 0.4.28
This commit is contained in:
commit
3a993d3244
@ -16,7 +16,7 @@
|
||||
|
||||
buildPythonPackage rec {
|
||||
pname = "blackjax";
|
||||
version = "1.2.0";
|
||||
version = "1.2.1";
|
||||
pyproject = true;
|
||||
|
||||
disabled = pythonOlder "3.9";
|
||||
@ -25,7 +25,7 @@ buildPythonPackage rec {
|
||||
owner = "blackjax-devs";
|
||||
repo = "blackjax";
|
||||
rev = "refs/tags/${version}";
|
||||
hash = "sha256-vXyxK3xALKG61YGK7fmoqQNGfOiagHFrvnU02WKZThw=";
|
||||
hash = "sha256-VoWBCjFMyE5LVJyf7du/pKlnvDHj22lguiP6ZUzH9ak=";
|
||||
};
|
||||
|
||||
build-system = [
|
||||
@ -56,6 +56,10 @@ buildPythonPackage rec {
|
||||
disabledTests = [
|
||||
# too slow
|
||||
"test_adaptive_tempered_smc"
|
||||
] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
|
||||
# Numerical test (AssertionError)
|
||||
# https://github.com/blackjax-devs/blackjax/issues/668
|
||||
"test_chees_adaptation"
|
||||
];
|
||||
|
||||
pythonImportsCheck = [
|
||||
|
@ -48,8 +48,21 @@ buildPythonPackage rec {
|
||||
pythonImportsCheck = [ "equinox" ];
|
||||
|
||||
disabledTests = [
|
||||
# Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
|
||||
"test_tracetime"
|
||||
# For simplicity, JAX has removed its internal frames from the traceback of the following exception.
|
||||
# https://github.com/patrick-kidger/equinox/issues/716
|
||||
"test_abstract"
|
||||
"test_complicated"
|
||||
"test_grad"
|
||||
"test_jvp"
|
||||
"test_mlp"
|
||||
"test_num_traces"
|
||||
"test_pytree_in"
|
||||
"test_simple"
|
||||
"test_vmap"
|
||||
|
||||
# AssertionError: assert 'foo:\n pri...pe=float32)\n' == 'foo:\n pri...pe=float32)\n'
|
||||
# Also reported in patrick-kidger/equinox#716
|
||||
"test_backward_nan"
|
||||
];
|
||||
|
||||
meta = with lib; {
|
||||
|
@ -25,7 +25,7 @@
|
||||
|
||||
buildPythonPackage rec {
|
||||
pname = "flax";
|
||||
version = "0.8.2";
|
||||
version = "0.8.3";
|
||||
pyproject = true;
|
||||
|
||||
disabled = pythonOlder "3.9";
|
||||
@ -34,16 +34,16 @@ buildPythonPackage rec {
|
||||
owner = "google";
|
||||
repo = "flax";
|
||||
rev = "refs/tags/v${version}";
|
||||
hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g=";
|
||||
hash = "sha256-uDGTyksUZTTL6FiTJP+qteFLOjr75dcTj9yRJ6Jm8xU=";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [
|
||||
build-system = [
|
||||
jaxlib
|
||||
pythonRelaxDepsHook
|
||||
setuptools-scm
|
||||
];
|
||||
|
||||
propagatedBuildInputs = [
|
||||
dependencies = [
|
||||
jax
|
||||
msgpack
|
||||
numpy
|
||||
|
@ -29,7 +29,7 @@ let
|
||||
in
|
||||
buildPythonPackage rec {
|
||||
pname = "jax";
|
||||
version = "0.4.25";
|
||||
version = "0.4.28";
|
||||
pyproject = true;
|
||||
|
||||
disabled = pythonOlder "3.9";
|
||||
@ -39,7 +39,7 @@ buildPythonPackage rec {
|
||||
repo = "jax";
|
||||
# google/jax contains tags for jax and jaxlib. Only use jax tags!
|
||||
rev = "refs/tags/jax-v${version}";
|
||||
hash = "sha256-poQQo2ZgEhPYzK3aCs+BjaHTNZbezJAECd+HOdY1Yok=";
|
||||
hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [
|
||||
@ -81,6 +81,14 @@ buildPythonPackage rec {
|
||||
"tests/"
|
||||
];
|
||||
|
||||
# Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with
|
||||
# PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py'
|
||||
# See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241
|
||||
# NOTE: this doesn't seem to be an issue on linux
|
||||
preCheck = lib.optionalString stdenv.isDarwin ''
|
||||
export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d)
|
||||
'';
|
||||
|
||||
disabledTests = [
|
||||
# Exceeds tolerance when the machine is busy
|
||||
"test_custom_linear_solve_aux"
|
||||
|
@ -20,17 +20,17 @@
|
||||
, stdenv
|
||||
# Options:
|
||||
, cudaSupport ? config.cudaSupport
|
||||
, cudaPackagesGoogle
|
||||
, cudaPackages
|
||||
}:
|
||||
|
||||
let
|
||||
inherit (cudaPackagesGoogle) cudaVersion;
|
||||
inherit (cudaPackages) cudaVersion;
|
||||
|
||||
version = "0.4.24";
|
||||
version = "0.4.28";
|
||||
|
||||
inherit (python) pythonVersion;
|
||||
|
||||
cudaLibPath = lib.makeLibraryPath (with cudaPackagesGoogle; [
|
||||
cudaLibPath = lib.makeLibraryPath (with cudaPackages; [
|
||||
cuda_cudart.lib # libcudart.so
|
||||
cuda_cupti.lib # libcupti.so
|
||||
cudnn.lib # libcudnn.so
|
||||
@ -56,65 +56,65 @@ let
|
||||
"3.9-x86_64-linux" = getSrcFromPypi {
|
||||
platform = "manylinux2014_x86_64";
|
||||
dist = "cp39";
|
||||
hash = "sha256-6P5ArMoLZiUkHUoQ/mJccbNj5/7el/op+Qo6cGQ33xE=";
|
||||
hash = "sha256-Slbr8FtKTBeRaZ2HTgcvP4CPCYa0AQsU+1SaackMqdw=";
|
||||
};
|
||||
"3.9-aarch64-darwin" = getSrcFromPypi {
|
||||
platform = "macosx_11_0_arm64";
|
||||
dist = "cp39";
|
||||
hash = "sha256-23JQZRwMLtt7sK/JlCBqqRyfTVIAVJFN2sL+nAkQgvU=";
|
||||
hash = "sha256-sBVi7IrXVxm30DiXUkiel+trTctMjBE75JFjTVKCrTw=";
|
||||
};
|
||||
"3.9-x86_64-darwin" = getSrcFromPypi {
|
||||
platform = "macosx_10_14_x86_64";
|
||||
dist = "cp39";
|
||||
hash = "sha256-OgMedn9GHGs5THZf3pkP3Aw/jJ0vL5qK1b+Lzf634Ik=";
|
||||
hash = "sha256-T5jMg3srbG3P4Kt/+esQkxSSCUYRmqOvn6oTlxj/J4c=";
|
||||
};
|
||||
|
||||
"3.10-x86_64-linux" = getSrcFromPypi {
|
||||
platform = "manylinux2014_x86_64";
|
||||
dist = "cp310";
|
||||
hash = "sha256-/VwUIIa7mTs/wLz0ArsEfNrz2pGriVVT5GX9XRFRxfY=";
|
||||
hash = "sha256-47zcb45g+FVPQVwU2TATTmAuPKM8OOVGJ0/VRfh1dps=";
|
||||
};
|
||||
"3.10-aarch64-darwin" = getSrcFromPypi {
|
||||
platform = "macosx_11_0_arm64";
|
||||
dist = "cp310";
|
||||
hash = "sha256-LgICOyDGts840SQQJh+yOMobMASb62llvJjpGvhzrSw=";
|
||||
hash = "sha256-8Djmi9ENGjVUcisLvjbmpEg4RDenWqnSg/aW8O2fjAk=";
|
||||
};
|
||||
"3.10-x86_64-darwin" = getSrcFromPypi {
|
||||
platform = "macosx_10_14_x86_64";
|
||||
dist = "cp310";
|
||||
hash = "sha256-vhyULw+zBpz1UEi2tqgBMQEzY9a6YBgEIg6A4PPh3bQ=";
|
||||
hash = "sha256-pCHSN/jCXShQFm0zRgPGc925tsJvUrxJZwS4eCKXvWY=";
|
||||
};
|
||||
|
||||
"3.11-x86_64-linux" = getSrcFromPypi {
|
||||
platform = "manylinux2014_x86_64";
|
||||
dist = "cp311";
|
||||
hash = "sha256-VJO/VVwBFkOEtq4y/sLVgAV8Cung01JULiuT6W96E/8=";
|
||||
hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU=";
|
||||
};
|
||||
"3.11-aarch64-darwin" = getSrcFromPypi {
|
||||
platform = "macosx_11_0_arm64";
|
||||
dist = "cp311";
|
||||
hash = "sha256-VtuwXxurpSp1KI8ty1bizs5cdy8GEBN2MgS227sOCmE=";
|
||||
hash = "sha256-eThX+vN/Nxyv51L+pfyBH0NeQ7j7S1AgWERKf17M+Ck=";
|
||||
};
|
||||
"3.11-x86_64-darwin" = getSrcFromPypi {
|
||||
platform = "macosx_10_14_x86_64";
|
||||
dist = "cp311";
|
||||
hash = "sha256-4Dj5dEGKb9hpg3HlVogNO1Gc9UibJhy1eym2mjivxAQ=";
|
||||
hash = "sha256-L/gpDtx7ksfq5SUX9lSSYz4mey6QZ7rT5MMj0hPnfPU=";
|
||||
};
|
||||
|
||||
"3.12-x86_64-linux" = getSrcFromPypi {
|
||||
platform = "manylinux2014_x86_64";
|
||||
dist = "cp312";
|
||||
hash = "sha256-TlrGVtb3NTLmhnILWPLJR+jISCZ5SUV4wxNFpSfkCBo=";
|
||||
hash = "sha256-RqGqhX9P7uikP8upXA4Kti1AwmzJcwtsaWVZCLo1n40=";
|
||||
};
|
||||
"3.12-aarch64-darwin" = getSrcFromPypi {
|
||||
platform = "macosx_11_0_arm64";
|
||||
dist = "cp312";
|
||||
hash = "sha256-FIwK5CGykQjteuWzLZnbtAggIxLQeGV96bXlZGEytN0=";
|
||||
hash = "sha256-jdi//jhTcC9jzZJNoO4lc0pNGc1ckmvgM9dyun0cF10=";
|
||||
};
|
||||
"3.12-x86_64-darwin" = getSrcFromPypi {
|
||||
platform = "macosx_10_14_x86_64";
|
||||
dist = "cp312";
|
||||
hash = "sha256-9/jw/wr6oUD9pOadVAaMRL086iVMUXwVgnUMcG1UNvE=";
|
||||
hash = "sha256-1sCaVFMpciRhrwVuc1FG0sjHTCKsdCaoRetp8ya096A=";
|
||||
};
|
||||
};
|
||||
|
||||
@ -130,35 +130,19 @@ let
|
||||
gpuSrcs = {
|
||||
"cuda12.2-3.9" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-xdJKLPtx+CIza2CrWKM3M0cZJzyNFVTTTsvlgh38bfM=";
|
||||
hash = "sha256-d8LIl22gIvmWfoyKfXKElZJXicPQIZxdS4HumhwQGCw=";
|
||||
};
|
||||
"cuda12.2-3.10" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-QCjrOczD2mp+CDwVXBc0/4rJnAizeV62AK0Dpx9X6TE=";
|
||||
hash = "sha256-PXtWv+UEcMWF8LhWe6Z1UGkf14PG3dkJ0Iop0LiimnQ=";
|
||||
};
|
||||
"cuda12.2-3.11" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-Ipy3vk1yUplpNzECAFt63aOIhgEWgXG7hkoeTIk9bQQ=";
|
||||
hash = "sha256-QO2WSOzmJ48VaCha596mELiOfPsAGLpGctmdzcCHE/o=";
|
||||
};
|
||||
"cuda12.2-3.12" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-LSnZHaUga/8Z65iKXWBnZDk4yUpNykFTu3vukCchO6Q=";
|
||||
};
|
||||
"cuda11.8-3.9" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-UmyugL0VjlXkiD7fuDPWgW8XUpr/QaP5ggp6swoZTzU=";
|
||||
};
|
||||
"cuda11.8-3.10" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-luKULEiV1t/sO6eckDxddJTiOFa0dtJeDlrvp+WYmHk=";
|
||||
};
|
||||
"cuda11.8-3.11" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-4+uJ8Ij6mFGEmjFEgi3fLnSLZs+v18BRoOt7mZuqydw=";
|
||||
};
|
||||
"cuda11.8-3.12" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp312-cp312-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-bUDFb94Ar/65SzzR9RLIs/SL/HdjaPT1Su5whmjkS00=";
|
||||
hash = "sha256-ixWMaIChy4Ammsn23/3cCoala0lFibuUxyUr3tjfFKU=";
|
||||
};
|
||||
};
|
||||
|
||||
@ -213,7 +197,7 @@ buildPythonPackage {
|
||||
# for more info.
|
||||
postInstall = lib.optional cudaSupport ''
|
||||
mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin
|
||||
ln -s ${lib.getExe' cudaPackagesGoogle.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
|
||||
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
|
||||
'';
|
||||
|
||||
inherit (jaxlib-build) pythonImportsCheck;
|
||||
@ -227,7 +211,7 @@ buildPythonPackage {
|
||||
platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ];
|
||||
broken =
|
||||
!(cudaSupport -> lib.versionAtLeast cudaVersion "11.1")
|
||||
|| !(cudaSupport -> lib.versionAtLeast cudaPackagesGoogle.cudnn.version "8.2")
|
||||
|| !(cudaSupport -> lib.versionAtLeast cudaPackages.cudnn.version "8.2")
|
||||
|| !(cudaSupport -> stdenv.isLinux)
|
||||
|| !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}"))
|
||||
# Fails at pythonImportsCheckPhase:
|
||||
|
@ -13,7 +13,6 @@
|
||||
, curl
|
||||
, cython
|
||||
, fetchFromGitHub
|
||||
, fetchpatch
|
||||
, git
|
||||
, IOKit
|
||||
, jsoncpp
|
||||
@ -45,22 +44,22 @@
|
||||
, config
|
||||
# CUDA flags:
|
||||
, cudaSupport ? config.cudaSupport
|
||||
, cudaPackagesGoogle
|
||||
, cudaPackages
|
||||
|
||||
# MKL:
|
||||
, mklSupport ? true
|
||||
}@inputs:
|
||||
|
||||
let
|
||||
inherit (cudaPackagesGoogle) cudaFlags cudaVersion cudnn nccl;
|
||||
inherit (cudaPackages) cudaFlags cudaVersion cudnn nccl;
|
||||
|
||||
pname = "jaxlib";
|
||||
version = "0.4.24";
|
||||
version = "0.4.28";
|
||||
|
||||
# It's necessary to consistently use backendStdenv when building with CUDA
|
||||
# support, otherwise we get libstdc++ errors downstream
|
||||
stdenv = throw "Use effectiveStdenv instead";
|
||||
effectiveStdenv = if cudaSupport then cudaPackagesGoogle.backendStdenv else inputs.stdenv;
|
||||
effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
|
||||
|
||||
meta = with lib; {
|
||||
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
|
||||
@ -78,7 +77,7 @@ let
|
||||
# These are necessary at build time and run time.
|
||||
cuda_libs_joined = symlinkJoin {
|
||||
name = "cuda-joined";
|
||||
paths = with cudaPackagesGoogle; [
|
||||
paths = with cudaPackages; [
|
||||
cuda_cudart.lib # libcudart.so
|
||||
cuda_cudart.static # libcudart_static.a
|
||||
cuda_cupti.lib # libcupti.so
|
||||
@ -92,11 +91,11 @@ let
|
||||
# These are only necessary at build time.
|
||||
cuda_build_deps_joined = symlinkJoin {
|
||||
name = "cuda-build-deps-joined";
|
||||
paths = with cudaPackagesGoogle; [
|
||||
paths = with cudaPackages; [
|
||||
cuda_libs_joined
|
||||
|
||||
# Binaries
|
||||
cudaPackagesGoogle.cuda_nvcc.bin # nvcc
|
||||
cudaPackages.cuda_nvcc.bin # nvcc
|
||||
|
||||
# Headers
|
||||
cuda_cccl.dev # block_load.cuh
|
||||
@ -181,19 +180,10 @@ let
|
||||
owner = "openxla";
|
||||
repo = "xla";
|
||||
# Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
|
||||
rev = "12eee889e1f2ad41e27d7b0e970cb92d282d3ec5";
|
||||
hash = "sha256-68kjjgwYjRlcT0TVJo9BN6s+WTkdu5UMJqQcfHpBT90=";
|
||||
rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4";
|
||||
hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
# Resolves "could not convert ‘result’ from ‘SmallVector<[...],6>’ to
|
||||
# ‘SmallVector<[...],4>’" compilation error. See https://github.com/google/jax/issues/19814#issuecomment-1945141259.
|
||||
(fetchpatch {
|
||||
url = "https://github.com/openxla/xla/commit/7a614cd346594fc7ea2fe75570c9c53a4a444f60.patch";
|
||||
hash = "sha256-RtuQTH8wzNiJcOtISLhf+gMlH1gg8hekvxEB+4wX6BM=";
|
||||
})
|
||||
];
|
||||
|
||||
dontBuild = true;
|
||||
|
||||
# This is necessary for patchShebangs to know the right path to use.
|
||||
@ -220,7 +210,7 @@ let
|
||||
repo = "jax";
|
||||
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
|
||||
rev = "refs/tags/${pname}-v${version}";
|
||||
hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs=";
|
||||
hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [
|
||||
@ -364,10 +354,10 @@ let
|
||||
];
|
||||
|
||||
sha256 = (if cudaSupport then {
|
||||
x86_64-linux = "sha256-8JilAoTbqOjOOJa/Zc/n/quaEDcpdcLXCNb34mfB+OM=";
|
||||
x86_64-linux = "sha256-VGNMf5/DgXbgsu1w5J1Pmrukw+7UO31BNU+crKVsX5k=";
|
||||
} else {
|
||||
x86_64-linux = "sha256-iqS+I1FQLNWXNMsA20cJp7YkyGUeshee5b2QfRBNZtk=";
|
||||
aarch64-linux = "sha256-qmJ0Fm/VGMTmko4PhKs1P8/GLEJmVxb8xg+ss/HsakY==";
|
||||
x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ=";
|
||||
aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA=";
|
||||
}).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
|
||||
};
|
||||
|
||||
@ -414,7 +404,7 @@ buildPythonPackage {
|
||||
# for more info.
|
||||
postInstall = lib.optionalString cudaSupport ''
|
||||
mkdir -p $out/bin
|
||||
ln -s ${cudaPackagesGoogle.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
|
||||
ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
|
||||
|
||||
find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
|
||||
patchelf --add-rpath "${lib.makeLibraryPath [cuda_libs_joined cudnn nccl]}" "$lib"
|
||||
@ -423,7 +413,7 @@ buildPythonPackage {
|
||||
|
||||
nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ];
|
||||
|
||||
propagatedBuildInputs = [
|
||||
dependencies = [
|
||||
absl-py
|
||||
curl
|
||||
double-conversion
|
||||
|
@ -6,6 +6,7 @@
|
||||
, fetchpatch
|
||||
, pytest-xdist
|
||||
, pytestCheckHook
|
||||
, setuptools
|
||||
, absl-py
|
||||
, cvxpy
|
||||
, jax
|
||||
@ -20,7 +21,7 @@
|
||||
buildPythonPackage rec {
|
||||
pname = "jaxopt";
|
||||
version = "0.8.3";
|
||||
format = "setuptools";
|
||||
pyproject = true;
|
||||
|
||||
disabled = pythonOlder "3.8";
|
||||
|
||||
@ -41,7 +42,11 @@ buildPythonPackage rec {
|
||||
})
|
||||
];
|
||||
|
||||
propagatedBuildInputs = [
|
||||
build-system = [
|
||||
setuptools
|
||||
];
|
||||
|
||||
dependencies = [
|
||||
absl-py
|
||||
jax
|
||||
jaxlib
|
||||
@ -66,11 +71,20 @@ buildPythonPackage rec {
|
||||
"jaxopt.tree_util"
|
||||
];
|
||||
|
||||
disabledTests = lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
|
||||
disabledTests = [
|
||||
# https://github.com/google/jaxopt/issues/592
|
||||
"test_solve_sparse"
|
||||
] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
|
||||
# https://github.com/google/jaxopt/issues/577
|
||||
"test_binary_logit_log_likelihood"
|
||||
"test_solve_sparse"
|
||||
"test_logreg_with_intercept_manual_loop3"
|
||||
|
||||
# https://github.com/google/jaxopt/issues/593
|
||||
# Makes the test suite crash
|
||||
"test_dtype_consistency"
|
||||
# AssertionError: Array(0.01411963, dtype=float32) not less than or equal to 0.01
|
||||
"test_multiclass_logreg6"
|
||||
];
|
||||
|
||||
meta = with lib; {
|
||||
|
@ -51,8 +51,10 @@ buildPythonPackage rec {
|
||||
scipy
|
||||
torch
|
||||
tensorflow
|
||||
jax
|
||||
jaxlib
|
||||
# Uncomment at next release (1.9.3)
|
||||
# See https://github.com/wjakob/nanobind/issues/578
|
||||
# jax
|
||||
# jaxlib
|
||||
];
|
||||
|
||||
meta = with lib; {
|
||||
|
@ -1,7 +1,6 @@
|
||||
{ lib
|
||||
, buildPythonPackage
|
||||
, fetchFromGitHub
|
||||
, fetchpatch
|
||||
, jax
|
||||
, jaxlib
|
||||
, keras
|
||||
@ -30,7 +29,12 @@ buildPythonPackage rec {
|
||||
hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk=";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [
|
||||
patches = [
|
||||
# Issue reported upstream: https://github.com/google/objax/issues/270
|
||||
./replace-deprecated-device_buffers.patch
|
||||
];
|
||||
|
||||
build-system = [
|
||||
setuptools
|
||||
];
|
||||
|
||||
@ -40,7 +44,7 @@ buildPythonPackage rec {
|
||||
jaxlib
|
||||
];
|
||||
|
||||
propagatedBuildInputs = [
|
||||
dependencies = [
|
||||
jax
|
||||
numpy
|
||||
parameterized
|
||||
|
@ -0,0 +1,14 @@
|
||||
diff --git a/objax/util/util.py b/objax/util/util.py
|
||||
index c31a356..344cf9a 100644
|
||||
--- a/objax/util/util.py
|
||||
+++ b/objax/util/util.py
|
||||
@@ -117,7 +117,8 @@ def get_local_devices():
|
||||
if _local_devices is None:
|
||||
x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32)
|
||||
sharded_x = map_to_device(x)
|
||||
- _local_devices = [b.device() for b in sharded_x.device_buffers]
|
||||
+ device_buffers = [buf.data for buf in sharded_x.addressable_shards]
|
||||
+ _local_devices = [list(b.devices())[0] for b in device_buffers]
|
||||
return _local_devices
|
||||
|
||||
|
@ -22,7 +22,7 @@
|
||||
, tensorboard
|
||||
, config
|
||||
, cudaSupport ? config.cudaSupport
|
||||
, cudaPackagesGoogle
|
||||
, cudaPackages
|
||||
, zlib
|
||||
, python
|
||||
, keras-applications
|
||||
@ -43,7 +43,7 @@ assert ! (stdenv.isDarwin && cudaSupport);
|
||||
|
||||
let
|
||||
packages = import ./binary-hashes.nix;
|
||||
inherit (cudaPackagesGoogle) cudatoolkit cudnn;
|
||||
inherit (cudaPackages) cudatoolkit cudnn;
|
||||
in buildPythonPackage {
|
||||
pname = "tensorflow" + lib.optionalString cudaSupport "-gpu";
|
||||
inherit (packages) version;
|
||||
@ -199,10 +199,6 @@ in buildPythonPackage {
|
||||
"tensorflow.python.framework"
|
||||
];
|
||||
|
||||
passthru = {
|
||||
cudaPackages = cudaPackagesGoogle;
|
||||
};
|
||||
|
||||
meta = with lib; {
|
||||
description = "Computation using data flow graphs for scalable machine learning";
|
||||
homepage = "http://tensorflow.org";
|
||||
|
@ -19,8 +19,8 @@
|
||||
# https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0
|
||||
, config
|
||||
, cudaSupport ? config.cudaSupport
|
||||
, cudaPackagesGoogle
|
||||
, cudaCapabilities ? cudaPackagesGoogle.cudaFlags.cudaCapabilities
|
||||
, cudaPackages
|
||||
, cudaCapabilities ? cudaPackages.cudaFlags.cudaCapabilities
|
||||
, mklSupport ? false, mkl
|
||||
, tensorboardSupport ? true
|
||||
# XLA without CUDA is broken
|
||||
@ -50,15 +50,15 @@ let
|
||||
# __ZN4llvm11SmallPtrSetIPKNS_10AllocaInstELj8EED1Ev in any of the
|
||||
# translation units, so the build fails at link time
|
||||
stdenv =
|
||||
if cudaSupport then cudaPackagesGoogle.backendStdenv
|
||||
if cudaSupport then cudaPackages.backendStdenv
|
||||
else if originalStdenv.isDarwin then llvmPackages.stdenv
|
||||
else originalStdenv;
|
||||
inherit (cudaPackagesGoogle) cudatoolkit nccl;
|
||||
inherit (cudaPackages) cudatoolkit nccl;
|
||||
# use compatible cuDNN (https://www.tensorflow.org/install/source#gpu)
|
||||
# cudaPackages.cudnn led to this:
|
||||
# https://github.com/tensorflow/tensorflow/issues/60398
|
||||
cudnnAttribute = "cudnn_8_6";
|
||||
cudnn = cudaPackagesGoogle.${cudnnAttribute};
|
||||
cudnn = cudaPackages.${cudnnAttribute};
|
||||
gentoo-patches = fetchzip {
|
||||
url = "https://dev.gentoo.org/~perfinion/patches/tensorflow-patches-2.12.0.tar.bz2";
|
||||
hash = "sha256-SCRX/5/zML7LmKEPJkcM5Tebez9vv/gmE4xhT/jyqWs=";
|
||||
@ -490,8 +490,8 @@ let
|
||||
broken =
|
||||
stdenv.isDarwin
|
||||
|| !(xlaSupport -> cudaSupport)
|
||||
|| !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackagesGoogle)
|
||||
|| !(cudaSupport -> cudaPackagesGoogle ? cudatoolkit);
|
||||
|| !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackages)
|
||||
|| !(cudaSupport -> cudaPackages ? cudatoolkit);
|
||||
} // lib.optionalAttrs stdenv.isDarwin {
|
||||
timeout = 86400; # 24 hours
|
||||
maxSilent = 14400; # 4h, double the default of 7200s
|
||||
@ -594,7 +594,6 @@ in buildPythonPackage {
|
||||
# Regression test for #77626 removed because not more `tensorflow.contrib`.
|
||||
|
||||
passthru = {
|
||||
cudaPackages = cudaPackagesGoogle;
|
||||
deps = bazel-build.deps;
|
||||
libtensorflow = bazel-build.out;
|
||||
};
|
||||
|
@ -3,7 +3,6 @@
|
||||
recurseIntoAttrs,
|
||||
|
||||
cudaPackages,
|
||||
cudaPackagesGoogle,
|
||||
|
||||
cudaPackages_10_0,
|
||||
cudaPackages_10_1,
|
||||
|
@ -7125,10 +7125,6 @@ with pkgs;
|
||||
cudaPackages_12_3 = callPackage ./cuda-packages.nix { cudaVersion = "12.3"; };
|
||||
cudaPackages_12 = cudaPackages_12_2; # Latest supported by cudnn
|
||||
|
||||
# Use the older cudaPackages for tensorflow and jax, as determined by cudnn
|
||||
# compatibility: https://www.tensorflow.org/install/source#gpu
|
||||
cudaPackagesGoogle = cudaPackages_11;
|
||||
|
||||
cudaPackages = recurseIntoAttrs cudaPackages_12;
|
||||
|
||||
# TODO: move to alias
|
||||
|
@ -14885,6 +14885,8 @@ self: super: with self; {
|
||||
|
||||
tensorflow-bin = callPackage ../development/python-modules/tensorflow/bin.nix {
|
||||
inherit (pkgs.config) cudaSupport;
|
||||
# https://www.tensorflow.org/install/source#gpu
|
||||
cudaPackages = pkgs.cudaPackages_11;
|
||||
};
|
||||
|
||||
tensorflow-build = let
|
||||
@ -14892,6 +14894,8 @@ self: super: with self; {
|
||||
protobufTF = pkgs.protobuf_21.override {
|
||||
abseil-cpp = pkgs.abseil-cpp_202301;
|
||||
};
|
||||
# https://www.tensorflow.org/install/source#gpu
|
||||
cudaPackagesTF = pkgs.cudaPackages_11;
|
||||
grpcTF = (pkgs.grpc.overrideAttrs (
|
||||
oldAttrs: rec {
|
||||
# nvcc fails on recent grpc versions, so we use the latest patch level
|
||||
@ -14937,6 +14941,7 @@ self: super: with self; {
|
||||
inherit (pkgs.darwin.apple_sdk.frameworks) Foundation Security;
|
||||
flatbuffers-core = pkgs.flatbuffers;
|
||||
flatbuffers-python = self.flatbuffers;
|
||||
cudaPackages = compat.cudaPackagesTF;
|
||||
protobuf-core = compat.protobufTF;
|
||||
protobuf-python = compat.protobuf-pythonTF;
|
||||
grpc = compat.grpcTF;
|
||||
|
Loading…
Reference in New Issue
Block a user