python3Packages.torch{,-bin}: test torch.cuda.is_available()

This commit is contained in:
Someone Serge 2023-09-20 07:14:07 +03:00
parent 7ed2cba5e8
commit f22b9da6b8
3 changed files with 35 additions and 1 deletions

View File

@ -8,6 +8,7 @@
pythonAtLeast,
pythonOlder,
addOpenGLRunpath,
callPackage,
cudaPackages,
future,
numpy,
@ -15,6 +16,7 @@
pyyaml,
requests,
setuptools,
torch-bin,
typing-extensions,
sympy,
jinja2,
@ -119,6 +121,8 @@ buildPythonPackage {
pythonImportsCheck = [ "torch" ];
passthru.tests.cudaAvailable = callPackage ./test-cuda.nix { torch = torch-bin; };
meta = {
description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
homepage = "https://pytorch.org/";

View File

@ -24,6 +24,10 @@
mpi,
buildDocs ? false,
# tests.cudaAvailable:
callPackage,
torch,
# Native build inputs
cmake,
symlinkJoin,
@ -639,11 +643,16 @@ buildPythonPackage rec {
rocmSupport
rocmPackages
;
cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
# At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
blasProvider = blas.provider;
# To help debug when a package is broken due to CUDA support
inherit brokenConditions;
cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
} // lib.optionalAttrs cudaSupport {
tests = lib.optionalAttrs cudaSupport {
cudaAvailable = callPackage ./test-cuda.nix { inherit torch; };
};
};
meta = {

View File

@ -0,0 +1,21 @@
{ runCommandNoCC
, python
, torch
}:
runCommandNoCC "${torch.name}-gpu-test"
{
nativeBuildInputs = [
(python.withPackages (_: [ torch ]))
];
requiredSystemFeatures = [
"cuda"
];
} ''
python3 << EOF
import torch
assert torch.cuda.is_available(), f"{torch.cuda.is_available()=}"
EOF
touch $out
''