nixpkgs/pkgs/development/python-modules/triton/default.nix
2024-10-17 01:00:57 +02:00

280 lines
8.1 KiB
Nix

{
lib,
addDriverRunpath,
buildPythonPackage,
cmake,
config,
cudaPackages,
fetchFromGitHub,
filelock,
gtest,
libxml2,
lit,
llvm,
ncurses,
ninja,
pybind11,
python,
pytestCheckHook,
stdenv,
substituteAll,
setuptools,
torchWithRocm,
zlib,
cudaSupport ? config.cudaSupport,
rocmSupport ? config.rocmSupport,
rocmPackages,
triton,
}:
buildPythonPackage {
pname = "triton";
version = "3.1.0";
pyproject = true;
src = fetchFromGitHub {
owner = "triton-lang";
repo = "triton";
# latest branch commit from https://github.com/triton-lang/triton/commits/release/3.1.x/
rev = "cf34004b8a67d290a962da166f5aa2fc66751326";
hash = "sha256-233fpuR7XXOaSKN+slhJbE/CMFzAqCRCE4V4rIoJZrk=";
};
patches =
[
./0001-setup.py-introduce-TRITON_OFFLINE_BUILD.patch
(substituteAll {
src = ./0001-_build-allow-extra-cc-flags.patch;
ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
})
(substituteAll (
{
src = ./0002-nvidia-amd-driver-short-circuit-before-ldconfig.patch;
}
// lib.optionalAttrs rocmSupport { libhipDir = "${lib.getLib rocmPackages.clr}/lib"; }
// lib.optionalAttrs cudaSupport {
libcudaStubsDir = "${lib.getLib cudaPackages.cuda_cudart}/lib/stubs";
ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
}
))
]
++ lib.optionals cudaSupport [
(substituteAll {
src = ./0003-nvidia-cudart-a-systempath.patch;
cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include";
})
(substituteAll {
src = ./0004-nvidia-allow-static-ptxas-path.patch;
nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ];
})
];
postPatch = ''
# Use our `cmakeFlags` instead and avoid downloading dependencies
# remove any downloads
substituteInPlace python/setup.py \
--replace-fail "get_json_package_info(), get_pybind11_package_info()" ""\
--replace-fail "get_pybind11_package_info(), get_llvm_package_info()" ""\
--replace-fail 'packages += ["triton/profiler"]' ""\
--replace-fail "curr_version != version" "False"
# Don't fetch googletest
substituteInPlace unittest/CMakeLists.txt \
--replace-fail "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
--replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"
'';
build-system = [ setuptools ];
nativeBuildInputs = [
cmake
ninja
# Note for future:
# These *probably* should go in depsTargetTarget
# ...but we cannot test cross right now anyway
# because we only support cudaPackages on x86_64-linux atm
lit
llvm
];
buildInputs = [
gtest
libxml2.dev
ncurses
pybind11
zlib
];
dependencies = [
filelock
# triton uses setuptools at runtime:
# https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
setuptools
];
NIX_CFLAGS_COMPILE = lib.optionals cudaSupport [
# Pybind11 started generating strange errors since python 3.12. Observed only in the CUDA branch.
# https://gist.github.com/SomeoneSerge/7d390b2b1313957c378e99ed57168219#file-gistfile0-txt-L1042
"-Wno-stringop-overread"
];
# Avoid GLIBCXX mismatch with other cuda-enabled python packages
preConfigure = ''
# Ensure that the build process uses the requested number of cores
export MAX_JOBS="$NIX_BUILD_CORES"
# Upstream's setup.py tries to write cache somewhere in ~/
export HOME=$(mktemp -d)
# Upstream's github actions patch setup.cfg to write base-dir. May be redundant
echo "
[build_ext]
base-dir=$PWD" >> python/setup.cfg
# The rest (including buildPhase) is relative to ./python/
cd python
'';
env =
{
TRITON_BUILD_PROTON = "OFF";
TRITON_OFFLINE_BUILD = true;
}
// lib.optionalAttrs cudaSupport {
CC = lib.getExe' cudaPackages.backendStdenv.cc "cc";
CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++";
# TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change)
TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump";
TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm";
TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc;
TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart;
TRITON_CUPTI_PATH = cudaPackages.cuda_cupti;
};
pythonRemoveDeps = [
# Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374
"torch"
# CLI tools without dist-info
"cmake"
"lit"
];
# CMake is run by setup.py instead
dontUseCmakeConfigure = true;
nativeCheckInputs = [ cmake ];
preCheck = ''
# build/temp* refers to build_ext.build_temp (looked up in the build logs)
(cd ./build/temp* ; ctest)
'';
pythonImportsCheck = [
"triton"
"triton.language"
];
passthru.gpuCheck = stdenv.mkDerivation {
pname = "triton-pytest";
inherit (triton) version src;
requiredSystemFeatures = [ "cuda" ];
nativeBuildInputs = [
(python.withPackages (ps: [
ps.scipy
ps.torchWithCuda
ps.triton-cuda
]))
];
dontBuild = true;
nativeCheckInputs = [ pytestCheckHook ];
doCheck = true;
preCheck = ''
cd python/test/unit
export HOME=$TMPDIR
'';
checkPhase = "pytestCheckPhase";
installPhase = "touch $out";
};
passthru.tests = {
# Ultimately, torch is our test suite:
inherit torchWithRocm;
# Test as `nix run -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda`
# or, using `programs.nix-required-mounts`, as `nix build -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda.gpuCheck`
axpy-cuda =
cudaPackages.writeGpuTestPython
{
libraries = ps: [
ps.triton
ps.torch-no-triton
];
}
''
# Adopted from Philippe Tillet https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html
import triton
import triton.language as tl
import torch
import os
@triton.jit
def axpy_kernel(n, a: tl.constexpr, x_ptr, y_ptr, out, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = a * x + y
tl.store(out + offsets, output, mask=mask)
def axpy(a, x, y):
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
def grid(meta):
return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
axpy_kernel[grid](n_elements, a, x, y, output, BLOCK_SIZE=1024)
return output
if __name__ == "__main__":
if os.environ.get("HOME", None) == "/homeless-shelter":
os.environ["HOME"] = os.environ.get("TMPDIR", "/tmp")
if "CC" not in os.environ:
os.environ["CC"] = "${lib.getExe' cudaPackages.backendStdenv.cc "cc"}"
torch.manual_seed(0)
size = 12345
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = 3.14 * x + y
output_triton = axpy(3.14, x, y)
assert output_torch.sub(output_triton).abs().max().item() < 1e-6
print("Triton axpy: OK")
'';
};
meta = {
description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
homepage = "https://github.com/triton-lang/triton";
platforms = lib.platforms.linux;
license = lib.licenses.mit;
maintainers = with lib.maintainers; [
SomeoneSerge
Madouura
derdennisop
];
};
}