python312Packages.flax: 0.8.5 -> 0.9.0 (#342970)

This commit is contained in:
Peder Bergebakken Sundt 2024-09-20 20:58:26 +02:00 committed by GitHub
commit ffe5c6cc56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 77 additions and 6 deletions

View File

@ -1,7 +1,6 @@
{
lib,
buildPythonPackage,
pythonOlder,
fetchFromGitHub,
# build-system
@ -26,6 +25,7 @@
pytest-xdist,
pytestCheckHook,
tensorflow,
treescope,
# optional-dependencies
matplotlib,
@ -33,16 +33,14 @@
buildPythonPackage rec {
pname = "flax";
version = "0.8.5";
version = "0.9.0";
pyproject = true;
disabled = pythonOlder "3.9";
src = fetchFromGitHub {
owner = "google";
repo = "flax";
rev = "refs/tags/v${version}";
hash = "sha256-6WOFq0758gtNdrlWqSQBlKmWVIGe5e4PAaGrvHoGjr0=";
hash = "sha256-iDWuUJKO7V4QrbVsS4ALgy6fbllOC43o7W4mhjtZ9xc=";
};
build-system = [
@ -75,6 +73,7 @@ buildPythonPackage rec {
pytest-xdist
pytestCheckHook
tensorflow
treescope
];
pytestFlagsArray = [
@ -95,13 +94,18 @@ buildPythonPackage rec {
"flax/nnx/examples/*"
# See https://github.com/google/flax/issues/3232.
"tests/jax_utils_test.py"
# Requires tree
# Too old version of tensorflow:
# ModuleNotFoundError: No module named 'keras.api._v2'
"tests/tensorboard_test.py"
];
disabledTests = [
# ValueError: Checkpoint path should be absolute
"test_overwrite_checkpoints0"
# Fixed in more recent versions of jax: https://github.com/google/flax/issues/4211
# TODO: Re-enable when jax>0.4.28 will be available in nixpkgs
"test_vmap_and_cond_passthrough" # ValueError: vmap has mapped output but out_axes is None
"test_vmap_and_cond_passthrough_error" # AssertionError: "at vmap.*'broadcast'.*got axis spec ...
];
meta = {

View File

@ -0,0 +1,65 @@
{
lib,
buildPythonPackage,
fetchFromGitHub,
# build-system
flit-core,
# dependencies
numpy,
# optional-dependencies
ipython,
jax,
palettable,
# tests
absl-py,
jaxlib,
pytestCheckHook,
torch,
}:
buildPythonPackage rec {
pname = "treescope";
version = "0.1.5";
pyproject = true;
src = fetchFromGitHub {
owner = "google-deepmind";
repo = "treescope";
rev = "refs/tags/v${version}";
hash = "sha256-+Hm60O9tEXIiE0av1O0BsOdMln4e1s7ijb3WNiQ74jE=";
};
build-system = [ flit-core ];
dependencies = [ numpy ];
optional-dependencies = {
notebook = [
ipython
jax
palettable
];
};
pythonImportsCheck = [ "treescope" ];
nativeCheckInputs = [
absl-py
jax
jaxlib
pytestCheckHook
torch
];
meta = {
description = "An interactive HTML pretty-printer for machine learning research in IPython notebooks";
homepage = "https://github.com/google-deepmind/treescope";
changelog = "https://github.com/google-deepmind/treescope/releases/tag/v${version}";
license = lib.licenses.asl20;
maintainers = with lib.maintainers; [ GaetanLepage ];
};
}

View File

@ -15778,6 +15778,8 @@ self: super: with self; {
treeo = callPackage ../development/python-modules/treeo { };
treescope = callPackage ../development/python-modules/treescope { };
treex = callPackage ../development/python-modules/treex { };
treq = callPackage ../development/python-modules/treq { };