python311Packages.diffusers: init at 0.24.0

This commit is contained in:
natsukium 2023-12-09 18:15:46 +09:00
parent cbef97d927
commit be762d1df8
No known key found for this signature in database
GPG Key ID: 9EA45A31DB994C53
2 changed files with 155 additions and 0 deletions

View File

@ -0,0 +1,153 @@
{ lib
, stdenv
, buildPythonPackage
, fetchFromGitHub
, pythonOlder
, writeText
, setuptools
, wheel
, filelock
, huggingface-hub
, importlib-metadata
, numpy
, pillow
, regex
, requests
, safetensors
# optional dependencies
, accelerate
, datasets
, flax
, jax
, jaxlib
, jinja2
, protobuf
, tensorboard
, torch
# test dependencies
, parameterized
, pytest-timeout
, pytest-xdist
, pytestCheckHook
, requests-mock
, ruff
, scipy
, sentencepiece
, torchsde
, transformers
}:
buildPythonPackage rec {
pname = "diffusers";
version = "0.24.0";
pyproject = true;
disabled = pythonOlder "3.8";
src = fetchFromGitHub {
owner = "huggingface";
repo = "diffusers";
rev = "refs/tags/v${version}";
hash = "sha256-ccWF8hQzPhFY/kqRum2tbanI+cQiT25MmvPZN+hGadc=";
};
nativeBuildInputs = [
setuptools
wheel
];
propagatedBuildInputs = [
filelock
huggingface-hub
importlib-metadata
numpy
pillow
regex
requests
safetensors
];
passthru.optional-dependencies = {
flax = [
flax
jax
jaxlib
];
torch = [
accelerate
torch
];
training = [
accelerate
datasets
jinja2
protobuf
tensorboard
];
};
pythonImportsCheck = [
"diffusers"
];
# tests crash due to torch segmentation fault
doCheck = !(stdenv.isLinux && stdenv.isAarch64);
nativeCheckInputs = [
parameterized
pytest-timeout
pytest-xdist
pytestCheckHook
requests-mock
ruff
scipy
sentencepiece
torchsde
transformers
] ++ passthru.optional-dependencies.torch;
preCheck = let
# This pytest hook mocks and catches attempts at accessing the network
# tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
# cf. python3Packages.shap
conftestSkipNetworkErrors = writeText "conftest.py" ''
from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
import urllib3
class NetworkAccessDeniedError(RuntimeError): pass
def deny_network_access(*a, **kw):
raise NetworkAccessDeniedError
urllib3.connection.HTTPSConnection._new_conn = deny_network_access
def pytest_runtest_makereport(item, call):
tr = orig_pytest_runtest_makereport(item, call)
if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
tr.outcome = 'skipped'
tr.wasxfail = "reason: Requires network access."
return tr
'';
in ''
export HOME=$TMPDIR
cat ${conftestSkipNetworkErrors} >> tests/conftest.py
'';
pytestFlagsArray = [
"tests/"
];
disabledTests = [
# depends on current working directory
"test_deprecate_stacklevel"
# fails due to precision of floating point numbers
"test_model_cpu_offload_forward_pass"
];
meta = with lib; {
description = "State-of-the-art diffusion models for image and audio generation in PyTorch";
homepage = "https://github.com/huggingface/diffusers";
changelog = "https://github.com/huggingface/diffusers/releases/tag/${src.rev}";
license = licenses.asl20;
maintainers = with maintainers; [ natsukium ];
};
}

View File

@ -2829,6 +2829,8 @@ self: super: with self; {
diffsync = callPackage ../development/python-modules/diffsync { };
diffusers = callPackage ../development/python-modules/diffusers { };
digital-ocean = callPackage ../development/python-modules/digitalocean { };
digi-xbee = callPackage ../development/python-modules/digi-xbee { };