addressing feedback

This commit is contained in:
Manuel Drehwald 2025-04-10 05:22:17 -04:00
parent 64718ab9ad
commit ae6247c9fe

View File

@ -69,6 +69,12 @@ pub enum DiffActivity {
/// length of a slice/vec. This is used for safety checks on slices.
FakeActivitySize,
}
impl DiffActivity {
pub fn is_dual_or_const(&self) -> bool {
use DiffActivity::*;
matches!(self, |Dual | DualOnly | Dualv | DualvOnly | Const)
}
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffItem {
@ -140,11 +146,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
DiffMode::Error => false,
DiffMode::Source => false,
DiffMode::Forward => {
activity == DiffActivity::Dual
|| activity == DiffActivity::Dualv
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::DualvOnly
|| activity == DiffActivity::Const
activity.is_dual_or_const()
}
DiffMode::Reverse => {
activity == DiffActivity::Const
@ -163,10 +165,8 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
use DiffActivity::*;
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
if matches!(activity, Const) {
return true;
}
if matches!(activity, Dual | DualOnly | Dualv | DualvOnly) {
// Dual variants also support all types.
if activity.is_dual_or_const() {
return true;
}
// FIXME(ZuseZ4) We should make this more robust to also
@ -183,7 +183,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
DiffMode::Error => false,
DiffMode::Source => false,
DiffMode::Forward => {
matches!(activity, Dual | DualOnly | Dualv | DualvOnly | Const)
activity.is_dual_or_const()
}
DiffMode::Reverse => {
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)