mirror of
https://github.com/NixOS/nixpkgs.git
synced 2024-11-02 15:41:48 +00:00
python3Packages.openai-triton: init at 2.0.0
This commit is contained in:
parent
0f76efb481
commit
378c0c6983
@ -24,6 +24,8 @@
|
||||
, targetDir ? "llvm"
|
||||
, targetProjects ? [ ]
|
||||
, targetRuntimes ? [ ]
|
||||
# "NATIVE" resolves into x86 or aarch64 depending on stdenv
|
||||
, llvmTargetsToBuild ? [ "NATIVE" ]
|
||||
, extraPatches ? [ ]
|
||||
, extraNativeBuildInputs ? [ ]
|
||||
, extraBuildInputs ? [ ]
|
||||
@ -46,6 +48,8 @@ let
|
||||
if stdenv.isx86_64 then "X86"
|
||||
else if stdenv.isAarch64 then "AArch64"
|
||||
else throw "Unsupported ROCm LLVM platform";
|
||||
inferNativeTarget = t: if t == "NATIVE" then llvmNativeTarget else t;
|
||||
llvmTargetsToBuild' = [ "AMDGPU" ] ++ builtins.map inferNativeTarget llvmTargetsToBuild;
|
||||
in stdenv.mkDerivation (finalAttrs: {
|
||||
pname = "rocm-llvm-${targetName}";
|
||||
version = "5.4.4";
|
||||
@ -98,7 +102,7 @@ in stdenv.mkDerivation (finalAttrs: {
|
||||
sourceRoot = "${finalAttrs.src.name}/${targetDir}";
|
||||
|
||||
cmakeFlags = [
|
||||
"-DLLVM_TARGETS_TO_BUILD=AMDGPU;${llvmNativeTarget}"
|
||||
"-DLLVM_TARGETS_TO_BUILD=${builtins.concatStringsSep ";" llvmTargetsToBuild'}"
|
||||
] ++ lib.optionals (finalAttrs.passthru.isLLVM && targetProjects != [ ]) [
|
||||
"-DLLVM_ENABLE_PROJECTS=${lib.concatStringsSep ";" targetProjects}"
|
||||
] ++ lib.optionals ((finalAttrs.passthru.isLLVM || targetDir == "runtimes") && targetRuntimes != [ ]) [
|
||||
|
246
pkgs/development/python-modules/openai-triton/default.nix
Normal file
246
pkgs/development/python-modules/openai-triton/default.nix
Normal file
@ -0,0 +1,246 @@
|
||||
{ lib
|
||||
, buildPythonPackage
|
||||
, python
|
||||
, fetchpatch
|
||||
, fetchFromGitHub
|
||||
, addOpenGLRunpath
|
||||
, cmake
|
||||
, cudaPackages
|
||||
, llvmPackages
|
||||
, pybind11
|
||||
, gtest
|
||||
, zlib
|
||||
, ncurses
|
||||
, libxml2
|
||||
, lit
|
||||
, filelock
|
||||
, torchWithRocm
|
||||
, pytest
|
||||
, pytestCheckHook
|
||||
, pythonRelaxDepsHook
|
||||
, pkgsTargetTarget
|
||||
}:
|
||||
|
||||
let
|
||||
pname = "triton";
|
||||
version = "2.0.0";
|
||||
|
||||
inherit (cudaPackages) cuda_cudart backendStdenv;
|
||||
ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas";
|
||||
|
||||
llvm = (llvmPackages.llvm.override {
|
||||
llvmTargetsToBuild = [ "NATIVE" "NVPTX" ];
|
||||
# Upstream CI sets these too:
|
||||
# targetProjects = [ "mlir" ];
|
||||
extraCMakeFlags = [
|
||||
"-DLLVM_INSTALL_UTILS=ON"
|
||||
];
|
||||
});
|
||||
in
|
||||
buildPythonPackage {
|
||||
inherit pname version;
|
||||
|
||||
format = "setuptools";
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "openai";
|
||||
repo = pname;
|
||||
rev = "v${version}";
|
||||
hash = "sha256-9GZzugab+Pdt74Dj6zjlEzjj4BcJ69rzMJmqcVMxsKU=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
# Prerequisite for llvm15 patch
|
||||
(fetchpatch {
|
||||
url = "https://github.com/openai/triton/commit/2aba985daaa70234823ea8f1161da938477d3e02.patch";
|
||||
hash = "sha256-LGv0+Ut2WYPC4Ksi4803Hwmhi3FyQOF9zElJc/JCobk=";
|
||||
})
|
||||
(fetchpatch {
|
||||
url = "https://github.com/openai/triton/commit/e3941f9d09cdd31529ba4a41018cfc0096aafea6.patch";
|
||||
hash = "sha256-A+Gor6qzFlGQhVVhiaaYOzqqx8yO2MdssnQS6TIfUWg=";
|
||||
})
|
||||
|
||||
# Source: https://github.com/openai/triton/commit/fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a.patch
|
||||
# The original patch adds ptxas binary, so we include our own clean copy
|
||||
# Drop with the next update
|
||||
./llvm15.patch
|
||||
|
||||
# TODO: there have been commits upstream aimed at removing the "torch"
|
||||
# circular dependency, but the patches fail to apply on the release
|
||||
# revision. Keeping the link for future reference
|
||||
# Also cf. https://github.com/openai/triton/issues/1374
|
||||
|
||||
# (fetchpatch {
|
||||
# url = "https://github.com/openai/triton/commit/fc7c0b0e437a191e421faa61494b2ff4870850f1.patch";
|
||||
# hash = "sha256-f0shIqHJkVvuil2Yku7vuqWFn7VCRKFSFjYRlwx25ig=";
|
||||
# })
|
||||
];
|
||||
|
||||
postPatch = ''
|
||||
substituteInPlace python/setup.py \
|
||||
--replace \
|
||||
'= get_thirdparty_packages(triton_cache_path)' \
|
||||
'= os.environ["cmakeFlags"].split()'
|
||||
''
|
||||
# Wiring triton=2.0.0 with llcmPackages_rocm.llvm=5.4.3
|
||||
# Revisit when updating either triton or llvm
|
||||
+ ''
|
||||
substituteInPlace CMakeLists.txt \
|
||||
--replace "nvptx" "NVPTX" \
|
||||
--replace "LLVM 11" "LLVM"
|
||||
sed -i '/AddMLIR/a set(MLIR_TABLEGEN_EXE "${llvmPackages.mlir}/bin/mlir-tblgen")' CMakeLists.txt
|
||||
sed -i '/AddMLIR/a set(MLIR_INCLUDE_DIR ''${MLIR_INCLUDE_DIRS})' CMakeLists.txt
|
||||
find -iname '*.td' -exec \
|
||||
sed -i \
|
||||
-e '\|include "mlir/IR/OpBase.td"|a include "mlir/IR/AttrTypeBase.td"' \
|
||||
-e 's|include "mlir/Dialect/StandardOps/IR/Ops.td"|include "mlir/Dialect/Func/IR/FuncOps.td"|' \
|
||||
'{}' ';'
|
||||
substituteInPlace unittest/CMakeLists.txt --replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
|
||||
sed -i 's/^include.*$//' unittest/CMakeLists.txt
|
||||
sed -i '/LINK_LIBS/i NVPTXInfo' lib/Target/PTX/CMakeLists.txt
|
||||
sed -i '/LINK_LIBS/i NVPTXCodeGen' lib/Target/PTX/CMakeLists.txt
|
||||
''
|
||||
# TritonMLIRIR already links MLIRIR. Not transitive?
|
||||
# + ''
|
||||
# echo "target_link_libraries(TritonPTX PUBLIC MLIRIR)" >> lib/Target/PTX/CMakeLists.txt
|
||||
# ''
|
||||
# Already defined in llvm, when built with -DLLVM_INSTALL_UTILS
|
||||
+ ''
|
||||
substituteInPlace bin/CMakeLists.txt \
|
||||
--replace "add_subdirectory(FileCheck)" ""
|
||||
|
||||
rm cmake/FindLLVM.cmake
|
||||
''
|
||||
+
|
||||
(
|
||||
let
|
||||
# Bash was getting weird without linting,
|
||||
# but basically upstream contains [cc, ..., "-lcuda", ...]
|
||||
# and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
|
||||
old = [ "-lcuda" ];
|
||||
new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cuda_cudart}/lib/stubs/" ];
|
||||
|
||||
quote = x: ''"${x}"'';
|
||||
oldStr = lib.concatMapStringsSep ", " quote old;
|
||||
newStr = lib.concatMapStringsSep ", " quote new;
|
||||
in
|
||||
''
|
||||
substituteInPlace python/triton/compiler.py \
|
||||
--replace '${oldStr}' '${newStr}'
|
||||
''
|
||||
)
|
||||
# Triton seems to be looking up cuda.h
|
||||
+ ''
|
||||
sed -i 's|cu_include_dir = os.path.join.*$|cu_include_dir = "${cuda_cudart}/include"|' python/triton/compiler.py
|
||||
'';
|
||||
|
||||
nativeBuildInputs = [
|
||||
cmake
|
||||
pythonRelaxDepsHook
|
||||
|
||||
# Requires torch (circular dependency) and probably needs GPUs:
|
||||
# pytestCheckHook
|
||||
|
||||
# Note for future:
|
||||
# These *probably* should go in depsTargetTarget
|
||||
# ...but we cannot test cross right now anyway
|
||||
# because we only support cudaPackages on x86_64-linux atm
|
||||
lit
|
||||
llvm
|
||||
llvmPackages.mlir
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
gtest
|
||||
libxml2.dev
|
||||
ncurses
|
||||
pybind11
|
||||
zlib
|
||||
];
|
||||
|
||||
propagatedBuildInputs = [
|
||||
filelock
|
||||
];
|
||||
|
||||
# Avoid GLIBCXX mismatch with other cuda-enabled python packages
|
||||
preConfigure =
|
||||
''
|
||||
export CC="${backendStdenv.cc}/bin/cc";
|
||||
export CXX="${backendStdenv.cc}/bin/c++";
|
||||
''
|
||||
# Upstream's setup.py tries to write cache somewhere in ~/
|
||||
+ ''
|
||||
export HOME=$TMPDIR
|
||||
''
|
||||
# Upstream's github actions patch setup.cfg to write base-dir. May be redundant
|
||||
+ ''
|
||||
echo "" >> python/setup.cfg
|
||||
echo "[build_ext]" >> python/setup.cfg
|
||||
echo "base-dir=$PWD" >> python/setup.cfg
|
||||
''
|
||||
# The rest (including buildPhase) is relative to ./python/
|
||||
+ ''
|
||||
cd python/
|
||||
''
|
||||
# Work around download_and_copy_ptxas()
|
||||
+ ''
|
||||
dst_cuda="$PWD/triton/third_party/cuda/bin"
|
||||
mkdir -p "$dst_cuda"
|
||||
ln -s "${ptxas}" "$dst_cuda/"
|
||||
'';
|
||||
|
||||
# CMake is run by setup.py instead
|
||||
dontUseCmakeConfigure = true;
|
||||
cmakeFlags = [
|
||||
"-DMLIR_DIR=${llvmPackages.mlir}/lib/cmake/mlir"
|
||||
];
|
||||
|
||||
postFixup =
|
||||
let
|
||||
ptxasDestination = "$out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas";
|
||||
in
|
||||
# Setuptools (?) strips runpath and +x flags. Let's just restore the symlink
|
||||
''
|
||||
rm -f ${ptxasDestination}
|
||||
ln -s ${ptxas} ${ptxasDestination}
|
||||
'';
|
||||
|
||||
checkInputs = [
|
||||
cmake # ctest
|
||||
];
|
||||
dontUseSetuptoolsCheck = true;
|
||||
preCheck =
|
||||
# build/temp* refers to build_ext.build_temp (looked up in the build logs)
|
||||
''
|
||||
(cd /build/source/python/build/temp* ; ctest)
|
||||
'' # For pytestCheckHook
|
||||
+ ''
|
||||
cd test/unit
|
||||
'';
|
||||
pythonImportsCheck = [
|
||||
# Circular dependency on torch
|
||||
# "triton"
|
||||
# "triton.language"
|
||||
];
|
||||
|
||||
# Ultimately, torch is our test suite:
|
||||
passthru.tests = {
|
||||
inherit torchWithRocm;
|
||||
};
|
||||
|
||||
pythonRemoveDeps = [
|
||||
# Circular dependency, cf. https://github.com/openai/triton/issues/1374
|
||||
"torch"
|
||||
|
||||
# CLI tools without dist-info
|
||||
"cmake"
|
||||
"lit"
|
||||
];
|
||||
meta = with lib; {
|
||||
description = "Development repository for the Triton language and compiler";
|
||||
homepage = "https://github.com/openai/triton/";
|
||||
platforms = lib.platforms.unix;
|
||||
license = licenses.mit;
|
||||
maintainers = with maintainers; [ SomeoneSerge ];
|
||||
};
|
||||
}
|
4617
pkgs/development/python-modules/openai-triton/llvm15.patch
Normal file
4617
pkgs/development/python-modules/openai-triton/llvm15.patch
Normal file
File diff suppressed because it is too large
Load Diff
@ -6,6 +6,7 @@
|
||||
|
||||
# Native build inputs
|
||||
cmake, util-linux, linkFarm, symlinkJoin, which, pybind11, removeReferencesTo,
|
||||
pythonRelaxDepsHook,
|
||||
|
||||
# Build inputs
|
||||
numactl,
|
||||
@ -13,9 +14,10 @@
|
||||
|
||||
# Propagated build inputs
|
||||
filelock,
|
||||
sympy,
|
||||
networkx,
|
||||
jinja2,
|
||||
networkx,
|
||||
openai-triton,
|
||||
sympy,
|
||||
numpy, pyyaml, cffi, click, typing-extensions,
|
||||
|
||||
# Unit tests
|
||||
@ -271,6 +273,7 @@ in buildPythonPackage rec {
|
||||
which
|
||||
ninja
|
||||
pybind11
|
||||
pythonRelaxDepsHook
|
||||
removeReferencesTo
|
||||
] ++ lib.optionals cudaSupport [ cudatoolkit_joined ]
|
||||
++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
|
||||
@ -298,8 +301,17 @@ in buildPythonPackage rec {
|
||||
|
||||
# the following are required for tensorboard support
|
||||
pillow six future tensorboard protobuf
|
||||
] ++ lib.optionals MPISupport [ mpi ]
|
||||
++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
|
||||
]
|
||||
++ lib.optionals MPISupport [ mpi ]
|
||||
++ lib.optionals rocmSupport [ rocmtoolkit_joined ]
|
||||
# rocm build requires openai-triton;
|
||||
# openai-triton currently requires cuda_nvcc,
|
||||
# so not including it in the cpu-only build;
|
||||
# torch.compile relies on openai-triton,
|
||||
# so we include it for the cuda build as well
|
||||
++ lib.optionals (rocmSupport || cudaSupport) [
|
||||
openai-triton
|
||||
];
|
||||
|
||||
# Tests take a long time and may be flaky, so just sanity-check imports
|
||||
doCheck = false;
|
||||
@ -327,6 +339,11 @@ in buildPythonPackage rec {
|
||||
"runHook postCheck"
|
||||
];
|
||||
|
||||
pythonRemoveDeps = [
|
||||
# In our dist-info the name is just "triton"
|
||||
"pytorch-triton-rocm"
|
||||
];
|
||||
|
||||
postInstall = ''
|
||||
find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +
|
||||
|
||||
|
@ -184,6 +184,8 @@ python.pkgs.buildPythonApplication rec {
|
||||
"tests/vocoder_tests/test_multiband_melgan_train.py"
|
||||
"tests/vocoder_tests/test_melgan_train.py"
|
||||
"tests/vocoder_tests/test_wavernn_train.py"
|
||||
# only a feed forward test, but still takes too long
|
||||
"tests/tts_tests/test_overflow.py"
|
||||
];
|
||||
|
||||
passthru = {
|
||||
|
@ -6801,6 +6801,8 @@ self: super: with self; {
|
||||
|
||||
open-meteo = callPackage ../development/python-modules/open-meteo { };
|
||||
|
||||
openai-triton = callPackage ../development/python-modules/openai-triton { llvmPackages = pkgs.llvmPackages_rocm; };
|
||||
|
||||
openai-whisper = callPackage ../development/python-modules/openai-whisper { };
|
||||
|
||||
openant = callPackage ../development/python-modules/openant { };
|
||||
|
Loading…
Reference in New Issue
Block a user