diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix new file mode 100644 index 000000000000..edf18bbe2ed0 --- /dev/null +++ b/pkgs/development/python-modules/jax/default.nix @@ -0,0 +1,39 @@ +{ buildPythonPackage, fetchFromGitHub, lib +# propagatedBuildInputs +, absl-py, numpy, opt-einsum +# checkInputs +, jaxlib, pytestCheckHook +}: + +buildPythonPackage rec { + pname = "jax"; + version = "0.2.19"; + + # Fetching from pypi doesn't allow us to run the test suite. See https://discourse.nixos.org/t/pythonremovetestsdir-hook-being-run-before-checkphase/14612/3. + src = fetchFromGitHub { + owner = "google"; + repo = pname; + rev = "jax-v${version}"; + sha256 = "sha256-pVn62G7pydR7ybkf7gSbu0FlEq2c0US6H2GTBAljup4="; + }; + + # jaxlib is _not_ included in propagatedBuildInputs because there are + # different versions of jaxlib depending on the desired target hardware. The + # JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the + # CPU wheel is packaged. + propagatedBuildInputs = [ absl-py numpy opt-einsum ]; + + checkInputs = [ jaxlib pytestCheckHook ]; + # NOTE: Don't run the tests in the expiremental directory as they require flax + # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. + # Not a big deal, this is how the JAX docs suggest running the test suite + # anyhow. + pytestFlagsArray = [ "-W ignore::DeprecationWarning" "tests/" ]; + + meta = with lib; { + description = "Differentiate, compile, and transform Numpy code."; + homepage = "https://github.com/google/jax"; + license = licenses.asl20; + maintainers = with maintainers; [ samuela ]; + }; +} diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index fa17a8a7e7ee..c844eea23dce 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -3716,6 +3716,8 @@ in { javaproperties = callPackage ../development/python-modules/javaproperties { }; + jax = callPackage ../development/python-modules/jax { }; + jaxlib = callPackage ../development/python-modules/jaxlib { }; JayDeBeApi = callPackage ../development/python-modules/JayDeBeApi { };