From 4337d63f1fae717d87efa089a96b1a2052255195 Mon Sep 17 00:00:00 2001 From: cfhammill Date: Mon, 23 Sep 2024 12:05:57 -0400 Subject: [PATCH] python3Packages.mamba-ssm: init at 2.2.2 --- .../python-modules/causal-conv1d/default.nix | 69 +++++++++++++++++ .../python-modules/mamba-ssm/default.nix | 77 +++++++++++++++++++ pkgs/top-level/python-packages.nix | 4 + 3 files changed, 150 insertions(+) create mode 100644 pkgs/development/python-modules/causal-conv1d/default.nix create mode 100644 pkgs/development/python-modules/mamba-ssm/default.nix diff --git a/pkgs/development/python-modules/causal-conv1d/default.nix b/pkgs/development/python-modules/causal-conv1d/default.nix new file mode 100644 index 000000000000..0653959ed6da --- /dev/null +++ b/pkgs/development/python-modules/causal-conv1d/default.nix @@ -0,0 +1,69 @@ +{ + lib, + buildPythonPackage, + fetchFromGitHub, + ninja, + setuptools, + torch, + cudaPackages, + rocmPackages, + config, + cudaSupport ? config.cudaSupport, + which, +}: + +buildPythonPackage rec { + pname = "causal-conv1d"; + version = "1.4.0"; + pyproject = true; + + src = fetchFromGitHub { + owner = "Dao-AILab"; + repo = "causal-conv1d"; + rev = "refs/tags/v${version}"; + hash = "sha256-p5x5u3zEmEMN3mWd88o3jmcpKUnovTvn7I9jIOj/ie0="; + }; + + build-system = [ + ninja + setuptools + torch + ]; + + nativeBuildInputs = [ which ]; + + buildInputs = ( + lib.optionals cudaSupport ( + with cudaPackages; + [ + cuda_cudart # cuda_runtime.h, -lcudart + cuda_cccl + libcusparse # cusparse.h + libcusolver # cusolverDn.h + cuda_nvcc + libcublas + ] + ) + ); + + dependencies = [ + torch + ]; + + # pytest tests not enabled due to nvidia GPU dependency + pythonImportsCheck = [ "causal_conv1d" ]; + + env = { + CAUSAL_CONV1D_FORCE_BUILD = "TRUE"; + } // lib.optionalAttrs cudaSupport { CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}"; }; + + meta = with lib; { + description = "Causal depthwise conv1d in CUDA with a PyTorch interface"; + homepage = "https://github.com/Dao-AILab/causal-conv1d"; + license = licenses.bsd3; + maintainers = with maintainers; [ cfhammill ]; + # The package requires CUDA or ROCm, the ROCm build hasn't + # been completed or tested, so broken if not using cuda. + broken = !cudaSupport; + }; +} diff --git a/pkgs/development/python-modules/mamba-ssm/default.nix b/pkgs/development/python-modules/mamba-ssm/default.nix new file mode 100644 index 000000000000..11ac68c1e19d --- /dev/null +++ b/pkgs/development/python-modules/mamba-ssm/default.nix @@ -0,0 +1,77 @@ +{ + lib, + buildPythonPackage, + fetchFromGitHub, + causal-conv1d, + einops, + ninja, + setuptools, + torch, + transformers, + triton, + cudaPackages, + rocmPackages, + config, + cudaSupport ? config.cudaSupport, + which, +}: + +buildPythonPackage rec { + pname = "mamba"; + version = "2.2.2"; + pyproject = true; + + src = fetchFromGitHub { + owner = "state-spaces"; + repo = "mamba"; + rev = "refs/tags/v${version}"; + hash = "sha256-R702JjM3AGk7upN7GkNK8u1q4ekMK9fYQkpO6Re45Ng="; + }; + + build-system = [ + ninja + setuptools + torch + ]; + + nativeBuildInputs = [ which ]; + + buildInputs = ( + lib.optionals cudaSupport ( + with cudaPackages; + [ + cuda_cudart # cuda_runtime.h, -lcudart + cuda_cccl + libcusparse # cusparse.h + libcusolver # cusolverDn.h + cuda_nvcc + libcublas + ] + ) + ); + + dependencies = [ + causal-conv1d + einops + torch + transformers + triton + ]; + + env = { + MAMBA_FORCE_BUILD = "TRUE"; + } // lib.optionalAttrs cudaSupport { CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}"; }; + + # pytest tests not enabled due to nvidia GPU dependency + pythonImportsCheck = [ "mamba_ssm" ]; + + meta = with lib; { + description = "Linear-Time Sequence Modeling with Selective State Spaces"; + homepage = "https://github.com/state-spaces/mamba"; + license = licenses.asl20; + maintainers = with maintainers; [ cfhammill ]; + # The package requires CUDA or ROCm, the ROCm build hasn't + # been completed or tested, so broken if not using cuda. + broken = !cudaSupport; + }; +} diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index 091be094c5b2..8ae4f18ef7e7 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -2050,6 +2050,8 @@ self: super: with self; { cattrs = callPackage ../development/python-modules/cattrs { }; + causal-conv1d = callPackage ../development/python-modules/causal-conv1d { }; + cbor2 = callPackage ../development/python-modules/cbor2 { }; cbor = callPackage ../development/python-modules/cbor { }; @@ -7539,6 +7541,8 @@ self: super: with self; { malduck = callPackage ../development/python-modules/malduck { }; + mamba-ssm = callPackage ../development/python-modules/mamba-ssm { }; + managesieve = callPackage ../development/python-modules/managesieve { }; mando = callPackage ../development/python-modules/mando { };