diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 63f4c8c7e62..13b15d020e0 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -69,6 +69,12 @@ pub enum DiffActivity { /// length of a slice/vec. This is used for safety checks on slices. FakeActivitySize, } + +impl DiffActivity { + pub fn is_dual_or_const(&self) -> bool { + use DiffActivity::*; + matches!(self, |Dual | DualOnly | Dualv | DualvOnly | Const) + } /// We generate one of these structs for each `#[autodiff(...)]` attribute. #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub struct AutoDiffItem { @@ -140,11 +146,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { DiffMode::Error => false, DiffMode::Source => false, DiffMode::Forward => { - activity == DiffActivity::Dual - || activity == DiffActivity::Dualv - || activity == DiffActivity::DualOnly - || activity == DiffActivity::DualvOnly - || activity == DiffActivity::Const + activity.is_dual_or_const() } DiffMode::Reverse => { activity == DiffActivity::Const @@ -163,10 +165,8 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { pub fn valid_ty_for_activity(ty: &P, activity: DiffActivity) -> bool { use DiffActivity::*; // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it. - if matches!(activity, Const) { - return true; - } - if matches!(activity, Dual | DualOnly | Dualv | DualvOnly) { + // Dual variants also support all types. + if activity.is_dual_or_const() { return true; } // FIXME(ZuseZ4) We should make this more robust to also @@ -183,7 +183,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool { DiffMode::Error => false, DiffMode::Source => false, DiffMode::Forward => { - matches!(activity, Dual | DualOnly | Dualv | DualvOnly | Const) + activity.is_dual_or_const() } DiffMode::Reverse => { matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)