From e754f2946b9d101b94d8fea49cb545c17aa3846f Mon Sep 17 00:00:00 2001 From: Gaetan Lepage Date: Wed, 14 Jun 2023 10:18:15 +0200 Subject: [PATCH 1/6] buildBazelPackage: add support for bazel run targets --- .../build-bazel-package/default.nix | 39 ++++++++++++++----- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/pkgs/build-support/build-bazel-package/default.nix b/pkgs/build-support/build-bazel-package/default.nix index f9de0ad468b2..3ffff74f70e2 100644 --- a/pkgs/build-support/build-bazel-package/default.nix +++ b/pkgs/build-support/build-bazel-package/default.nix @@ -10,9 +10,12 @@ args@{ , bazelFlags ? [] , bazelBuildFlags ? [] , bazelTestFlags ? [] +, bazelRunFlags ? [] +, runTargetFlags ? [] , bazelFetchFlags ? [] -, bazelTargets +, bazelTargets ? [] , bazelTestTargets ? [] +, bazelRunTarget ? null , buildAttrs , fetchAttrs @@ -46,17 +49,23 @@ args@{ let fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // { - name = name; - bazelFlags = bazelFlags; - bazelBuildFlags = bazelBuildFlags; - bazelTestFlags = bazelTestFlags; - bazelFetchFlags = bazelFetchFlags; - bazelTestTargets = bazelTestTargets; - dontAddBazelOpts = dontAddBazelOpts; + inherit + name + bazelFlags + bazelBuildFlags + bazelTestFlags + bazelRunFlags + runTargetFlags + bazelFetchFlags + bazelTargets + bazelTestTargets + bazelRunTarget + dontAddBazelOpts + ; }; fBuildAttrs = fArgs // buildAttrs; fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ]; - bazelCmd = { cmd, additionalFlags, targets }: + bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }: lib.optionalString (targets != [ ]) '' # See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables] BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \ @@ -73,7 +82,8 @@ let "''${host_linkopts[@]}" \ $bazelFlags \ ${lib.strings.concatStringsSep " " additionalFlags} \ - ${lib.strings.concatStringsSep " " targets} + ${lib.strings.concatStringsSep " " targets} \ + ${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags} ''; # we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so: # chmod: cannot operate on dangling symlink '$symlink' @@ -262,6 +272,15 @@ stdenv.mkDerivation (fBuildAttrs // { targets = fBuildAttrs.bazelTargets; } } + ${ + bazelCmd { + cmd = "run"; + additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ]; + # Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list. + targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ]; + targetRunFlags = fBuildAttrs.runTargetFlags; + } + } runHook postBuild ''; }) From 22114b44bd01225ee690cb87962d927b54d54319 Mon Sep 17 00:00:00 2001 From: Gaetan Lepage Date: Tue, 18 Apr 2023 17:06:01 +0200 Subject: [PATCH 2/6] python3Packages.ml-dtypes: init at 0.2.0 --- .../python-modules/ml-dtypes/default.nix | 60 +++++++++++++++++++ pkgs/top-level/python-packages.nix | 2 + 2 files changed, 62 insertions(+) create mode 100644 pkgs/development/python-modules/ml-dtypes/default.nix diff --git a/pkgs/development/python-modules/ml-dtypes/default.nix b/pkgs/development/python-modules/ml-dtypes/default.nix new file mode 100644 index 000000000000..c329196d51de --- /dev/null +++ b/pkgs/development/python-modules/ml-dtypes/default.nix @@ -0,0 +1,60 @@ +{ lib +, buildPythonPackage +, pythonOlder +, fetchFromGitHub +, setuptools +, pybind11 +, numpy +, pytestCheckHook +, absl-py +}: + +buildPythonPackage rec { + pname = "ml-dtypes"; + version = "0.2.0"; + format = "pyproject"; + + disabled = pythonOlder "3.7"; + + src = fetchFromGitHub { + owner = "jax-ml"; + repo = "ml_dtypes"; + rev = "refs/tags/v${version}"; + hash = "sha256-eqajWUwylIYsS8gzEaCZLLr+1+34LXWhfKBjuwsEhhI="; + # Since this upstream patch (https://github.com/jax-ml/ml_dtypes/commit/1bfd097e794413b0d465fa34f2eff0f3828ff521), + # the attempts to use the nixpkgs packaged eigen dependency have failed. + # Hence, we rely on the bundled eigen library. + fetchSubmodules = true; + }; + + nativeBuildInputs = [ + setuptools + pybind11 + ]; + + propagatedBuildInputs = [ + numpy + ]; + + nativeCheckInputs = [ + pytestCheckHook + absl-py + ]; + + preCheck = '' + # remove src module, so tests use the installed module instead + mv ./ml_dtypes/tests ./tests + rm -rf ./ml_dtypes + ''; + + pythonImportsCheck = [ + "ml_dtypes" + ]; + + meta = with lib; { + description = "A stand-alone implementation of several NumPy dtype extensions used in machine learning libraries"; + homepage = "https://github.com/jax-ml/ml_dtypes"; + license = licenses.asl20; + maintainers = with maintainers; [ GaetanLepage samuela ]; + }; +} diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index 5c498ff9519a..21460fef30c0 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -6564,6 +6564,8 @@ self: super: with self; { ml-collections = callPackage ../development/python-modules/ml-collections { }; + ml-dtypes = callPackage ../development/python-modules/ml-dtypes { }; + mlflow = callPackage ../development/python-modules/mlflow { }; mlrose = callPackage ../development/python-modules/mlrose { }; From 7b16d5d8cb8f2bda4a9bf4948b58fb6d4714061d Mon Sep 17 00:00:00 2001 From: Gaetan Lepage Date: Tue, 18 Apr 2023 17:06:54 +0200 Subject: [PATCH 3/6] python3Packages.jaxlib-bin: 0.4.4 -> 0.4.14 --- .../development/python-modules/jaxlib/bin.nix | 84 +++++++++++-------- .../python-modules/jaxlib/prefetch.sh | 22 +++-- 2 files changed, 66 insertions(+), 40 deletions(-) diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index b3d3138ab443..c0773878c1d8 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -18,11 +18,12 @@ , autoPatchelfHook , buildPythonPackage , config -, cudnn ? cudaPackages.cudnn +, fetchPypi , fetchurl , flatbuffers -, isPy39 +, jaxlib-build , lib +, ml-dtypes , python , scipy , stdenv @@ -35,46 +36,57 @@ let inherit (cudaPackages) cudatoolkit cudnn; in -assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; -assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2"; +assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux; let - version = "0.4.4"; + version = "0.4.14"; + + inherit (python) pythonVersion; + + # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the + # official instructions recommend installing CPU-only versions via PyPI. + cpuSrcs = + let + getSrcFromPypi = { platform, hash }: fetchPypi { + inherit version platform hash; + pname = "jaxlib"; + format = "wheel"; + # See the `disabled` attr comment below. + dist = "cp310"; + python = "cp310"; + abi = "cp310"; + }; + in + { + "x86_64-linux" = getSrcFromPypi { + platform = "manylinux2014_x86_64"; + hash = "sha256-nyylSZfqHeftlvVgJZFCN1ldjluZVJIYu4ZSsVxvXf8="; + }; + "aarch64-darwin" = getSrcFromPypi { + platform = "macosx_11_0_arm64"; + hash = "sha256-La3wYbGCjWTl7krBD6BaBRqyBD8R530Lckbz0AWv0FM="; + }; + "x86_64-darwin" = getSrcFromPypi { + platform = "macosx_10_14_x86_64"; + hash = "sha256-hDg5+qisgtgOrdvbjxsUgI73cW6Aah8NLjhPe4kMAsM="; + }; + }; - pythonVersion = python.pythonVersion; # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See - # https://github.com/google/jax/issues/12879 as to why this specific URL is - # the correct index. - cpuSrcs = { - "x86_64-linux" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl"; - hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ="; - }; - "aarch64-darwin" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl"; - hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U="; - }; - "x86_64-darwin" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl"; - hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok="; - }; + # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. + gpuSrc = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; + hash = "sha256-CcQ5kjp4XfUX4/RwFY3T5G3kVKAeyoCTXu1Lo4O16Qo="; }; - gpuSrc = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl"; - hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk="; - }; in -buildPythonPackage rec { +buildPythonPackage { pname = "jaxlib"; inherit version; format = "wheel"; - # At the time of writing (2022-10-19), there are releases for <=3.10. - # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs - # python version. disabled = !(pythonVersion == "3.10"); # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. @@ -87,9 +99,10 @@ buildPythonPackage rec { # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. - nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ]; + nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ] + ++ lib.optionals cudaSupport [ addOpenGLRunpath ]; # Dynamic link dependencies - buildInputs = [ stdenv.cc.cc ]; + buildInputs = [ stdenv.cc.cc.lib ]; # jaxlib contains shared libraries that open other shared libraries via dlopen # and these implicit dependencies are not recognized by ldd or @@ -113,7 +126,12 @@ buildPythonPackage rec { done ''; - propagatedBuildInputs = [ absl-py flatbuffers scipy ]; + propagatedBuildInputs = [ + absl-py + flatbuffers + ml-dtypes + scipy + ]; # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH. # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for @@ -123,7 +141,7 @@ buildPythonPackage rec { ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas ''; - pythonImportsCheck = [ "jaxlib" ]; + inherit (jaxlib-build) pythonImportsCheck; meta = with lib; { description = "XLA library for JAX"; diff --git a/pkgs/development/python-modules/jaxlib/prefetch.sh b/pkgs/development/python-modules/jaxlib/prefetch.sh index 31db6530639f..3362e2d0b781 100755 --- a/pkgs/development/python-modules/jaxlib/prefetch.sh +++ b/pkgs/development/python-modules/jaxlib/prefetch.sh @@ -1,7 +1,15 @@ -version="$1" -nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)" -nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)" -nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)" -nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)" -nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)" -nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)" +#!/usr/bin/env bash + +prefetch () { + expr="(import { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url" + url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r) + echo "$url" + sha256=$(nix-prefetch-url "$url") + nix hash to-sri --type sha256 "$sha256" + echo +} + +prefetch "x86_64-linux" "false" +prefetch "aarch64-darwin" "false" +prefetch "x86_64-darwin" "false" +prefetch "x86_64-linux" "true" From 06ef57da87f0c88778fbf2b2496663f68a6927ed Mon Sep 17 00:00:00 2001 From: Gaetan Lepage Date: Tue, 18 Apr 2023 17:31:29 +0200 Subject: [PATCH 4/6] python3Packages.jax: 0.4.5 -> 0.4.14 --- .../python-modules/jax/default.nix | 40 +++++++------------ 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix index 4901467262f3..b22d82d7f22f 100644 --- a/pkgs/development/python-modules/jax/default.nix +++ b/pkgs/development/python-modules/jax/default.nix @@ -1,13 +1,14 @@ { lib -, absl-py , blas , buildPythonPackage -, etils +, setuptools +, importlib-metadata , fetchFromGitHub , jaxlib , jaxlib-bin , lapack , matplotlib +, ml-dtypes , numpy , opt-einsum , pytestCheckHook @@ -15,7 +16,6 @@ , pythonOlder , scipy , stdenv -, typing-extensions }: let @@ -27,30 +27,32 @@ let in buildPythonPackage rec { pname = "jax"; - version = "0.4.5"; - format = "setuptools"; + version = "0.4.14"; + format = "pyproject"; - disabled = pythonOlder "3.7"; + disabled = pythonOlder "3.9"; src = fetchFromGitHub { owner = "google"; repo = pname; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/${pname}-v${version}"; - hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA="; + hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg="; }; + nativeBuildInputs = [ + setuptools + ]; + # jaxlib is _not_ included in propagatedBuildInputs because there are # different versions of jaxlib depending on the desired target hardware. The # JAX project ships separate wheels for CPU, GPU, and TPU. propagatedBuildInputs = [ - absl-py - etils + ml-dtypes numpy opt-einsum scipy - typing-extensions - ] ++ etils.optional-dependencies.epath; + ] ++ lib.optional (pythonOlder "3.10") importlib-metadata; nativeCheckInputs = [ jaxlib' @@ -96,24 +98,12 @@ buildPythonPackage rec { "testScanGrad_jit_scan" ]; - # See https://github.com/google/jax/issues/11722. This is a temporary fix in - # order to unblock etils, and upgrading jax/jaxlib to the latest version. See - # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993. - disabledTestPaths = [ - "tests/api_test.py" - "tests/core_test.py" - "tests/lax_numpy_indexing_test.py" - "tests/lax_numpy_test.py" - "tests/nn_test.py" - "tests/random_test.py" - "tests/sparse_test.py" - ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ + disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # RuntimeWarning: invalid value encountered in cast "tests/lax_test.py" ]; - # As of 0.3.22, `import jax` does not work without jaxlib being installed. - pythonImportsCheck = [ ]; + pythonImportsCheck = [ "jax" ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code"; From 6232bc9c7bc1974770751a50c2aa4f3d14e4e133 Mon Sep 17 00:00:00 2001 From: Gaetan Lepage Date: Tue, 18 Apr 2023 17:31:13 +0200 Subject: [PATCH 5/6] python3Packages.jaxlib: 0.4.4 -> 0.4.14 --- .../python-modules/jaxlib/default.nix | 78 ++++++++++--------- pkgs/top-level/python-packages.nix | 1 - 2 files changed, 40 insertions(+), 39 deletions(-) diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index bf93bf1a5a26..8670231e0608 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -4,7 +4,7 @@ # Build-time dependencies: , addOpenGLRunpath -, bazel_5 +, bazel_6 , binutils , buildBazelPackage , buildPythonPackage @@ -21,11 +21,13 @@ , setuptools , symlinkJoin , wheel +, build , which # Python dependencies: , absl-py , flatbuffers +, ml-dtypes , numpy , scipy , six @@ -35,7 +37,6 @@ , giflib , grpc , libjpeg_turbo -, protobuf , python , snappy , zlib @@ -53,7 +54,7 @@ let inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl; pname = "jaxlib"; - version = "0.4.4"; + version = "0.4.14"; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; @@ -99,7 +100,9 @@ let # "com_github_googleapis_googleapis" # "com_github_googlecloudplatform_google_cloud_cpp" "com_github_grpc_grpc" - "com_google_protobuf" + # ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain': + # target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel + # "com_google_protobuf" # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)' # "com_googlesource_code_re2" "curl" @@ -120,7 +123,9 @@ let "org_sqlite" "pasta" "png" - "pybind11" + # ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx': + # target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel + # "pybind11" "six_archive" "snappy" "tblib_archive" @@ -138,14 +143,15 @@ let bazel-build = buildBazelPackage rec { name = "bazel-build-${pname}-${version}"; - bazel = bazel_5; + # See https://github.com/google/jax/blob/main/.bazelversion for the latest. + bazel = bazel_6; src = fetchFromGitHub { owner = "google"; repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! rev = "refs/tags/${pname}-v${version}"; - hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo="; + hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg="; }; nativeBuildInputs = [ @@ -154,6 +160,7 @@ let git setuptools wheel + build which ] ++ lib.optionals stdenv.isDarwin [ cctools @@ -169,7 +176,7 @@ let numpy openssl pkgs.flatbuffers - protobuf + pkgs.protobuf pybind11 scipy six @@ -188,7 +195,8 @@ let rm -f .bazelversion ''; - bazelTargets = [ "//build:build_wheel" ]; + bazelRunTarget = "//jaxlib/tools:build_wheel"; + runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ]; removeRulesCC = false; @@ -207,7 +215,11 @@ let build --action_env=PYENV_ROOT build --python_path="${python}/bin/python" build --distinct_host_configuration=false - build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include" + build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" + '' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) '' + build --config=avx_posix + '' + lib.optionalString mklSupport '' + build --config=mkl_open_source_only '' + lib.optionalString cudaSupport '' build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}" build --action_env CUDNN_INSTALL_PATH="${cudnn}" @@ -234,7 +246,7 @@ let fetchAttrs = { TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs; # we have to force @mkl_dnn_v1 since it's not needed on darwin - bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ]; + bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ]; bazelFlags = bazelFlags ++ [ "--config=avx_posix" ] ++ lib.optionals cudaSupport [ @@ -249,9 +261,9 @@ let sha256 = if cudaSupport then - "sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk=" + "sha256-8QaXoZq6oITRsYn4RdLUXcKQv3PJ4Q3ItX9PkBwxGBI=" else - "sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI="; + "sha256-M/h5EZmyiV4QvzgKRjdz7V1LHENUJlc/ig1QAItnWVQ="; }; buildAttrs = { @@ -261,25 +273,13 @@ let "nsync" # fails to build on darwin ]); - bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [ - "--config=avx_posix" - ] ++ lib.optionals cudaSupport [ - "--config=cuda" - ] ++ lib.optionals mklSupport [ - "--config=mkl_open_source_only" - ]; # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet. - # 1) Fix pybind11 include paths. - # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on + # 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # loading multiple extensions in the same python program due to duplicate protobuf DBs. - # 3) Patch python path in the compiler driver. - preBuild = '' - for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do - sed -i 's@include/pybind11@pybind11@g' $src - done - '' + lib.optionalString cudaSupport '' + # 2) Patch python path in the compiler driver. + preBuild = lib.optionalString cudaSupport '' export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib" - patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl + patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl '' + lib.optionalString stdenv.isDarwin '' # Framework search paths aren't added by bintools hook # https://github.com/NixOS/nixpkgs/pull/41914 @@ -289,16 +289,12 @@ let substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ --replace "/usr/bin/libtool" "${cctools}/bin/libtool" '' + (if stdenv.cc.isGNU then '' - sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD - sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD + sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD + sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD '' else if stdenv.cc.isClang then '' - sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD - sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD + sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD + sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD '' else throw "Unsupported stdenv.cc: ${stdenv.cc}"); - - installPhase = '' - ./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch} - ''; }; inherit meta; @@ -345,13 +341,19 @@ buildPythonPackage { grpc jsoncpp libjpeg_turbo + ml-dtypes numpy scipy six snappy ]; - pythonImportsCheck = [ "jaxlib" ]; + pythonImportsCheck = [ + "jaxlib" + # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade. + "jaxlib.cpu_feature_guard" + "jaxlib.xla_client" + ]; # Without it there are complaints about libcudart.so.11.0 not being found # because RPATH path entries added above are stripped. diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index 21460fef30c0..af7fa4a4693f 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -5310,7 +5310,6 @@ self: super: with self; { # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'. inherit (pkgs.config) cudaSupport; IOKit = pkgs.darwin.apple_sdk_11_0.IOKit; - protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21 }; jaxlib = self.jaxlib-build; From f4d63170eeb5edc3b464c04ef79de1774f728a35 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Thu, 3 Aug 2023 09:05:20 +0800 Subject: [PATCH 6/6] python3Packages.jaxlib: fix dependency hash on aarch64-linux --- pkgs/development/python-modules/jaxlib/default.nix | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index 8670231e0608..070516deefd9 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -259,11 +259,12 @@ let "--config=mkl_open_source_only" ]; - sha256 = - if cudaSupport then - "sha256-8QaXoZq6oITRsYn4RdLUXcKQv3PJ4Q3ItX9PkBwxGBI=" - else - "sha256-M/h5EZmyiV4QvzgKRjdz7V1LHENUJlc/ig1QAItnWVQ="; + sha256 = (if cudaSupport then { + x86_64-linux = "sha256-8QaXoZq6oITRsYn4RdLUXcKQv3PJ4Q3ItX9PkBwxGBI="; + } else { + x86_64-linux = "sha256-M/h5EZmyiV4QvzgKRjdz7V1LHENUJlc/ig1QAItnWVQ="; + aarch64-linux = "sha256-edkYcdlvOLNGRSanch1fGCZwq8SFn3TzcUNt1LhzG/E="; + }).${stdenv.system} or (throw "jaxlib: unsupported system: ${stdenv.system}"); }; buildAttrs = {