mirror of
https://github.com/NixOS/nixpkgs.git
synced 2024-11-30 19:02:57 +00:00
commit
9f38e1c714
@ -8,6 +8,7 @@
|
|||||||
fetchFromGitHub,
|
fetchFromGitHub,
|
||||||
jaxlib,
|
jaxlib,
|
||||||
jaxlib-bin,
|
jaxlib-bin,
|
||||||
|
jaxlib-build,
|
||||||
hypothesis,
|
hypothesis,
|
||||||
lapack,
|
lapack,
|
||||||
matplotlib,
|
matplotlib,
|
||||||
@ -23,10 +24,6 @@
|
|||||||
|
|
||||||
let
|
let
|
||||||
usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
|
usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
|
||||||
# jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work
|
|
||||||
# fine. jaxlib is only used in the checkPhase, so switching backends does not
|
|
||||||
# impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*.
|
|
||||||
jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib;
|
|
||||||
in
|
in
|
||||||
buildPythonPackage rec {
|
buildPythonPackage rec {
|
||||||
pname = "jax";
|
pname = "jax";
|
||||||
@ -61,7 +58,7 @@ buildPythonPackage rec {
|
|||||||
|
|
||||||
nativeCheckInputs = [
|
nativeCheckInputs = [
|
||||||
hypothesis
|
hypothesis
|
||||||
jaxlib'
|
jaxlib
|
||||||
matplotlib
|
matplotlib
|
||||||
pytestCheckHook
|
pytestCheckHook
|
||||||
pytest-xdist
|
pytest-xdist
|
||||||
@ -130,7 +127,11 @@ buildPythonPackage rec {
|
|||||||
"testQdwhWithOnRankDeficientInput5"
|
"testQdwhWithOnRankDeficientInput5"
|
||||||
];
|
];
|
||||||
|
|
||||||
disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
|
disabledTestPaths = [
|
||||||
|
# Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba
|
||||||
|
"tests/linalg_test.py"
|
||||||
|
]
|
||||||
|
++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
|
||||||
# RuntimeWarning: invalid value encountered in cast
|
# RuntimeWarning: invalid value encountered in cast
|
||||||
"tests/lax_test.py"
|
"tests/lax_test.py"
|
||||||
];
|
];
|
||||||
@ -147,7 +148,7 @@ buildPythonPackage rec {
|
|||||||
# NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin
|
# NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin
|
||||||
passthru.tests = {
|
passthru.tests = {
|
||||||
test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
|
test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
|
||||||
jaxlib = jaxlib.override { cudaSupport = true; };
|
jaxlib = jaxlib-build.override { cudaSupport = true; };
|
||||||
};
|
};
|
||||||
test_cuda_jaxlibBin = callPackage ./test-cuda.nix {
|
test_cuda_jaxlibBin = callPackage ./test-cuda.nix {
|
||||||
jaxlib = jaxlib-bin.override { cudaSupport = true; };
|
jaxlib = jaxlib-bin.override { cudaSupport = true; };
|
||||||
@ -158,7 +159,11 @@ buildPythonPackage rec {
|
|||||||
passthru.skipBulkUpdate = true;
|
passthru.skipBulkUpdate = true;
|
||||||
|
|
||||||
meta = with lib; {
|
meta = with lib; {
|
||||||
description = "Differentiate, compile, and transform Numpy code";
|
description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
|
||||||
|
longDescription = ''
|
||||||
|
This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
|
||||||
|
e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
|
||||||
|
'';
|
||||||
homepage = "https://github.com/google/jax";
|
homepage = "https://github.com/google/jax";
|
||||||
license = licenses.asl20;
|
license = licenses.asl20;
|
||||||
maintainers = with maintainers; [ samuela ];
|
maintainers = with maintainers; [ samuela ];
|
||||||
|
@ -225,7 +225,7 @@ buildPythonPackage {
|
|||||||
inherit (jaxlib-build) pythonImportsCheck;
|
inherit (jaxlib-build) pythonImportsCheck;
|
||||||
|
|
||||||
meta = with lib; {
|
meta = with lib; {
|
||||||
description = "XLA library for JAX";
|
description = "Prebuilt jaxlib backend from PyPi";
|
||||||
homepage = "https://github.com/google/jax";
|
homepage = "https://github.com/google/jax";
|
||||||
sourceProvenance = with sourceTypes; [ binaryNativeCode ];
|
sourceProvenance = with sourceTypes; [ binaryNativeCode ];
|
||||||
license = licenses.asl20;
|
license = licenses.asl20;
|
||||||
|
@ -67,16 +67,17 @@ let
|
|||||||
effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
|
effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
|
||||||
|
|
||||||
meta = with lib; {
|
meta = with lib; {
|
||||||
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research";
|
description = "Source-built JAX backend. JAX is Autograd and XLA, brought together for high-performance machine learning research";
|
||||||
homepage = "https://github.com/google/jax";
|
homepage = "https://github.com/google/jax";
|
||||||
license = licenses.asl20;
|
license = licenses.asl20;
|
||||||
maintainers = with maintainers; [ ndl ];
|
maintainers = with maintainers; [ ndl ];
|
||||||
platforms = platforms.unix;
|
|
||||||
|
# Make this platforms.unix once Darwin is supported.
|
||||||
|
# The top-level jaxlib now falls back to jaxlib-bin on unsupported platforms.
|
||||||
# aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
|
# aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
|
||||||
# however even with that fix applied, it doesn't work for everyone:
|
# however even with that fix applied, it doesn't work for everyone:
|
||||||
# https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
|
# https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
|
||||||
# NOTE: We always build with NCCL; if it is unsupported, then our build is broken.
|
platforms = platforms.linux;
|
||||||
broken = effectiveStdenv.isDarwin || nccl.meta.unsupported;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
# Bazel wants a merged cudnn at configuration time
|
# Bazel wants a merged cudnn at configuration time
|
||||||
|
@ -6121,13 +6121,14 @@ self: super: with self; {
|
|||||||
IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
|
IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
|
||||||
};
|
};
|
||||||
|
|
||||||
jaxlib = self.jaxlib-build;
|
# Use the -bin on macOS since the source build doesn't support it (see #323154)
|
||||||
|
jaxlib = if jaxlib-build.meta.unsupported then jaxlib-bin else jaxlib-build;
|
||||||
|
|
||||||
jaxlibWithCuda = self.jaxlib-build.override {
|
jaxlibWithCuda = self.jaxlib.override {
|
||||||
cudaSupport = true;
|
cudaSupport = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
jaxlibWithoutCuda = self.jaxlib-build.override {
|
jaxlibWithoutCuda = self.jaxlib.override {
|
||||||
cudaSupport = false;
|
cudaSupport = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user