Merge pull request #215047 from bcdarwin/blackjax

This commit is contained in:
Sandro 2023-02-16 22:23:24 +01:00 committed by GitHub
commit 92c9c2638c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 4 deletions

View File

@ -4,9 +4,11 @@
, fetchFromGitHub
, pytestCheckHook
, arviz
, blackjax
, formulae
, graphviz
, numpy
, numpyro
, pandas
, pymc
, scipy
@ -35,14 +37,16 @@ buildPythonPackage rec {
preCheck = ''export HOME=$(mktemp -d)'';
nativeCheckInputs = [ graphviz pytestCheckHook ];
nativeCheckInputs = [
blackjax
graphviz
numpyro
pytestCheckHook
];
disabledTests = [
# attempt to fetch data:
"test_data_is_copied"
"test_predict_offset"
# require blackjax (not in Nixpkgs), numpyro, and jax:
"test_logistic_regression_categoric_alternative_samplers"
"test_regression_alternative_samplers"
];
pythonImportsCheck = [ "bambi" ];

View File

@ -0,0 +1,62 @@
{ lib
, buildPythonPackage
, pythonOlder
, fetchFromGitHub
, fetchpatch
, pytestCheckHook
, fastprogress
, jax
, jaxlib
, jaxopt
, optax
, typing-extensions
}:
buildPythonPackage rec {
pname = "blackjax";
version = "0.9.6";
disabled = pythonOlder "3.7";
src = fetchFromGitHub {
owner = "blackjax-devs";
repo = pname;
rev = "refs/tags/${version}";
hash = "sha256-EieDu9SJxi2cp1bHlxX4vvFZeDGMGIm24GoR8nSyjvE=";
};
patches = [
# remove in next release
(fetchpatch {
name = "fix-lbfgs-args";
url = "https://github.com/blackjax-devs/blackjax/commit/1aaa6f64bbcb0557b658604b2daba826e260cbc6.patch";
hash = "sha256-XyjorXPH5Ap35Tv1/lTeTWamjplJF29SsvOq59ypftE=";
})
];
propagatedBuildInputs = [
fastprogress
jax
jaxlib
jaxopt
optax
typing-extensions
];
nativeCheckInputs = [ pytestCheckHook ];
disabledTestPaths = [ "tests/test_benchmarks.py" ];
disabledTests = [
# too slow
"test_adaptive_tempered_smc"
];
pythonImportsCheck = [
"blackjax"
];
meta = with lib; {
homepage = "https://blackjax-devs.github.io/blackjax";
description = "Sampling library designed for ease of use, speed and modularity";
license = licenses.asl20;
maintainers = with maintainers; [ bcdarwin ];
};
}

View File

@ -0,0 +1,59 @@
{ lib
, buildPythonPackage
, pythonOlder
, fetchFromGitHub
, pytestCheckHook
, absl-py
, cvxpy
, jax
, jaxlib
, matplotlib
, numpy
, optax
, scipy
, scikitlearn
}:
buildPythonPackage rec {
pname = "jaxopt";
version = "0.5.5";
disabled = pythonOlder "3.7";
src = fetchFromGitHub {
owner = "google";
repo = pname;
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-WOsr/Dvguu9/qX6+LMlAKM3EANtYPtDu8Uo2157+bs0=";
};
propagatedBuildInputs = [
absl-py
jax
jaxlib
matplotlib
numpy
scipy
];
nativeCheckInputs = [
pytestCheckHook
cvxpy
optax
scikitlearn
];
pythonImportsCheck = [
"jaxopt"
"jaxopt.implicit_diff"
"jaxopt.linear_solve"
"jaxopt.loss"
"jaxopt.tree_util"
];
meta = with lib; {
homepage = "https://jaxopt.github.io";
description = "Hardware accelerated, batchable and differentiable optimizers in JAX";
license = licenses.asl20;
maintainers = with maintainers; [ bcdarwin ];
};
}

View File

@ -1300,6 +1300,8 @@ self: super: with self; {
black = callPackage ../development/python-modules/black { };
blackjax = callPackage ../development/python-modules/blackjax { };
black-macchiato = callPackage ../development/python-modules/black-macchiato { };
bleach = callPackage ../development/python-modules/bleach { };
@ -4868,6 +4870,8 @@ self: super: with self; {
cudaSupport = false;
};
jaxopt = callPackage ../development/python-modules/jaxopt { };
JayDeBeApi = callPackage ../development/python-modules/JayDeBeApi { };
jc = callPackage ../development/python-modules/jc { };