From 64718ab9adfd5564ec2b18b954eebf736133b808 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 8 Apr 2025 20:59:26 -0400 Subject: [PATCH] working dupvonly for fwd mode --- .../rustc_ast/src/expand/autodiff_attrs.rs | 11 ++++++++-- compiler/rustc_builtin_macros/src/autodiff.rs | 20 ++++++++++++------- .../src/builder/autodiff.rs | 18 ++++++++--------- .../src/partitioning/autodiff.rs | 3 +++ 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 0c8ba1b0154..63f4c8c7e62 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -56,6 +56,10 @@ pub enum DiffActivity { /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument /// with it. Drop the code which updates the original input/output for maximum performance. DualOnly, + /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument + /// with it. Drop the code which updates the original input/output for maximum performance. + /// It expects the shadow argument to be `width` times larger than the original input/output. + DualvOnly, /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. Duplicated, /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. @@ -139,6 +143,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { activity == DiffActivity::Dual || activity == DiffActivity::Dualv || activity == DiffActivity::DualOnly + || activity == DiffActivity::DualvOnly || activity == DiffActivity::Const } DiffMode::Reverse => { @@ -161,7 +166,7 @@ pub fn valid_ty_for_activity(ty: &P, activity: DiffActivity) -> bool { if matches!(activity, Const) { return true; } - if matches!(activity, Dual | DualOnly | Dualv) { + if matches!(activity, Dual | DualOnly | Dualv | DualvOnly) { return true; } // FIXME(ZuseZ4) We should make this more robust to also @@ -178,7 +183,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool { DiffMode::Error => false, DiffMode::Source => false, DiffMode::Forward => { - matches!(activity, Dual | DualOnly | Dualv | Const) + matches!(activity, Dual | DualOnly | Dualv | DualvOnly | Const) } DiffMode::Reverse => { matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const) @@ -196,6 +201,7 @@ impl Display for DiffActivity { DiffActivity::Dual => write!(f, "Dual"), DiffActivity::Dualv => write!(f, "Dualv"), DiffActivity::DualOnly => write!(f, "DualOnly"), + DiffActivity::DualvOnly => write!(f, "DualvOnly"), DiffActivity::Duplicated => write!(f, "Duplicated"), DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"), DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"), @@ -228,6 +234,7 @@ impl FromStr for DiffActivity { "Dual" => Ok(DiffActivity::Dual), "Dualv" => Ok(DiffActivity::Dualv), "DualOnly" => Ok(DiffActivity::DualOnly), + "DualvOnly" => Ok(DiffActivity::DualvOnly), "Duplicated" => Ok(DiffActivity::Duplicated), "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly), _ => Err(()), diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 183d55293e6..763378c99bf 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -799,12 +799,18 @@ mod llvm_enzyme { d_inputs.push(shadow_arg.clone()); } } - DiffActivity::Dual | DiffActivity::DualOnly | DiffActivity::Dualv => { - let iterations = if matches!(activity, DiffActivity::Dualv) { - 1 - } else { - x.width - }; + DiffActivity::Dual + | DiffActivity::DualOnly + | DiffActivity::Dualv + | DiffActivity::DualvOnly => { + // the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause + // Enzyme to not expect N arguments, but one argument (which is instead larger). + let iterations = + if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) { + 1 + } else { + x.width + }; for i in 0..iterations { let mut shadow_arg = arg.clone(); let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { @@ -908,7 +914,7 @@ mod llvm_enzyme { let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }); d_decl.output = FnRetTy::Ty(ty); } - if let DiffActivity::DualOnly = x.ret_activity { + if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) { // No need to change the return type, // we will just return the shadow in place of the primal return. // However, if we have a width > 1, then we don't return -> T, but -> [T; width] diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index ea962e164fb..db1b311a9fe 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -51,7 +51,7 @@ fn has_sret(fnc: &Value) -> bool { // using iterators and peek()? fn match_args_from_caller_to_enzyme<'ll>( cx: &SimpleCx<'ll>, - builder: &SBuilder<'ll,'ll>, + builder: &SBuilder<'ll, 'll>, width: u32, args: &mut Vec<&'ll llvm::Value>, inputs: &[DiffActivity], @@ -81,6 +81,7 @@ fn match_args_from_caller_to_enzyme<'ll>( let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap(); let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap(); let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap(); + let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap(); while activity_pos < inputs.len() { let diff_activity = inputs[activity_pos as usize]; @@ -94,6 +95,7 @@ fn match_args_from_caller_to_enzyme<'ll>( DiffActivity::Dual => (enzyme_dup, true), DiffActivity::Dualv => (enzyme_dupv, true), DiffActivity::DualOnly => (enzyme_dupnoneed, true), + DiffActivity::DualvOnly => (enzyme_dupnoneedv, true), DiffActivity::Duplicated => (enzyme_dup, true), DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true), DiffActivity::FakeActivitySize => (enzyme_const, false), @@ -106,10 +108,9 @@ fn match_args_from_caller_to_enzyme<'ll>( // T=f32 => 4 bytes // n_elems is the next integer. // Now we multiply `4 * next_outer_arg` to get the stride. - //let mul = builder - // .build_mul(cx.get_const_i64(4), next_outer_arg) - // .unwrap(); - let mul = unsafe {llvm::LLVMBuildMul(builder.llbuilder, cx.get_const_i64(4), next_outer_arg, UNNAMED)}; + let mul = unsafe { + llvm::LLVMBuildMul(builder.llbuilder, cx.get_const_i64(4), next_outer_arg, UNNAMED) + }; args.push(mul); } args.push(outer_arg); @@ -140,11 +141,8 @@ fn match_args_from_caller_to_enzyme<'ll>( // int2 >= int1, which means the shadow vector is large enough to store the gradient. assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer); - let iterations = if matches!(diff_activity, DiffActivity::Dualv) { - 1 - } else { - width as usize - }; + let iterations = + if matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize }; for i in 0..iterations { let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)]; diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs index 27189ea6713..6c83fe758a5 100644 --- a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs +++ b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs @@ -40,6 +40,9 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec new_activities.push(activity); new_positions.push(i + 1); } + // Now we need to figure out the size of each slice element in memory. + // Can we actually do that here? + continue; } }