tensorrt: dont break eval for unrelated packages

This commit is contained in:
Someone Serge 2023-11-24 22:46:28 +00:00
parent 5bda2ec626
commit 3ee37e4356
No known key found for this signature in database
GPG Key ID: 7B0E3B1390D61DA4
3 changed files with 32 additions and 13 deletions

View File

@ -17,16 +17,32 @@ final: prev: let
isSupported = fileData: elem cudaVersion fileData.supportedCudaVersions; isSupported = fileData: elem cudaVersion fileData.supportedCudaVersions;
# Return the first file that is supported. In practice there should only ever be one anyway. # Return the first file that is supported. In practice there should only ever be one anyway.
supportedFile = files: findFirst isSupported null files; supportedFile = files: findFirst isSupported null files;
# Supported versions with versions as keys and file as value
supportedVersions = filterAttrs (version: file: file !=null ) (mapAttrs (version: files: supportedFile files) tensorRTVersions);
# Compute versioned attribute name to be used in this package set # Compute versioned attribute name to be used in this package set
computeName = version: "tensorrt_${toUnderscore version}"; computeName = version: "tensorrt_${toUnderscore version}";
# Supported versions with versions as keys and file as value
supportedVersions = lib.recursiveUpdate
{
tensorrt = {
enable = false;
fileVersionCuda = null;
fileVersionCudnn = null;
fullVersion = "0.0.0";
sha256 = null;
tarball = null;
supportedCudaVersions = [ ];
};
}
(mapAttrs' (version: attrs: nameValuePair (computeName version) attrs)
(filterAttrs (version: file: file != null) (mapAttrs (version: files: supportedFile files) tensorRTVersions)));
# Add all supported builds as attributes # Add all supported builds as attributes
allBuilds = mapAttrs' (version: file: nameValuePair (computeName version) (buildTensorRTPackage (removeAttrs file ["fileVersionCuda"]))) supportedVersions; allBuilds = mapAttrs (name: file: buildTensorRTPackage (removeAttrs file ["fileVersionCuda"])) supportedVersions;
# Set the default attributes, e.g. tensorrt = tensorrt_8_4; # Set the default attributes, e.g. tensorrt = tensorrt_8_4;
defaultBuild = { "tensorrt" = if allBuilds ? ${computeName tensorRTDefaultVersion} defaultName = computeName tensorRTDefaultVersion;
then allBuilds.${computeName tensorRTDefaultVersion} defaultBuild = lib.optionalAttrs (allBuilds ? ${defaultName}) { tensorrt = allBuilds.${computeName tensorRTDefaultVersion}; };
else throw "tensorrt-${tensorRTDefaultVersion} does not support your cuda version ${cudaVersion}"; };
in { in {
inherit buildTensorRTPackage; inherit buildTensorRTPackage;
} // allBuilds // defaultBuild; } // allBuilds // defaultBuild;

View File

@ -8,20 +8,22 @@
, cudnn , cudnn
}: }:
{ fullVersion { enable ? true
, fullVersion
, fileVersionCudnn ? null , fileVersionCudnn ? null
, tarball , tarball
, sha256 , sha256
, supportedCudaVersions ? [ ] , supportedCudaVersions ? [ ]
}: }:
assert fileVersionCudnn == null || lib.assertMsg (lib.strings.versionAtLeast cudnn.version fileVersionCudnn) assert !enable || fileVersionCudnn == null || lib.assertMsg (lib.strings.versionAtLeast cudnn.version fileVersionCudnn)
"This version of TensorRT requires at least cuDNN ${fileVersionCudnn} (current version is ${cudnn.version})"; "This version of TensorRT requires at least cuDNN ${fileVersionCudnn} (current version is ${cudnn.version})";
backendStdenv.mkDerivation rec { backendStdenv.mkDerivation rec {
pname = "cudatoolkit-${cudatoolkit.majorVersion}-tensorrt"; pname = "cudatoolkit-${cudatoolkit.majorVersion}-tensorrt";
version = fullVersion; version = fullVersion;
src = requireFile rec { src = if !enable then null else
requireFile rec {
name = tarball; name = tarball;
inherit sha256; inherit sha256;
message = '' message = ''
@ -38,13 +40,13 @@ backendStdenv.mkDerivation rec {
outputs = [ "out" "dev" ]; outputs = [ "out" "dev" ];
nativeBuildInputs = [ nativeBuildInputs = lib.optionals enable [
autoPatchelfHook autoPatchelfHook
autoAddOpenGLRunpathHook autoAddOpenGLRunpathHook
]; ];
# Used by autoPatchelfHook # Used by autoPatchelfHook
buildInputs = [ buildInputs = lib.optionals enable [
backendStdenv.cc.cc.lib # libstdc++ backendStdenv.cc.cc.lib # libstdc++
cudatoolkit cudatoolkit
cudnn cudnn
@ -75,6 +77,7 @@ backendStdenv.mkDerivation rec {
''; '';
passthru.stdenv = backendStdenv; passthru.stdenv = backendStdenv;
passthru.enable = enable;
meta = with lib; { meta = with lib; {
# Check that the cudatoolkit version satisfies our min/max constraints (both # Check that the cudatoolkit version satisfies our min/max constraints (both
@ -82,7 +85,7 @@ backendStdenv.mkDerivation rec {
# official version constraints (as recorded in default.nix). In some cases # official version constraints (as recorded in default.nix). In some cases
# you _may_ be able to smudge version constraints, just know that you're # you _may_ be able to smudge version constraints, just know that you're
# embarking into unknown and unsupported territory when doing so. # embarking into unknown and unsupported territory when doing so.
broken = !(elem cudaVersion supportedCudaVersions); broken = !enable || !(elem cudaVersion supportedCudaVersions);
description = "TensorRT: a high-performance deep learning interface"; description = "TensorRT: a high-performance deep learning interface";
homepage = "https://developer.nvidia.com/tensorrt"; homepage = "https://developer.nvidia.com/tensorrt";
license = licenses.unfree; license = licenses.unfree;

View File

@ -13956,7 +13956,7 @@ self: super: with self; {
tensorly = callPackage ../development/python-modules/tensorly { }; tensorly = callPackage ../development/python-modules/tensorly { };
tensorrt = callPackage ../development/python-modules/tensorrt { }; tensorrt = callPackage ../development/python-modules/tensorrt { cudaPackages = pkgs.cudaPackages_11; };
tensorstore = callPackage ../development/python-modules/tensorstore { }; tensorstore = callPackage ../development/python-modules/tensorstore { };