tensorrt: dont break eval for unrelated packages

This commit is contained in:
Someone Serge 2023-11-24 22:46:28 +00:00
parent 5bda2ec626
commit 3ee37e4356
No known key found for this signature in database
GPG Key ID: 7B0E3B1390D61DA4
3 changed files with 32 additions and 13 deletions

View File

@ -17,16 +17,32 @@ final: prev: let
isSupported = fileData: elem cudaVersion fileData.supportedCudaVersions;
# Return the first file that is supported. In practice there should only ever be one anyway.
supportedFile = files: findFirst isSupported null files;
# Supported versions with versions as keys and file as value
supportedVersions = filterAttrs (version: file: file !=null ) (mapAttrs (version: files: supportedFile files) tensorRTVersions);
# Compute versioned attribute name to be used in this package set
computeName = version: "tensorrt_${toUnderscore version}";
# Supported versions with versions as keys and file as value
supportedVersions = lib.recursiveUpdate
{
tensorrt = {
enable = false;
fileVersionCuda = null;
fileVersionCudnn = null;
fullVersion = "0.0.0";
sha256 = null;
tarball = null;
supportedCudaVersions = [ ];
};
}
(mapAttrs' (version: attrs: nameValuePair (computeName version) attrs)
(filterAttrs (version: file: file != null) (mapAttrs (version: files: supportedFile files) tensorRTVersions)));
# Add all supported builds as attributes
allBuilds = mapAttrs' (version: file: nameValuePair (computeName version) (buildTensorRTPackage (removeAttrs file ["fileVersionCuda"]))) supportedVersions;
allBuilds = mapAttrs (name: file: buildTensorRTPackage (removeAttrs file ["fileVersionCuda"])) supportedVersions;
# Set the default attributes, e.g. tensorrt = tensorrt_8_4;
defaultBuild = { "tensorrt" = if allBuilds ? ${computeName tensorRTDefaultVersion}
then allBuilds.${computeName tensorRTDefaultVersion}
else throw "tensorrt-${tensorRTDefaultVersion} does not support your cuda version ${cudaVersion}"; };
defaultName = computeName tensorRTDefaultVersion;
defaultBuild = lib.optionalAttrs (allBuilds ? ${defaultName}) { tensorrt = allBuilds.${computeName tensorRTDefaultVersion}; };
in {
inherit buildTensorRTPackage;
} // allBuilds // defaultBuild;

View File

@ -8,20 +8,22 @@
, cudnn
}:
{ fullVersion
{ enable ? true
, fullVersion
, fileVersionCudnn ? null
, tarball
, sha256
, supportedCudaVersions ? [ ]
}:
assert fileVersionCudnn == null || lib.assertMsg (lib.strings.versionAtLeast cudnn.version fileVersionCudnn)
assert !enable || fileVersionCudnn == null || lib.assertMsg (lib.strings.versionAtLeast cudnn.version fileVersionCudnn)
"This version of TensorRT requires at least cuDNN ${fileVersionCudnn} (current version is ${cudnn.version})";
backendStdenv.mkDerivation rec {
pname = "cudatoolkit-${cudatoolkit.majorVersion}-tensorrt";
version = fullVersion;
src = requireFile rec {
src = if !enable then null else
requireFile rec {
name = tarball;
inherit sha256;
message = ''
@ -38,13 +40,13 @@ backendStdenv.mkDerivation rec {
outputs = [ "out" "dev" ];
nativeBuildInputs = [
nativeBuildInputs = lib.optionals enable [
autoPatchelfHook
autoAddOpenGLRunpathHook
];
# Used by autoPatchelfHook
buildInputs = [
buildInputs = lib.optionals enable [
backendStdenv.cc.cc.lib # libstdc++
cudatoolkit
cudnn
@ -75,6 +77,7 @@ backendStdenv.mkDerivation rec {
'';
passthru.stdenv = backendStdenv;
passthru.enable = enable;
meta = with lib; {
# Check that the cudatoolkit version satisfies our min/max constraints (both
@ -82,7 +85,7 @@ backendStdenv.mkDerivation rec {
# official version constraints (as recorded in default.nix). In some cases
# you _may_ be able to smudge version constraints, just know that you're
# embarking into unknown and unsupported territory when doing so.
broken = !(elem cudaVersion supportedCudaVersions);
broken = !enable || !(elem cudaVersion supportedCudaVersions);
description = "TensorRT: a high-performance deep learning interface";
homepage = "https://developer.nvidia.com/tensorrt";
license = licenses.unfree;

View File

@ -13956,7 +13956,7 @@ self: super: with self; {
tensorly = callPackage ../development/python-modules/tensorly { };
tensorrt = callPackage ../development/python-modules/tensorrt { };
tensorrt = callPackage ../development/python-modules/tensorrt { cudaPackages = pkgs.cudaPackages_11; };
tensorstore = callPackage ../development/python-modules/tensorstore { };