Merge pull request #323154 from zimbatm/jax-fixes

Jax fixes
This commit is contained in:
Jonas Chevalier 2024-07-04 11:20:04 +02:00 committed by GitHub
commit 9f38e1c714
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 23 additions and 16 deletions

View File

@ -8,6 +8,7 @@
fetchFromGitHub, fetchFromGitHub,
jaxlib, jaxlib,
jaxlib-bin, jaxlib-bin,
jaxlib-build,
hypothesis, hypothesis,
lapack, lapack,
matplotlib, matplotlib,
@ -23,10 +24,6 @@
let let
usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
# jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work
# fine. jaxlib is only used in the checkPhase, so switching backends does not
# impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*.
jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib;
in in
buildPythonPackage rec { buildPythonPackage rec {
pname = "jax"; pname = "jax";
@ -61,7 +58,7 @@ buildPythonPackage rec {
nativeCheckInputs = [ nativeCheckInputs = [
hypothesis hypothesis
jaxlib' jaxlib
matplotlib matplotlib
pytestCheckHook pytestCheckHook
pytest-xdist pytest-xdist
@ -130,7 +127,11 @@ buildPythonPackage rec {
"testQdwhWithOnRankDeficientInput5" "testQdwhWithOnRankDeficientInput5"
]; ];
disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ disabledTestPaths = [
# Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba
"tests/linalg_test.py"
]
++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
# RuntimeWarning: invalid value encountered in cast # RuntimeWarning: invalid value encountered in cast
"tests/lax_test.py" "tests/lax_test.py"
]; ];
@ -147,7 +148,7 @@ buildPythonPackage rec {
# NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin
passthru.tests = { passthru.tests = {
test_cuda_jaxlibSource = callPackage ./test-cuda.nix { test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
jaxlib = jaxlib.override { cudaSupport = true; }; jaxlib = jaxlib-build.override { cudaSupport = true; };
}; };
test_cuda_jaxlibBin = callPackage ./test-cuda.nix { test_cuda_jaxlibBin = callPackage ./test-cuda.nix {
jaxlib = jaxlib-bin.override { cudaSupport = true; }; jaxlib = jaxlib-bin.override { cudaSupport = true; };
@ -158,7 +159,11 @@ buildPythonPackage rec {
passthru.skipBulkUpdate = true; passthru.skipBulkUpdate = true;
meta = with lib; { meta = with lib; {
description = "Differentiate, compile, and transform Numpy code"; description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
longDescription = ''
This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
'';
homepage = "https://github.com/google/jax"; homepage = "https://github.com/google/jax";
license = licenses.asl20; license = licenses.asl20;
maintainers = with maintainers; [ samuela ]; maintainers = with maintainers; [ samuela ];

View File

@ -225,7 +225,7 @@ buildPythonPackage {
inherit (jaxlib-build) pythonImportsCheck; inherit (jaxlib-build) pythonImportsCheck;
meta = with lib; { meta = with lib; {
description = "XLA library for JAX"; description = "Prebuilt jaxlib backend from PyPi";
homepage = "https://github.com/google/jax"; homepage = "https://github.com/google/jax";
sourceProvenance = with sourceTypes; [ binaryNativeCode ]; sourceProvenance = with sourceTypes; [ binaryNativeCode ];
license = licenses.asl20; license = licenses.asl20;

View File

@ -67,16 +67,17 @@ let
effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv; effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
meta = with lib; { meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research"; description = "Source-built JAX backend. JAX is Autograd and XLA, brought together for high-performance machine learning research";
homepage = "https://github.com/google/jax"; homepage = "https://github.com/google/jax";
license = licenses.asl20; license = licenses.asl20;
maintainers = with maintainers; [ ndl ]; maintainers = with maintainers; [ ndl ];
platforms = platforms.unix;
# Make this platforms.unix once Darwin is supported.
# The top-level jaxlib now falls back to jaxlib-bin on unsupported platforms.
# aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
# however even with that fix applied, it doesn't work for everyone: # however even with that fix applied, it doesn't work for everyone:
# https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
# NOTE: We always build with NCCL; if it is unsupported, then our build is broken. platforms = platforms.linux;
broken = effectiveStdenv.isDarwin || nccl.meta.unsupported;
}; };
# Bazel wants a merged cudnn at configuration time # Bazel wants a merged cudnn at configuration time

View File

@ -6121,13 +6121,14 @@ self: super: with self; {
IOKit = pkgs.darwin.apple_sdk_11_0.IOKit; IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
}; };
jaxlib = self.jaxlib-build; # Use the -bin on macOS since the source build doesn't support it (see #323154)
jaxlib = if jaxlib-build.meta.unsupported then jaxlib-bin else jaxlib-build;
jaxlibWithCuda = self.jaxlib-build.override { jaxlibWithCuda = self.jaxlib.override {
cudaSupport = true; cudaSupport = true;
}; };
jaxlibWithoutCuda = self.jaxlib-build.override { jaxlibWithoutCuda = self.jaxlib.override {
cudaSupport = false; cudaSupport = false;
}; };