diff --git a/pkgs/development/python-modules/functorch/default.nix b/pkgs/development/python-modules/functorch/default.nix new file mode 100644 index 000000000000..05b96077edc9 --- /dev/null +++ b/pkgs/development/python-modules/functorch/default.nix @@ -0,0 +1,98 @@ +{ buildPythonPackage +, expecttest +, fetchFromGitHub +, lib +, ninja +, pytestCheckHook +, python +, pytorch +, which +}: + +buildPythonPackage rec { + pname = "functorch"; + version = "0.1.1"; + format = "setuptools"; + + src = fetchFromGitHub { + owner = "pytorch"; + repo = pname; + rev = "v${version}"; + hash = "sha256-FidM04Q3hkGEDr4dthJv0MWtGiRfnWxJoyzu7Wl3SD8="; + }; + + # Somewhat surprisingly pytorch is actually necessary for the build process. + # `setup.py` imports `torch.utils.cpp_extension`. + nativeBuildInputs = [ + ninja + pytorch + which + ]; + + preCheck = '' + rm -rf functorch/ + ''; + + checkInputs = [ + expecttest + pytestCheckHook + ]; + + # See https://github.com/pytorch/functorch/issues/835. + disabledTests = [ + # RuntimeError: ("('...', '') is in PyTorch's OpInfo db ", "but is not in functorch's OpInfo db. Please regenerate ", '... and add the new tests to ', 'denylists if necessary.') + "test_coverage_bernoulli_cpu_float32" + "test_coverage_column_stack_cpu_float32" + "test_coverage_diagflat_cpu_float32" + "test_coverage_flatten_cpu_float32" + "test_coverage_linalg_lu_factor_cpu_float32" + "test_coverage_linalg_lu_factor_ex_cpu_float32" + "test_coverage_multinomial_cpu_float32" + "test_coverage_nn_functional_dropout2d_cpu_float32" + "test_coverage_nn_functional_feature_alpha_dropout_with_train_cpu_float32" + "test_coverage_nn_functional_feature_alpha_dropout_without_train_cpu_float32" + "test_coverage_nn_functional_kl_div_cpu_float32" + "test_coverage_normal_cpu_float32" + "test_coverage_normal_number_mean_cpu_float32" + "test_coverage_pca_lowrank_cpu_float32" + "test_coverage_round_decimals_0_cpu_float32" + "test_coverage_round_decimals_3_cpu_float32" + "test_coverage_round_decimals_neg_3_cpu_float32" + "test_coverage_scatter_reduce_cpu_float32" + "test_coverage_svd_lowrank_cpu_float32" + + # > self.assertEqual(len(functorch_lagging_op_db), len(op_db)) + # E AssertionError: Scalars are not equal! + # E + # E Absolute difference: 19 + # E Relative difference: 0.03525046382189239 + "test_functorch_lagging_op_db_has_opinfos_cpu" + + # RuntimeError: PyTorch not compiled with LLVM support! + "test_bias_gelu" + "test_binary_ops" + "test_broadcast1" + "test_broadcast2" + "test_float_double" + "test_float_int" + "test_fx_trace" + "test_int_long" + "test_issue57611" + "test_slice1" + "test_slice2" + "test_transposed1" + "test_transposed2" + "test_unary_ops" + ]; + + pythonImportsCheck = [ "functorch" ]; + + meta = with lib; { + description = "JAX-like composable function transforms for PyTorch"; + homepage = "https://pytorch.org/functorch"; + license = licenses.bsd3; + maintainers = with maintainers; [ samuela ]; + # See https://github.com/NixOS/nixpkgs/pull/174248#issuecomment-1139895064. + platforms = platforms.x86_64; + }; +} diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index 4876e53e89d5..23b6c937572c 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -3260,6 +3260,8 @@ in { functools32 = callPackage ../development/python-modules/functools32 { }; + functorch = callPackage ../development/python-modules/functorch { }; + funcy = callPackage ../development/python-modules/funcy { }; furl = callPackage ../development/python-modules/furl { };