working dupvonly for fwd mode

This commit is contained in:
Manuel Drehwald 2025-04-08 20:59:26 -04:00
parent 42cc67243d
commit 64718ab9ad
4 changed files with 33 additions and 19 deletions

View File

@ -56,6 +56,10 @@ pub enum DiffActivity {
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it. Drop the code which updates the original input/output for maximum performance.
DualOnly,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it. Drop the code which updates the original input/output for maximum performance.
/// It expects the shadow argument to be `width` times larger than the original input/output.
DualvOnly,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
Duplicated,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
@ -139,6 +143,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
activity == DiffActivity::Dual
|| activity == DiffActivity::Dualv
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::DualvOnly
|| activity == DiffActivity::Const
}
DiffMode::Reverse => {
@ -161,7 +166,7 @@ pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
if matches!(activity, Const) {
return true;
}
if matches!(activity, Dual | DualOnly | Dualv) {
if matches!(activity, Dual | DualOnly | Dualv | DualvOnly) {
return true;
}
// FIXME(ZuseZ4) We should make this more robust to also
@ -178,7 +183,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
DiffMode::Error => false,
DiffMode::Source => false,
DiffMode::Forward => {
matches!(activity, Dual | DualOnly | Dualv | Const)
matches!(activity, Dual | DualOnly | Dualv | DualvOnly | Const)
}
DiffMode::Reverse => {
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
@ -196,6 +201,7 @@ impl Display for DiffActivity {
DiffActivity::Dual => write!(f, "Dual"),
DiffActivity::Dualv => write!(f, "Dualv"),
DiffActivity::DualOnly => write!(f, "DualOnly"),
DiffActivity::DualvOnly => write!(f, "DualvOnly"),
DiffActivity::Duplicated => write!(f, "Duplicated"),
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
@ -228,6 +234,7 @@ impl FromStr for DiffActivity {
"Dual" => Ok(DiffActivity::Dual),
"Dualv" => Ok(DiffActivity::Dualv),
"DualOnly" => Ok(DiffActivity::DualOnly),
"DualvOnly" => Ok(DiffActivity::DualvOnly),
"Duplicated" => Ok(DiffActivity::Duplicated),
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
_ => Err(()),

View File

@ -799,12 +799,18 @@ mod llvm_enzyme {
d_inputs.push(shadow_arg.clone());
}
}
DiffActivity::Dual | DiffActivity::DualOnly | DiffActivity::Dualv => {
let iterations = if matches!(activity, DiffActivity::Dualv) {
1
} else {
x.width
};
DiffActivity::Dual
| DiffActivity::DualOnly
| DiffActivity::Dualv
| DiffActivity::DualvOnly => {
// the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
// Enzyme to not expect N arguments, but one argument (which is instead larger).
let iterations =
if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
1
} else {
x.width
};
for i in 0..iterations {
let mut shadow_arg = arg.clone();
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
@ -908,7 +914,7 @@ mod llvm_enzyme {
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
d_decl.output = FnRetTy::Ty(ty);
}
if let DiffActivity::DualOnly = x.ret_activity {
if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
// No need to change the return type,
// we will just return the shadow in place of the primal return.
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]

View File

@ -51,7 +51,7 @@ fn has_sret(fnc: &Value) -> bool {
// using iterators and peek()?
fn match_args_from_caller_to_enzyme<'ll>(
cx: &SimpleCx<'ll>,
builder: &SBuilder<'ll,'ll>,
builder: &SBuilder<'ll, 'll>,
width: u32,
args: &mut Vec<&'ll llvm::Value>,
inputs: &[DiffActivity],
@ -81,6 +81,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap();
let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap();
while activity_pos < inputs.len() {
let diff_activity = inputs[activity_pos as usize];
@ -94,6 +95,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
DiffActivity::Dual => (enzyme_dup, true),
DiffActivity::Dualv => (enzyme_dupv, true),
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
DiffActivity::DualvOnly => (enzyme_dupnoneedv, true),
DiffActivity::Duplicated => (enzyme_dup, true),
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
DiffActivity::FakeActivitySize => (enzyme_const, false),
@ -106,10 +108,9 @@ fn match_args_from_caller_to_enzyme<'ll>(
// T=f32 => 4 bytes
// n_elems is the next integer.
// Now we multiply `4 * next_outer_arg` to get the stride.
//let mul = builder
// .build_mul(cx.get_const_i64(4), next_outer_arg)
// .unwrap();
let mul = unsafe {llvm::LLVMBuildMul(builder.llbuilder, cx.get_const_i64(4), next_outer_arg, UNNAMED)};
let mul = unsafe {
llvm::LLVMBuildMul(builder.llbuilder, cx.get_const_i64(4), next_outer_arg, UNNAMED)
};
args.push(mul);
}
args.push(outer_arg);
@ -140,11 +141,8 @@ fn match_args_from_caller_to_enzyme<'ll>(
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
let iterations = if matches!(diff_activity, DiffActivity::Dualv) {
1
} else {
width as usize
};
let iterations =
if matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize };
for i in 0..iterations {
let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];

View File

@ -40,6 +40,9 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
new_activities.push(activity);
new_positions.push(i + 1);
}
// Now we need to figure out the size of each slice element in memory.
// Can we actually do that here?
continue;
}
}