nixpkgs/pkgs/development/python-modules/pytorch/default.nix
Austin Seipp c6ecb3887f pythonPackages.pytorch: fix weirdly broken wheel version number
Otherwise, the wheel gets built with invalid metadata -- causing
'torch >= 1.0.0' to be unsatisfiable in other python packages, for
instance.

Signed-off-by: Austin Seipp <aseipp@pobox.com>
2018-12-25 10:49:39 -06:00

93 lines
2.7 KiB
Nix
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{ buildPythonPackage, pythonOlder,
cudaSupport ? false, cudatoolkit ? null, cudnn ? null,
fetchFromGitHub, lib, numpy, pyyaml, cffi, typing, cmake, hypothesis,
linkFarm, symlinkJoin,
utillinux, which }:
assert cudnn == null || cudatoolkit != null;
assert !cudaSupport || cudatoolkit != null;
let
cudatoolkit_joined = symlinkJoin {
name = "${cudatoolkit.name}-unsplit";
paths = [ cudatoolkit.out cudatoolkit.lib ];
};
# Normally libcuda.so.1 is provided at runtime by nvidia-x11 via
# LD_LIBRARY_PATH=/run/opengl-driver/lib. We only use the stub
# libcuda.so from cudatoolkit for running tests, so that we dont have
# to recompile pytorch on every update to nvidia-x11 or the kernel.
cudaStub = linkFarm "cuda-stub" [{
name = "libcuda.so.1";
path = "${cudatoolkit}/lib/stubs/libcuda.so";
}];
cudaStubEnv = lib.optionalString cudaSupport
"LD_LIBRARY_PATH=${cudaStub}\${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH} ";
in buildPythonPackage rec {
version = "1.0.0";
pname = "pytorch";
src = fetchFromGitHub {
owner = "pytorch";
repo = "pytorch";
rev = "v${version}";
fetchSubmodules = true;
sha256 = "076cpbig4sywn9vv674c0xdg832sdrd5pk1d0725pjkm436kpvlm";
};
preConfigure = lib.optionalString cudaSupport ''
export CC=${cudatoolkit.cc}/bin/gcc CXX=${cudatoolkit.cc}/bin/g++
'' + lib.optionalString (cudaSupport && cudnn != null) ''
export CUDNN_INCLUDE_DIR=${cudnn}/include
'';
preFixup = ''
function join_by { local IFS="$1"; shift; echo "$*"; }
function strip2 {
IFS=':'
read -ra RP <<< $(patchelf --print-rpath $1)
IFS=' '
RP_NEW=$(join_by : ''${RP[@]:2})
patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
}
for f in $(find ''${out} -name 'libcaffe2*.so')
do
strip2 $f
done
'';
# Override the (weirdly) wrong version set by default. See
# https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
# https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
PYTORCH_BUILD_VERSION = version;
PYTORCH_BUILD_NUMBER = 0;
buildInputs = [
cmake
numpy.blas
utillinux
which
] ++ lib.optionals cudaSupport [cudatoolkit_joined cudnn];
propagatedBuildInputs = [
cffi
numpy
pyyaml
] ++ lib.optional (pythonOlder "3.5") typing;
checkInputs = [ hypothesis ];
checkPhase = ''
${cudaStubEnv}python test/run_test.py --exclude dataloader sparse torch utils thd_distributed distributed cpp_extensions
'';
meta = {
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration.";
homepage = https://pytorch.org/;
license = lib.licenses.bsd3;
platforms = lib.platforms.linux;
maintainers = with lib.maintainers; [ teh ];
};
}