mirror of
https://github.com/rust-lang/rust.git
synced 2025-04-16 22:16:53 +00:00
working dupvonly for fwd mode
This commit is contained in:
parent
42cc67243d
commit
64718ab9ad
@ -56,6 +56,10 @@ pub enum DiffActivity {
|
||||
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
|
||||
/// with it. Drop the code which updates the original input/output for maximum performance.
|
||||
DualOnly,
|
||||
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
|
||||
/// with it. Drop the code which updates the original input/output for maximum performance.
|
||||
/// It expects the shadow argument to be `width` times larger than the original input/output.
|
||||
DualvOnly,
|
||||
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
|
||||
Duplicated,
|
||||
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
|
||||
@ -139,6 +143,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
|
||||
activity == DiffActivity::Dual
|
||||
|| activity == DiffActivity::Dualv
|
||||
|| activity == DiffActivity::DualOnly
|
||||
|| activity == DiffActivity::DualvOnly
|
||||
|| activity == DiffActivity::Const
|
||||
}
|
||||
DiffMode::Reverse => {
|
||||
@ -161,7 +166,7 @@ pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
|
||||
if matches!(activity, Const) {
|
||||
return true;
|
||||
}
|
||||
if matches!(activity, Dual | DualOnly | Dualv) {
|
||||
if matches!(activity, Dual | DualOnly | Dualv | DualvOnly) {
|
||||
return true;
|
||||
}
|
||||
// FIXME(ZuseZ4) We should make this more robust to also
|
||||
@ -178,7 +183,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
|
||||
DiffMode::Error => false,
|
||||
DiffMode::Source => false,
|
||||
DiffMode::Forward => {
|
||||
matches!(activity, Dual | DualOnly | Dualv | Const)
|
||||
matches!(activity, Dual | DualOnly | Dualv | DualvOnly | Const)
|
||||
}
|
||||
DiffMode::Reverse => {
|
||||
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
|
||||
@ -196,6 +201,7 @@ impl Display for DiffActivity {
|
||||
DiffActivity::Dual => write!(f, "Dual"),
|
||||
DiffActivity::Dualv => write!(f, "Dualv"),
|
||||
DiffActivity::DualOnly => write!(f, "DualOnly"),
|
||||
DiffActivity::DualvOnly => write!(f, "DualvOnly"),
|
||||
DiffActivity::Duplicated => write!(f, "Duplicated"),
|
||||
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
|
||||
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
|
||||
@ -228,6 +234,7 @@ impl FromStr for DiffActivity {
|
||||
"Dual" => Ok(DiffActivity::Dual),
|
||||
"Dualv" => Ok(DiffActivity::Dualv),
|
||||
"DualOnly" => Ok(DiffActivity::DualOnly),
|
||||
"DualvOnly" => Ok(DiffActivity::DualvOnly),
|
||||
"Duplicated" => Ok(DiffActivity::Duplicated),
|
||||
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
|
||||
_ => Err(()),
|
||||
|
@ -799,12 +799,18 @@ mod llvm_enzyme {
|
||||
d_inputs.push(shadow_arg.clone());
|
||||
}
|
||||
}
|
||||
DiffActivity::Dual | DiffActivity::DualOnly | DiffActivity::Dualv => {
|
||||
let iterations = if matches!(activity, DiffActivity::Dualv) {
|
||||
1
|
||||
} else {
|
||||
x.width
|
||||
};
|
||||
DiffActivity::Dual
|
||||
| DiffActivity::DualOnly
|
||||
| DiffActivity::Dualv
|
||||
| DiffActivity::DualvOnly => {
|
||||
// the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
|
||||
// Enzyme to not expect N arguments, but one argument (which is instead larger).
|
||||
let iterations =
|
||||
if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
|
||||
1
|
||||
} else {
|
||||
x.width
|
||||
};
|
||||
for i in 0..iterations {
|
||||
let mut shadow_arg = arg.clone();
|
||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||
@ -908,7 +914,7 @@ mod llvm_enzyme {
|
||||
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
||||
d_decl.output = FnRetTy::Ty(ty);
|
||||
}
|
||||
if let DiffActivity::DualOnly = x.ret_activity {
|
||||
if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
|
||||
// No need to change the return type,
|
||||
// we will just return the shadow in place of the primal return.
|
||||
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]
|
||||
|
@ -51,7 +51,7 @@ fn has_sret(fnc: &Value) -> bool {
|
||||
// using iterators and peek()?
|
||||
fn match_args_from_caller_to_enzyme<'ll>(
|
||||
cx: &SimpleCx<'ll>,
|
||||
builder: &SBuilder<'ll,'ll>,
|
||||
builder: &SBuilder<'ll, 'll>,
|
||||
width: u32,
|
||||
args: &mut Vec<&'ll llvm::Value>,
|
||||
inputs: &[DiffActivity],
|
||||
@ -81,6 +81,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
|
||||
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
|
||||
let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap();
|
||||
let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
|
||||
let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap();
|
||||
|
||||
while activity_pos < inputs.len() {
|
||||
let diff_activity = inputs[activity_pos as usize];
|
||||
@ -94,6 +95,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
|
||||
DiffActivity::Dual => (enzyme_dup, true),
|
||||
DiffActivity::Dualv => (enzyme_dupv, true),
|
||||
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
|
||||
DiffActivity::DualvOnly => (enzyme_dupnoneedv, true),
|
||||
DiffActivity::Duplicated => (enzyme_dup, true),
|
||||
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
|
||||
DiffActivity::FakeActivitySize => (enzyme_const, false),
|
||||
@ -106,10 +108,9 @@ fn match_args_from_caller_to_enzyme<'ll>(
|
||||
// T=f32 => 4 bytes
|
||||
// n_elems is the next integer.
|
||||
// Now we multiply `4 * next_outer_arg` to get the stride.
|
||||
//let mul = builder
|
||||
// .build_mul(cx.get_const_i64(4), next_outer_arg)
|
||||
// .unwrap();
|
||||
let mul = unsafe {llvm::LLVMBuildMul(builder.llbuilder, cx.get_const_i64(4), next_outer_arg, UNNAMED)};
|
||||
let mul = unsafe {
|
||||
llvm::LLVMBuildMul(builder.llbuilder, cx.get_const_i64(4), next_outer_arg, UNNAMED)
|
||||
};
|
||||
args.push(mul);
|
||||
}
|
||||
args.push(outer_arg);
|
||||
@ -140,11 +141,8 @@ fn match_args_from_caller_to_enzyme<'ll>(
|
||||
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
|
||||
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
|
||||
|
||||
let iterations = if matches!(diff_activity, DiffActivity::Dualv) {
|
||||
1
|
||||
} else {
|
||||
width as usize
|
||||
};
|
||||
let iterations =
|
||||
if matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize };
|
||||
|
||||
for i in 0..iterations {
|
||||
let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
|
||||
|
@ -40,6 +40,9 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
|
||||
new_activities.push(activity);
|
||||
new_positions.push(i + 1);
|
||||
}
|
||||
// Now we need to figure out the size of each slice element in memory.
|
||||
// Can we actually do that here?
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user