nixpkgs/pkgs/development/cuda-modules/flags.nix
Connor Baker 9bebd9e72d tree-wide: cudaPackages should not break default eval
cudaPackages: guard expressions against null values
2024-01-10 01:15:01 +00:00

382 lines
14 KiB
Nix

# Type aliases
# Gpu :: AttrSet
# - See the documentation in ./gpus.nix.
{
config,
cudaCapabilities ? (config.cudaCapabilities or []),
cudaForwardCompat ? (config.cudaForwardCompat or true),
lib,
cudaVersion,
hostPlatform,
# gpus :: List Gpu
gpus,
}:
let
inherit (lib)
asserts
attrsets
lists
strings
trivial
;
# Flags are determined based on your CUDA toolkit by default. You may benefit
# from improved performance, reduced file size, or greater hardware support by
# passing a configuration based on your specific GPU environment.
#
# cudaCapabilities :: List Capability
# List of hardware generations to build.
# E.g. [ "8.0" ]
# Currently, the last item is considered the optional forward-compatibility arch,
# but this may change in the future.
#
# cudaForwardCompat :: Bool
# Whether to include the forward compatibility gencode (+PTX)
# to support future GPU generations.
# E.g. true
#
# Please see the accompanying documentation or https://github.com/NixOS/nixpkgs/pull/205351
# isSupported :: Gpu -> Bool
isSupported =
gpu:
let
inherit (gpu) minCudaVersion maxCudaVersion;
lowerBoundSatisfied = strings.versionAtLeast cudaVersion minCudaVersion;
upperBoundSatisfied =
(maxCudaVersion == null) || !(strings.versionOlder maxCudaVersion cudaVersion);
in
lowerBoundSatisfied && upperBoundSatisfied;
# NOTE: Jetson is never built by default.
# isDefault :: Gpu -> Bool
isDefault =
gpu:
let
inherit (gpu) dontDefaultAfter isJetson;
newGpu = dontDefaultAfter == null;
recentGpu = newGpu || strings.versionAtLeast dontDefaultAfter cudaVersion;
in
recentGpu && !isJetson;
# supportedGpus :: List Gpu
# GPUs which are supported by the provided CUDA version.
supportedGpus = builtins.filter isSupported gpus;
# defaultGpus :: List Gpu
# GPUs which are supported by the provided CUDA version and we want to build for by default.
defaultGpus = builtins.filter isDefault supportedGpus;
# supportedCapabilities :: List Capability
supportedCapabilities = lists.map (gpu: gpu.computeCapability) supportedGpus;
# defaultCapabilities :: List Capability
# The default capabilities to target, if not overridden by the user.
defaultCapabilities = lists.map (gpu: gpu.computeCapability) defaultGpus;
# cudaArchNameToVersions :: AttrSet String (List String)
# Maps the name of a GPU architecture to different versions of that architecture.
# For example, "Ampere" maps to [ "8.0" "8.6" "8.7" ].
cudaArchNameToVersions =
lists.groupBy' (versions: gpu: versions ++ [gpu.computeCapability]) [] (gpu: gpu.archName)
supportedGpus;
# cudaComputeCapabilityToName :: AttrSet String String
# Maps the version of a GPU architecture to the name of that architecture.
# For example, "8.0" maps to "Ampere".
cudaComputeCapabilityToName = builtins.listToAttrs (
lists.map (gpu: attrsets.nameValuePair gpu.computeCapability gpu.archName) supportedGpus
);
# cudaComputeCapabilityToIsJetson :: AttrSet String Boolean
cudaComputeCapabilityToIsJetson = builtins.listToAttrs (
lists.map (attrs: attrsets.nameValuePair attrs.computeCapability attrs.isJetson) supportedGpus
);
# jetsonComputeCapabilities :: List String
jetsonComputeCapabilities = trivial.pipe cudaComputeCapabilityToIsJetson [
(attrsets.filterAttrs (_: isJetson: isJetson))
builtins.attrNames
];
# Find the intersection with the user-specified list of cudaCapabilities.
# NOTE: Jetson devices are never built by default because they cannot be targeted along with
# non-Jetson devices and require an aarch64 host platform. As such, if they're present anywhere,
# they must be in the user-specified cudaCapabilities.
# NOTE: We don't need to worry about mixes of Jetson and non-Jetson devices here -- there's
# sanity-checking for all that in below.
jetsonTargets = lists.intersectLists jetsonComputeCapabilities cudaCapabilities;
# dropDot :: String -> String
dropDot = ver: builtins.replaceStrings ["."] [""] ver;
# archMapper :: String -> List String -> List String
# Maps a feature across a list of architecture versions to produce a list of architectures.
# For example, "sm" and [ "8.0" "8.6" "8.7" ] produces [ "sm_80" "sm_86" "sm_87" ].
archMapper = feat: lists.map (computeCapability: "${feat}_${dropDot computeCapability}");
# gencodeMapper :: String -> List String -> List String
# Maps a feature across a list of architecture versions to produce a list of gencode arguments.
# For example, "sm" and [ "8.0" "8.6" "8.7" ] produces [ "-gencode=arch=compute_80,code=sm_80"
# "-gencode=arch=compute_86,code=sm_86" "-gencode=arch=compute_87,code=sm_87" ].
gencodeMapper =
feat:
lists.map (
computeCapability:
"-gencode=arch=compute_${dropDot computeCapability},code=${feat}_${dropDot computeCapability}"
);
# Maps Nix system to NVIDIA redist arch.
# NOTE: We swap out the default `linux-sbsa` redist (for server-grade ARM chips) with the
# `linux-aarch64` redist (which is for Jetson devices) if we're building any Jetson devices.
# Since both are based on aarch64, we can only have one or the other, otherwise there's an
# ambiguity as to which should be used.
# NOTE: This function *will* be called by unsupported systems because `cudaPackages` is part of
# `all-packages.nix`, which is evaluated on all systems. As such, we need to handle unsupported
# systems gracefully.
# getRedistArch :: String -> String
getRedistArch = nixSystem: attrsets.attrByPath [ nixSystem ] "unsupported" {
aarch64-linux = if jetsonTargets != [] then "linux-aarch64" else "linux-sbsa";
x86_64-linux = "linux-x86_64";
ppc64le-linux = "linux-ppc64le";
x86_64-windows = "windows-x86_64";
};
# Maps NVIDIA redist arch to Nix system.
# NOTE: This function *will* be called by unsupported systems because `cudaPackages` is part of
# `all-packages.nix`, which is evaluated on all systems. As such, we need to handle unsupported
# systems gracefully.
# getNixSystem :: String -> String
getNixSystem = redistArch: attrsets.attrByPath [ redistArch ] "unsupported-${redistArch}" {
linux-sbsa = "aarch64-linux";
linux-aarch64 = "aarch64-linux";
linux-x86_64 = "x86_64-linux";
linux-ppc64le = "ppc64le-linux";
windows-x86_64 = "x86_64-windows";
};
formatCapabilities =
{
cudaCapabilities,
enableForwardCompat ? true,
}:
rec {
inherit cudaCapabilities enableForwardCompat;
# archNames :: List String
# E.g. [ "Turing" "Ampere" ]
#
# Unknown architectures are rendered as sm_XX gencode flags.
archNames = lists.unique (
lists.map (cap: cudaComputeCapabilityToName.${cap} or "sm_${dropDot cap}") cudaCapabilities
);
# realArches :: List String
# The real architectures are physical architectures supported by the CUDA version.
# E.g. [ "sm_75" "sm_86" ]
realArches = archMapper "sm" cudaCapabilities;
# virtualArches :: List String
# The virtual architectures are typically used for forward compatibility, when trying to support
# an architecture newer than the CUDA version allows.
# E.g. [ "compute_75" "compute_86" ]
virtualArches = archMapper "compute" cudaCapabilities;
# arches :: List String
# By default, build for all supported architectures and forward compatibility via a virtual
# architecture for the newest supported architecture.
# E.g. [ "sm_75" "sm_86" "compute_86" ]
arches = realArches ++ lists.optional enableForwardCompat (lists.last virtualArches);
# gencode :: List String
# A list of CUDA gencode arguments to pass to NVCC.
# E.g. [ "-gencode=arch=compute_75,code=sm_75" ... "-gencode=arch=compute_86,code=compute_86" ]
gencode =
let
base = gencodeMapper "sm" cudaCapabilities;
forward = gencodeMapper "compute" [(lists.last cudaCapabilities)];
in
base ++ lib.optionals enableForwardCompat forward;
# gencodeString :: String
# A space-separated string of CUDA gencode arguments to pass to NVCC.
# E.g. "-gencode=arch=compute_75,code=sm_75 ... -gencode=arch=compute_86,code=compute_86"
gencodeString = strings.concatStringsSep " " gencode;
# Jetson devices cannot be targeted by the same binaries which target non-Jetson devices. While
# NVIDIA provides both `linux-aarch64` and `linux-sbsa` packages, which both target `aarch64`,
# they are built with different settings and cannot be mixed.
# isJetsonBuild :: Boolean
isJetsonBuild =
let
requestedJetsonDevices =
lists.filter (cap: cudaComputeCapabilityToIsJetson.${cap} or false)
cudaCapabilities;
requestedNonJetsonDevices =
lists.filter (cap: !(builtins.elem cap requestedJetsonDevices))
cudaCapabilities;
jetsonBuildSufficientCondition = requestedJetsonDevices != [];
jetsonBuildNecessaryCondition = requestedNonJetsonDevices == [] && hostPlatform.isAarch64;
in
trivial.throwIf (jetsonBuildSufficientCondition && !jetsonBuildNecessaryCondition)
''
Jetson devices cannot be targeted with non-Jetson devices. Additionally, they require hostPlatform to be aarch64.
You requested ${builtins.toJSON cudaCapabilities} for host platform ${hostPlatform.system}.
Requested Jetson devices: ${builtins.toJSON requestedJetsonDevices}.
Requested non-Jetson devices: ${builtins.toJSON requestedNonJetsonDevices}.
Exactly one of the following must be true:
- All CUDA capabilities belong to Jetson devices and hostPlatform is aarch64.
- No CUDA capabilities belong to Jetson devices.
See ${./gpus.nix} for a list of architectures supported by this version of Nixpkgs.
''
jetsonBuildSufficientCondition
&& jetsonBuildNecessaryCondition;
};
in
# When changing names or formats: pause, validate, and update the assert
assert let
expected = {
cudaCapabilities = [
"7.5"
"8.6"
];
enableForwardCompat = true;
archNames = [
"Turing"
"Ampere"
];
realArches = [
"sm_75"
"sm_86"
];
virtualArches = [
"compute_75"
"compute_86"
];
arches = [
"sm_75"
"sm_86"
"compute_86"
];
gencode = [
"-gencode=arch=compute_75,code=sm_75"
"-gencode=arch=compute_86,code=sm_86"
"-gencode=arch=compute_86,code=compute_86"
];
gencodeString = "-gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86";
isJetsonBuild = false;
};
actual = formatCapabilities {
cudaCapabilities = [
"7.5"
"8.6"
];
};
actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
in
asserts.assertMsg ((strings.versionAtLeast cudaVersion "11.2") -> (expected == actualWrapped)) ''
This test should only fail when using a version of CUDA older than 11.2, the first to support
8.6.
Expected: ${builtins.toJSON expected}
Actual: ${builtins.toJSON actualWrapped}
'';
# Check mixed Jetson and non-Jetson devices
assert let
expected = false;
actual = formatCapabilities {
cudaCapabilities = [
"7.2"
"7.5"
];
};
actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
in
asserts.assertMsg (expected == actualWrapped) ''
Jetson devices capabilities cannot be mixed with non-jetson devices.
Capability 7.5 is non-Jetson and should not be allowed with Jetson 7.2.
Expected: ${builtins.toJSON expected}
Actual: ${builtins.toJSON actualWrapped}
'';
# Check Jetson-only
assert let
expected = {
cudaCapabilities = [
"6.2"
"7.2"
];
enableForwardCompat = true;
archNames = [
"Pascal"
"Volta"
];
realArches = [
"sm_62"
"sm_72"
];
virtualArches = [
"compute_62"
"compute_72"
];
arches = [
"sm_62"
"sm_72"
"compute_72"
];
gencode = [
"-gencode=arch=compute_62,code=sm_62"
"-gencode=arch=compute_72,code=sm_72"
"-gencode=arch=compute_72,code=compute_72"
];
gencodeString = "-gencode=arch=compute_62,code=sm_62 -gencode=arch=compute_72,code=sm_72 -gencode=arch=compute_72,code=compute_72";
isJetsonBuild = true;
};
actual = formatCapabilities {
cudaCapabilities = [
"6.2"
"7.2"
];
};
actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
in
asserts.assertMsg
# We can't do this test unless we're targeting aarch64
(hostPlatform.isAarch64 -> (expected == actualWrapped))
''
Jetson devices can only be built with other Jetson devices.
Both 6.2 and 7.2 are Jetson devices.
Expected: ${builtins.toJSON expected}
Actual: ${builtins.toJSON actualWrapped}
'';
{
# formatCapabilities :: { cudaCapabilities: List Capability, enableForwardCompat: Boolean } -> { ... }
inherit formatCapabilities;
# cudaArchNameToVersions :: String => String
inherit cudaArchNameToVersions;
# cudaComputeCapabilityToName :: String => String
inherit cudaComputeCapabilityToName;
# dropDot :: String -> String
inherit dropDot;
inherit
defaultCapabilities
supportedCapabilities
jetsonComputeCapabilities
jetsonTargets
getNixSystem
getRedistArch
;
}
// formatCapabilities {
cudaCapabilities = if cudaCapabilities == [] then defaultCapabilities else cudaCapabilities;
enableForwardCompat = cudaForwardCompat;
}