mirror of
https://github.com/NixOS/nixpkgs.git
synced 2025-01-04 03:53:56 +00:00
python312Packages.numpyro: 0.15.3 -> 0.16.1
Diff: https://github.com/pyro-ppl/numpyro/compare/refs/tags/0.15.3...0.16.1 Changelog: https://github.com/pyro-ppl/numpyro/releases/tag/0.16.1
This commit is contained in:
parent
b0d48f0a52
commit
a6d98860c4
@ -14,30 +14,28 @@
|
||||
tqdm,
|
||||
|
||||
# tests
|
||||
# Our current version of tensorflow (2.13.0) is too old and doesn't support python>=3.12
|
||||
# We remove optional test dependencies that require tensorflow and skip the corresponding tests to
|
||||
# avoid introducing a useless incompatibility with python 3.12:
|
||||
# dm-haiku,
|
||||
# flax,
|
||||
# tensorflow-probability,
|
||||
dm-haiku,
|
||||
flax,
|
||||
funsor,
|
||||
graphviz,
|
||||
optax,
|
||||
pyro-api,
|
||||
pytest-xdist,
|
||||
pytestCheckHook,
|
||||
scikit-learn,
|
||||
tensorflow-probability,
|
||||
}:
|
||||
|
||||
buildPythonPackage rec {
|
||||
pname = "numpyro";
|
||||
version = "0.15.3";
|
||||
version = "0.16.1";
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "pyro-ppl";
|
||||
repo = "numpyro";
|
||||
rev = "refs/tags/${version}";
|
||||
hash = "sha256-g+ep221hhLbCjQasKpiEAXkygI5A3Hglqo1tV8lv5eg=";
|
||||
tag = version;
|
||||
hash = "sha256-6i7LPdmMakGeLujhA9d7Ep9oiVcND3ni/jzUkqgEqxw=";
|
||||
};
|
||||
|
||||
build-system = [ setuptools ];
|
||||
@ -51,19 +49,29 @@ buildPythonPackage rec {
|
||||
];
|
||||
|
||||
nativeCheckInputs = [
|
||||
# dm-haiku
|
||||
# flax
|
||||
dm-haiku
|
||||
flax
|
||||
funsor
|
||||
graphviz
|
||||
optax
|
||||
pyro-api
|
||||
pytest-xdist
|
||||
pytestCheckHook
|
||||
scikit-learn
|
||||
# tensorflow-probability
|
||||
tensorflow-probability
|
||||
];
|
||||
|
||||
pythonImportsCheck = [ "numpyro" ];
|
||||
|
||||
pytestFlagsArray = [
|
||||
# A few tests fail with:
|
||||
# UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1.
|
||||
# Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program.
|
||||
# You can double-check how many devices are available in your system using `jax.local_device_count()`.
|
||||
"-W"
|
||||
"ignore::UserWarning"
|
||||
];
|
||||
|
||||
disabledTests = [
|
||||
# AssertionError due to tolerance issues
|
||||
"test_beta_binomial_log_prob"
|
||||
@ -86,38 +94,13 @@ buildPythonPackage rec {
|
||||
# NameError: unbound axis name: _provenance
|
||||
"test_model_transformation"
|
||||
|
||||
# require dm-haiku
|
||||
"test_flax_state_dropout_smoke"
|
||||
"test_flax_module"
|
||||
"test_random_module_mcmc"
|
||||
|
||||
# require flax
|
||||
"test_haiku_state_dropout_smoke"
|
||||
"test_haiku_module"
|
||||
"test_random_module_mcmc"
|
||||
|
||||
# require tensorflow-probability
|
||||
"test_modified_bessel_first_kind_vect"
|
||||
"test_diag_spectral_density_periodic"
|
||||
"test_kernel_approx_periodic"
|
||||
"test_modified_bessel_first_kind_one_dim"
|
||||
"test_modified_bessel_first_kind_vect"
|
||||
"test_periodic_gp_one_dim_model"
|
||||
"test_no_tracer_leak_at_lazy_property_sample"
|
||||
|
||||
# flaky on darwin
|
||||
# TODO: uncomment at next release (0.15.4) as it has been fixed:
|
||||
# https://github.com/pyro-ppl/numpyro/pull/1863
|
||||
"test_change_point_x64"
|
||||
# ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available (num_replicas=2)
|
||||
"test_chain"
|
||||
];
|
||||
|
||||
disabledTestPaths = [
|
||||
# require jaxns (unpackaged)
|
||||
"test/contrib/test_nested_sampling.py"
|
||||
|
||||
# requires tensorflow-probability
|
||||
"test/contrib/test_tfp.py"
|
||||
"test/test_distributions.py"
|
||||
];
|
||||
|
||||
meta = {
|
||||
|
Loading…
Reference in New Issue
Block a user