mirror of
https://github.com/rust-lang/rust.git
synced 2024-11-25 16:24:46 +00:00
Auto merge of #129458 - EnzymeAD:enzyme-frontend, r=jieyouxu
Autodiff Upstreaming - enzyme frontend This is an upstream PR for the `autodiff` rustc_builtin_macro that is part of the autodiff feature. For the full implementation, see: https://github.com/rust-lang/rust/pull/129175 **Content:** It contains a new `#[autodiff(<args>)]` rustc_builtin_macro, as well as a `#[rustc_autodiff]` builtin attribute. The autodiff macro is applied on function `f` and will expand to a second function `df` (name given by user). It will add a dummy body to `df` to make sure it type-checks. The body will later be replaced by enzyme on llvm-ir level, we therefore don't really care about the content. Most of the changes (700 from 1.2k) are in `compiler/rustc_builtin_macros/src/autodiff.rs`, which expand the macro. Nothing except expansion is implemented for now. I have a fallback implementation for relevant functions in case that rustc should be build without autodiff support. The default for now will be off, although we want to flip it later (once everything landed) to on for nightly. For the sake of CI, I have flipped the defaults, I'll revert this before merging. **Dummy function Body:** The first line is an `inline_asm` nop to make inlining less likely (I have additional checks to prevent this in the middle end of rustc. If `f` gets inlined too early, we can't pass it to enzyme and thus can't differentiate it. If `df` gets inlined too early, the call site will just compute this dummy code instead of the derivatives, a correctness issue. The following black_box lines make sure that none of the input arguments is getting optimized away before we replace the body. **Motivation:** The user facing autodiff macro can verify the user input. Then I write it as args to the rustc_attribute, so from here on I can know that these values should be sensible. A rustc_attribute also turned out to be quite nice to attach this information to the corresponding function and carry it till the backend. This is also just an experiment, I expect to adjust the user facing autodiff macro based on user feedback, to improve usability. As a simple example of what this will do, we can see this expansion: From: ``` #[autodiff(df, Reverse, Duplicated, Const, Active)] pub fn f1(x: &[f64], y: f64) -> f64 { unimplemented!() } ``` to ``` #[rustc_autodiff] #[inline(never)] pub fn f1(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, Duplicated, Const, Active,)] #[inline(never)] pub fn df(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 { unsafe { asm!("NOP"); }; ::core::hint::black_box(f1(x, y)); ::core::hint::black_box((dx, dret)); ::core::hint::black_box(f1(x, y)) } ``` I will add a few more tests once I figured out why rustc rebuilds every time I touch a test. Tracking: - https://github.com/rust-lang/rust/issues/124509 try-job: dist-x86_64-msvc
This commit is contained in:
commit
785c83015c
283
compiler/rustc_ast/src/expand/autodiff_attrs.rs
Normal file
283
compiler/rustc_ast/src/expand/autodiff_attrs.rs
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
|
||||||
|
//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
|
||||||
|
//! is the function to which the autodiff attribute is applied, and the target is the function
|
||||||
|
//! getting generated by us (with a name given by the user as the first autodiff arg).
|
||||||
|
|
||||||
|
use std::fmt::{self, Display, Formatter};
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use crate::expand::typetree::TypeTree;
|
||||||
|
use crate::expand::{Decodable, Encodable, HashStable_Generic};
|
||||||
|
use crate::ptr::P;
|
||||||
|
use crate::{Ty, TyKind};
|
||||||
|
|
||||||
|
/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
|
||||||
|
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
|
||||||
|
/// are a hack to support higher order derivatives. We need to compute first order derivatives
|
||||||
|
/// before we compute second order derivatives, otherwise we would differentiate our placeholder
|
||||||
|
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
|
||||||
|
/// as it's already done in the C++ and Julia frontend of Enzyme.
|
||||||
|
///
|
||||||
|
/// (FIXME) remove *First variants.
|
||||||
|
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
|
||||||
|
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
|
||||||
|
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub enum DiffMode {
|
||||||
|
/// No autodiff is applied (used during error handling).
|
||||||
|
Error,
|
||||||
|
/// The primal function which we will differentiate.
|
||||||
|
Source,
|
||||||
|
/// The target function, to be created using forward mode AD.
|
||||||
|
Forward,
|
||||||
|
/// The target function, to be created using reverse mode AD.
|
||||||
|
Reverse,
|
||||||
|
/// The target function, to be created using forward mode AD.
|
||||||
|
/// This target function will also be used as a source for higher order derivatives,
|
||||||
|
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
|
||||||
|
ForwardFirst,
|
||||||
|
/// The target function, to be created using reverse mode AD.
|
||||||
|
/// This target function will also be used as a source for higher order derivatives,
|
||||||
|
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
|
||||||
|
ReverseFirst,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
|
||||||
|
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
|
||||||
|
/// we add to the previous shadow value. To not surprise users, we picked different names.
|
||||||
|
/// Dual numbers is also a quite well known name for forward mode AD types.
|
||||||
|
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub enum DiffActivity {
|
||||||
|
/// Implicit or Explicit () return type, so a special case of Const.
|
||||||
|
None,
|
||||||
|
/// Don't compute derivatives with respect to this input/output.
|
||||||
|
Const,
|
||||||
|
/// Reverse Mode, Compute derivatives for this scalar input/output.
|
||||||
|
Active,
|
||||||
|
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
|
||||||
|
/// the original return value.
|
||||||
|
ActiveOnly,
|
||||||
|
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
|
||||||
|
/// with it.
|
||||||
|
Dual,
|
||||||
|
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
|
||||||
|
/// with it. Drop the code which updates the original input/output for maximum performance.
|
||||||
|
DualOnly,
|
||||||
|
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
|
||||||
|
Duplicated,
|
||||||
|
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
|
||||||
|
/// Drop the code which updates the original input for maximum performance.
|
||||||
|
DuplicatedOnly,
|
||||||
|
/// All Integers must be Const, but these are used to mark the integer which represents the
|
||||||
|
/// length of a slice/vec. This is used for safety checks on slices.
|
||||||
|
FakeActivitySize,
|
||||||
|
}
|
||||||
|
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
|
||||||
|
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub struct AutoDiffItem {
|
||||||
|
/// The name of the function getting differentiated
|
||||||
|
pub source: String,
|
||||||
|
/// The name of the function being generated
|
||||||
|
pub target: String,
|
||||||
|
pub attrs: AutoDiffAttrs,
|
||||||
|
/// Describe the memory layout of input types
|
||||||
|
pub inputs: Vec<TypeTree>,
|
||||||
|
/// Describe the memory layout of the output type
|
||||||
|
pub output: TypeTree,
|
||||||
|
}
|
||||||
|
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub struct AutoDiffAttrs {
|
||||||
|
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
|
||||||
|
/// e.g. in the [JAX
|
||||||
|
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
|
||||||
|
pub mode: DiffMode,
|
||||||
|
pub ret_activity: DiffActivity,
|
||||||
|
pub input_activity: Vec<DiffActivity>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DiffMode {
|
||||||
|
pub fn is_rev(&self) -> bool {
|
||||||
|
matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst)
|
||||||
|
}
|
||||||
|
pub fn is_fwd(&self) -> bool {
|
||||||
|
matches!(self, DiffMode::Forward | DiffMode::ForwardFirst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for DiffMode {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
DiffMode::Error => write!(f, "Error"),
|
||||||
|
DiffMode::Source => write!(f, "Source"),
|
||||||
|
DiffMode::Forward => write!(f, "Forward"),
|
||||||
|
DiffMode::Reverse => write!(f, "Reverse"),
|
||||||
|
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
|
||||||
|
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
|
||||||
|
/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
|
||||||
|
/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
|
||||||
|
/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
|
||||||
|
/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
|
||||||
|
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
|
||||||
|
if activity == DiffActivity::None {
|
||||||
|
// Only valid if primal returns (), but we can't check that here.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
match mode {
|
||||||
|
DiffMode::Error => false,
|
||||||
|
DiffMode::Source => false,
|
||||||
|
DiffMode::Forward | DiffMode::ForwardFirst => {
|
||||||
|
activity == DiffActivity::Dual
|
||||||
|
|| activity == DiffActivity::DualOnly
|
||||||
|
|| activity == DiffActivity::Const
|
||||||
|
}
|
||||||
|
DiffMode::Reverse | DiffMode::ReverseFirst => {
|
||||||
|
activity == DiffActivity::Const
|
||||||
|
|| activity == DiffActivity::Active
|
||||||
|
|| activity == DiffActivity::ActiveOnly
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
|
||||||
|
/// for the given argument, but we generally can't know the size of such a type.
|
||||||
|
/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
|
||||||
|
/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
|
||||||
|
/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
|
||||||
|
/// users here from marking scalars as Duplicated, due to type aliases.
|
||||||
|
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
|
||||||
|
use DiffActivity::*;
|
||||||
|
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
|
||||||
|
if matches!(activity, Const) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if matches!(activity, Dual | DualOnly) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// FIXME(ZuseZ4) We should make this more robust to also
|
||||||
|
// handle type aliases. Once that is done, we can be more restrictive here.
|
||||||
|
if matches!(activity, Active | ActiveOnly) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
|
||||||
|
&& matches!(activity, Duplicated | DuplicatedOnly)
|
||||||
|
}
|
||||||
|
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
|
||||||
|
use DiffActivity::*;
|
||||||
|
return match mode {
|
||||||
|
DiffMode::Error => false,
|
||||||
|
DiffMode::Source => false,
|
||||||
|
DiffMode::Forward | DiffMode::ForwardFirst => {
|
||||||
|
matches!(activity, Dual | DualOnly | Const)
|
||||||
|
}
|
||||||
|
DiffMode::Reverse | DiffMode::ReverseFirst => {
|
||||||
|
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for DiffActivity {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
DiffActivity::None => write!(f, "None"),
|
||||||
|
DiffActivity::Const => write!(f, "Const"),
|
||||||
|
DiffActivity::Active => write!(f, "Active"),
|
||||||
|
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
|
||||||
|
DiffActivity::Dual => write!(f, "Dual"),
|
||||||
|
DiffActivity::DualOnly => write!(f, "DualOnly"),
|
||||||
|
DiffActivity::Duplicated => write!(f, "Duplicated"),
|
||||||
|
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
|
||||||
|
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for DiffMode {
|
||||||
|
type Err = ();
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<DiffMode, ()> {
|
||||||
|
match s {
|
||||||
|
"Error" => Ok(DiffMode::Error),
|
||||||
|
"Source" => Ok(DiffMode::Source),
|
||||||
|
"Forward" => Ok(DiffMode::Forward),
|
||||||
|
"Reverse" => Ok(DiffMode::Reverse),
|
||||||
|
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
|
||||||
|
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
|
||||||
|
_ => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl FromStr for DiffActivity {
|
||||||
|
type Err = ();
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<DiffActivity, ()> {
|
||||||
|
match s {
|
||||||
|
"None" => Ok(DiffActivity::None),
|
||||||
|
"Active" => Ok(DiffActivity::Active),
|
||||||
|
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
|
||||||
|
"Const" => Ok(DiffActivity::Const),
|
||||||
|
"Dual" => Ok(DiffActivity::Dual),
|
||||||
|
"DualOnly" => Ok(DiffActivity::DualOnly),
|
||||||
|
"Duplicated" => Ok(DiffActivity::Duplicated),
|
||||||
|
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
|
||||||
|
_ => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AutoDiffAttrs {
|
||||||
|
pub fn has_ret_activity(&self) -> bool {
|
||||||
|
self.ret_activity != DiffActivity::None
|
||||||
|
}
|
||||||
|
pub fn has_active_only_ret(&self) -> bool {
|
||||||
|
self.ret_activity == DiffActivity::ActiveOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn error() -> Self {
|
||||||
|
AutoDiffAttrs {
|
||||||
|
mode: DiffMode::Error,
|
||||||
|
ret_activity: DiffActivity::None,
|
||||||
|
input_activity: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn source() -> Self {
|
||||||
|
AutoDiffAttrs {
|
||||||
|
mode: DiffMode::Source,
|
||||||
|
ret_activity: DiffActivity::None,
|
||||||
|
input_activity: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_active(&self) -> bool {
|
||||||
|
self.mode != DiffMode::Error
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_source(&self) -> bool {
|
||||||
|
self.mode == DiffMode::Source
|
||||||
|
}
|
||||||
|
pub fn apply_autodiff(&self) -> bool {
|
||||||
|
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_item(
|
||||||
|
self,
|
||||||
|
source: String,
|
||||||
|
target: String,
|
||||||
|
inputs: Vec<TypeTree>,
|
||||||
|
output: TypeTree,
|
||||||
|
) -> AutoDiffItem {
|
||||||
|
AutoDiffItem { source, target, inputs, output, attrs: self }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for AutoDiffItem {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
|
||||||
|
write!(f, " with attributes: {:?}", self.attrs)?;
|
||||||
|
write!(f, " with inputs: {:?}", self.inputs)?;
|
||||||
|
write!(f, " with output: {:?}", self.output)
|
||||||
|
}
|
||||||
|
}
|
@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
|
|||||||
use crate::MetaItem;
|
use crate::MetaItem;
|
||||||
|
|
||||||
pub mod allocator;
|
pub mod allocator;
|
||||||
|
pub mod autodiff_attrs;
|
||||||
|
pub mod typetree;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
|
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
|
||||||
pub struct StrippedCfgItem<ModId = DefId> {
|
pub struct StrippedCfgItem<ModId = DefId> {
|
||||||
|
90
compiler/rustc_ast/src/expand/typetree.rs
Normal file
90
compiler/rustc_ast/src/expand/typetree.rs
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
//! This module contains the definition of the `TypeTree` and `Type` structs.
|
||||||
|
//! They are thin Rust wrappers around the TypeTrees used by Enzyme as the LLVM based autodiff
|
||||||
|
//! backend. The Enzyme TypeTrees currently have various limitations and should be rewritten, so the
|
||||||
|
//! Rust frontend obviously has the same limitations. The main motivation of TypeTrees is to
|
||||||
|
//! represent how a type looks like "in memory". Enzyme can deduce this based on usage patterns in
|
||||||
|
//! the user code, but this is extremely slow and not even always sufficient. As such we lower some
|
||||||
|
//! information from rustc to help Enzyme. For a full explanation of their design it is necessary to
|
||||||
|
//! analyze the implementation in Enzyme core itself. As a rough summary, `-1` in Enzyme speech means
|
||||||
|
//! everywhere. That is `{0:-1: Float}` means at index 0 you have a ptr, if you dereference it it
|
||||||
|
//! will be floats everywhere. Thus `* f32`. If you have `{-1:int}` it means int's everywhere,
|
||||||
|
//! e.g. [i32; N]. `{0:-1:-1 float}` then means one pointer at offset 0, if you dereference it there
|
||||||
|
//! will be only pointers, if you dereference these new pointers they will point to array of floats.
|
||||||
|
//! Generally, it allows byte-specific descriptions.
|
||||||
|
//! FIXME: This description might be partly inaccurate and should be extended, along with
|
||||||
|
//! adding documentation to the corresponding Enzyme core code.
|
||||||
|
//! FIXME: Rewrite the TypeTree logic in Enzyme core to reduce the need for the rustc frontend to
|
||||||
|
//! provide typetree information.
|
||||||
|
//! FIXME: We should also re-evaluate where we create TypeTrees from Rust types, since MIR
|
||||||
|
//! representations of some types might not be accurate. For example a vector of floats might be
|
||||||
|
//! represented as a vector of u8s in MIR in some cases.
|
||||||
|
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
use crate::expand::{Decodable, Encodable, HashStable_Generic};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub enum Kind {
|
||||||
|
Anything,
|
||||||
|
Integer,
|
||||||
|
Pointer,
|
||||||
|
Half,
|
||||||
|
Float,
|
||||||
|
Double,
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub struct TypeTree(pub Vec<Type>);
|
||||||
|
|
||||||
|
impl TypeTree {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self(Vec::new())
|
||||||
|
}
|
||||||
|
pub fn all_ints() -> Self {
|
||||||
|
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
|
||||||
|
}
|
||||||
|
pub fn int(size: usize) -> Self {
|
||||||
|
let mut ints = Vec::with_capacity(size);
|
||||||
|
for i in 0..size {
|
||||||
|
ints.push(Type {
|
||||||
|
offset: i as isize,
|
||||||
|
size: 1,
|
||||||
|
kind: Kind::Integer,
|
||||||
|
child: TypeTree::new(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Self(ints)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub struct FncTree {
|
||||||
|
pub args: Vec<TypeTree>,
|
||||||
|
pub ret: TypeTree,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub struct Type {
|
||||||
|
pub offset: isize,
|
||||||
|
pub size: usize,
|
||||||
|
pub kind: Kind,
|
||||||
|
pub child: TypeTree,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Type {
|
||||||
|
pub fn add_offset(self, add: isize) -> Self {
|
||||||
|
let offset = match self.offset {
|
||||||
|
-1 => add,
|
||||||
|
x => add + x,
|
||||||
|
};
|
||||||
|
|
||||||
|
Self { size: self.size, kind: self.kind, child: self.child, offset }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Type {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
<Self as fmt::Debug>::fmt(self, f)
|
||||||
|
}
|
||||||
|
}
|
@ -3,6 +3,10 @@ name = "rustc_builtin_macros"
|
|||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
|
||||||
|
[lints.rust]
|
||||||
|
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
doctest = false
|
doctest = false
|
||||||
|
|
||||||
|
@ -69,6 +69,15 @@ builtin_macros_assert_requires_boolean = macro requires a boolean expression as
|
|||||||
builtin_macros_assert_requires_expression = macro requires an expression as an argument
|
builtin_macros_assert_requires_expression = macro requires an expression as an argument
|
||||||
.suggestion = try removing semicolon
|
.suggestion = try removing semicolon
|
||||||
|
|
||||||
|
builtin_macros_autodiff = autodiff must be applied to function
|
||||||
|
builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode
|
||||||
|
builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
|
||||||
|
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
|
||||||
|
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
|
||||||
|
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
|
||||||
|
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
|
||||||
|
|
||||||
|
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
|
||||||
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
|
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
|
||||||
.label = not applicable here
|
.label = not applicable here
|
||||||
.label2 = not a `struct`, `enum` or `union`
|
.label2 = not a `struct`, `enum` or `union`
|
||||||
|
820
compiler/rustc_builtin_macros/src/autodiff.rs
Normal file
820
compiler/rustc_builtin_macros/src/autodiff.rs
Normal file
@ -0,0 +1,820 @@
|
|||||||
|
//! This module contains the implementation of the `#[autodiff]` attribute.
|
||||||
|
//! Currently our linter isn't smart enough to see that each import is used in one of the two
|
||||||
|
//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.
|
||||||
|
//! FIXME(ZuseZ4): Remove this once we have a smarter linter.
|
||||||
|
|
||||||
|
#[cfg(llvm_enzyme)]
|
||||||
|
mod llvm_enzyme {
|
||||||
|
use std::str::FromStr;
|
||||||
|
use std::string::String;
|
||||||
|
|
||||||
|
use rustc_ast::expand::autodiff_attrs::{
|
||||||
|
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ty_for_activity,
|
||||||
|
};
|
||||||
|
use rustc_ast::ptr::P;
|
||||||
|
use rustc_ast::token::{Token, TokenKind};
|
||||||
|
use rustc_ast::tokenstream::*;
|
||||||
|
use rustc_ast::visit::AssocCtxt::*;
|
||||||
|
use rustc_ast::{
|
||||||
|
self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
|
||||||
|
PatKind, TyKind,
|
||||||
|
};
|
||||||
|
use rustc_expand::base::{Annotatable, ExtCtxt};
|
||||||
|
use rustc_span::symbol::{Ident, kw, sym};
|
||||||
|
use rustc_span::{Span, Symbol};
|
||||||
|
use thin_vec::{ThinVec, thin_vec};
|
||||||
|
use tracing::{debug, trace};
|
||||||
|
|
||||||
|
use crate::errors;
|
||||||
|
|
||||||
|
// If we have a default `()` return type or explicitley `()` return type,
|
||||||
|
// then we often can skip doing some work.
|
||||||
|
fn has_ret(ty: &FnRetTy) -> bool {
|
||||||
|
match ty {
|
||||||
|
FnRetTy::Ty(ty) => !ty.kind.is_unit(),
|
||||||
|
FnRetTy::Default(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn first_ident(x: &MetaItemInner) -> rustc_span::symbol::Ident {
|
||||||
|
let segments = &x.meta_item().unwrap().path.segments;
|
||||||
|
assert!(segments.len() == 1);
|
||||||
|
segments[0].ident
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(x: &MetaItemInner) -> String {
|
||||||
|
first_ident(x).name.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn from_ast(
|
||||||
|
ecx: &mut ExtCtxt<'_>,
|
||||||
|
meta_item: &ThinVec<MetaItemInner>,
|
||||||
|
has_ret: bool,
|
||||||
|
) -> AutoDiffAttrs {
|
||||||
|
let dcx = ecx.sess.dcx();
|
||||||
|
let mode = name(&meta_item[1]);
|
||||||
|
let Ok(mode) = DiffMode::from_str(&mode) else {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
|
||||||
|
return AutoDiffAttrs::error();
|
||||||
|
};
|
||||||
|
let mut activities: Vec<DiffActivity> = vec![];
|
||||||
|
let mut errors = false;
|
||||||
|
for x in &meta_item[2..] {
|
||||||
|
let activity_str = name(&x);
|
||||||
|
let res = DiffActivity::from_str(&activity_str);
|
||||||
|
match res {
|
||||||
|
Ok(x) => activities.push(x),
|
||||||
|
Err(_) => {
|
||||||
|
dcx.emit_err(errors::AutoDiffUnknownActivity {
|
||||||
|
span: x.span(),
|
||||||
|
act: activity_str,
|
||||||
|
});
|
||||||
|
errors = true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if errors {
|
||||||
|
return AutoDiffAttrs::error();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a return type exist, we need to split the last activity,
|
||||||
|
// otherwise we return None as placeholder.
|
||||||
|
let (ret_activity, input_activity) = if has_ret {
|
||||||
|
let Some((last, rest)) = activities.split_last() else {
|
||||||
|
unreachable!(
|
||||||
|
"should not be reachable because we counted the number of activities previously"
|
||||||
|
);
|
||||||
|
};
|
||||||
|
(last, rest)
|
||||||
|
} else {
|
||||||
|
(&DiffActivity::None, activities.as_slice())
|
||||||
|
};
|
||||||
|
|
||||||
|
AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// We expand the autodiff macro to generate a new placeholder function which passes
|
||||||
|
/// type-checking and can be called by users. The function body of the placeholder function will
|
||||||
|
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
|
||||||
|
/// should just prevent early inlining and optimizations which alter the function signature.
|
||||||
|
/// The exact signature of the generated function depends on the configuration provided by the
|
||||||
|
/// user, but here is an example:
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// #[autodiff(cos_box, Reverse, Duplicated, Active)]
|
||||||
|
/// fn sin(x: &Box<f32>) -> f32 {
|
||||||
|
/// f32::sin(**x)
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
/// which becomes expanded to:
|
||||||
|
/// ```
|
||||||
|
/// #[rustc_autodiff]
|
||||||
|
/// #[inline(never)]
|
||||||
|
/// fn sin(x: &Box<f32>) -> f32 {
|
||||||
|
/// f32::sin(**x)
|
||||||
|
/// }
|
||||||
|
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
|
||||||
|
/// #[inline(never)]
|
||||||
|
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
|
||||||
|
/// unsafe {
|
||||||
|
/// asm!("NOP");
|
||||||
|
/// };
|
||||||
|
/// ::core::hint::black_box(sin(x));
|
||||||
|
/// ::core::hint::black_box((dx, dret));
|
||||||
|
/// ::core::hint::black_box(sin(x))
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
|
||||||
|
/// in CI.
|
||||||
|
pub(crate) fn expand(
|
||||||
|
ecx: &mut ExtCtxt<'_>,
|
||||||
|
expand_span: Span,
|
||||||
|
meta_item: &ast::MetaItem,
|
||||||
|
mut item: Annotatable,
|
||||||
|
) -> Vec<Annotatable> {
|
||||||
|
let dcx = ecx.sess.dcx();
|
||||||
|
// first get the annotable item:
|
||||||
|
let (sig, is_impl): (FnSig, bool) = match &item {
|
||||||
|
Annotatable::Item(ref iitem) => {
|
||||||
|
let sig = match &iitem.kind {
|
||||||
|
ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
|
||||||
|
_ => {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(sig.clone(), false)
|
||||||
|
}
|
||||||
|
Annotatable::AssocItem(ref assoc_item, _) => {
|
||||||
|
let sig = match &assoc_item.kind {
|
||||||
|
ast::AssocItemKind::Fn(box ast::Fn { sig, .. }) => sig,
|
||||||
|
_ => {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(sig.clone(), true)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
|
||||||
|
ast::MetaItemKind::List(ref vec) => vec.clone(),
|
||||||
|
_ => {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let has_ret = has_ret(&sig.decl.output);
|
||||||
|
let sig_span = ecx.with_call_site_ctxt(sig.span);
|
||||||
|
|
||||||
|
let (vis, primal) = match &item {
|
||||||
|
Annotatable::Item(ref iitem) => (iitem.vis.clone(), iitem.ident.clone()),
|
||||||
|
Annotatable::AssocItem(ref assoc_item, _) => {
|
||||||
|
(assoc_item.vis.clone(), assoc_item.ident.clone())
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// create TokenStream from vec elemtents:
|
||||||
|
// meta_item doesn't have a .tokens field
|
||||||
|
let comma: Token = Token::new(TokenKind::Comma, Span::default());
|
||||||
|
let mut ts: Vec<TokenTree> = vec![];
|
||||||
|
if meta_item_vec.len() < 2 {
|
||||||
|
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
|
||||||
|
// input and output args.
|
||||||
|
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
||||||
|
return vec![item];
|
||||||
|
} else {
|
||||||
|
for t in meta_item_vec.clone()[1..].iter() {
|
||||||
|
let val = first_ident(t);
|
||||||
|
let t = Token::from_ast_ident(val);
|
||||||
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||||
|
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !has_ret {
|
||||||
|
// We don't want users to provide a return activity if the function doesn't return anything.
|
||||||
|
// For simplicity, we just add a dummy token to the end of the list.
|
||||||
|
let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
|
||||||
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||||
|
}
|
||||||
|
let ts: TokenStream = TokenStream::from_iter(ts);
|
||||||
|
|
||||||
|
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
|
||||||
|
if !x.is_active() {
|
||||||
|
// We encountered an error, so we return the original item.
|
||||||
|
// This allows us to potentially parse other attributes.
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
let span = ecx.with_def_site_ctxt(expand_span);
|
||||||
|
|
||||||
|
let n_active: u32 = x
|
||||||
|
.input_activity
|
||||||
|
.iter()
|
||||||
|
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
|
||||||
|
.count() as u32;
|
||||||
|
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
|
||||||
|
let new_decl_span = d_sig.span;
|
||||||
|
let d_body = gen_enzyme_body(
|
||||||
|
ecx,
|
||||||
|
&x,
|
||||||
|
n_active,
|
||||||
|
&sig,
|
||||||
|
&d_sig,
|
||||||
|
primal,
|
||||||
|
&new_args,
|
||||||
|
span,
|
||||||
|
sig_span,
|
||||||
|
new_decl_span,
|
||||||
|
idents,
|
||||||
|
errored,
|
||||||
|
);
|
||||||
|
let d_ident = first_ident(&meta_item_vec[0]);
|
||||||
|
|
||||||
|
// The first element of it is the name of the function to be generated
|
||||||
|
let asdf = Box::new(ast::Fn {
|
||||||
|
defaultness: ast::Defaultness::Final,
|
||||||
|
sig: d_sig,
|
||||||
|
generics: Generics::default(),
|
||||||
|
body: Some(d_body),
|
||||||
|
});
|
||||||
|
let mut rustc_ad_attr =
|
||||||
|
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
|
||||||
|
|
||||||
|
let ts2: Vec<TokenTree> = vec![TokenTree::Token(
|
||||||
|
Token::new(TokenKind::Ident(sym::never, false.into()), span),
|
||||||
|
Spacing::Joint,
|
||||||
|
)];
|
||||||
|
let never_arg = ast::DelimArgs {
|
||||||
|
dspan: ast::tokenstream::DelimSpan::from_single(span),
|
||||||
|
delim: ast::token::Delimiter::Parenthesis,
|
||||||
|
tokens: ast::tokenstream::TokenStream::from_iter(ts2),
|
||||||
|
};
|
||||||
|
let inline_item = ast::AttrItem {
|
||||||
|
unsafety: ast::Safety::Default,
|
||||||
|
path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
|
||||||
|
args: ast::AttrArgs::Delimited(never_arg),
|
||||||
|
tokens: None,
|
||||||
|
};
|
||||||
|
let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
|
||||||
|
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
|
||||||
|
let attr: ast::Attribute = ast::Attribute {
|
||||||
|
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
|
||||||
|
id: new_id,
|
||||||
|
style: ast::AttrStyle::Outer,
|
||||||
|
span,
|
||||||
|
};
|
||||||
|
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
|
||||||
|
let inline_never: ast::Attribute = ast::Attribute {
|
||||||
|
kind: ast::AttrKind::Normal(inline_never_attr),
|
||||||
|
id: new_id,
|
||||||
|
style: ast::AttrStyle::Outer,
|
||||||
|
span,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Don't add it multiple times:
|
||||||
|
let orig_annotatable: Annotatable = match item {
|
||||||
|
Annotatable::Item(ref mut iitem) => {
|
||||||
|
if !iitem.attrs.iter().any(|a| a.id == attr.id) {
|
||||||
|
iitem.attrs.push(attr.clone());
|
||||||
|
}
|
||||||
|
if !iitem.attrs.iter().any(|a| a.id == inline_never.id) {
|
||||||
|
iitem.attrs.push(inline_never.clone());
|
||||||
|
}
|
||||||
|
Annotatable::Item(iitem.clone())
|
||||||
|
}
|
||||||
|
Annotatable::AssocItem(ref mut assoc_item, i @ Impl) => {
|
||||||
|
if !assoc_item.attrs.iter().any(|a| a.id == attr.id) {
|
||||||
|
assoc_item.attrs.push(attr.clone());
|
||||||
|
}
|
||||||
|
if !assoc_item.attrs.iter().any(|a| a.id == inline_never.id) {
|
||||||
|
assoc_item.attrs.push(inline_never.clone());
|
||||||
|
}
|
||||||
|
Annotatable::AssocItem(assoc_item.clone(), i)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
unreachable!("annotatable kind checked previously")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Now update for d_fn
|
||||||
|
rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
|
||||||
|
dspan: DelimSpan::dummy(),
|
||||||
|
delim: rustc_ast::token::Delimiter::Parenthesis,
|
||||||
|
tokens: ts,
|
||||||
|
});
|
||||||
|
let d_attr: ast::Attribute = ast::Attribute {
|
||||||
|
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
|
||||||
|
id: new_id,
|
||||||
|
style: ast::AttrStyle::Outer,
|
||||||
|
span,
|
||||||
|
};
|
||||||
|
|
||||||
|
let d_annotatable = if is_impl {
|
||||||
|
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
|
||||||
|
let d_fn = P(ast::AssocItem {
|
||||||
|
attrs: thin_vec![d_attr.clone(), inline_never],
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
span,
|
||||||
|
vis,
|
||||||
|
ident: d_ident,
|
||||||
|
kind: assoc_item,
|
||||||
|
tokens: None,
|
||||||
|
});
|
||||||
|
Annotatable::AssocItem(d_fn, Impl)
|
||||||
|
} else {
|
||||||
|
let mut d_fn = ecx.item(
|
||||||
|
span,
|
||||||
|
d_ident,
|
||||||
|
thin_vec![d_attr.clone(), inline_never],
|
||||||
|
ItemKind::Fn(asdf),
|
||||||
|
);
|
||||||
|
d_fn.vis = vis;
|
||||||
|
Annotatable::Item(d_fn)
|
||||||
|
};
|
||||||
|
|
||||||
|
return vec![orig_annotatable, d_annotatable];
|
||||||
|
}
|
||||||
|
|
||||||
|
// shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
|
||||||
|
// mutable references or ptrs, because Enzyme will write into them.
|
||||||
|
fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
|
||||||
|
let mut ty = ty.clone();
|
||||||
|
match ty.kind {
|
||||||
|
TyKind::Ptr(ref mut mut_ty) => {
|
||||||
|
mut_ty.mutbl = ast::Mutability::Mut;
|
||||||
|
}
|
||||||
|
TyKind::Ref(_, ref mut mut_ty) => {
|
||||||
|
mut_ty.mutbl = ast::Mutability::Mut;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!("unsupported type: {:?}", ty);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ty
|
||||||
|
}
|
||||||
|
|
||||||
|
/// We only want this function to type-check, since we will replace the body
|
||||||
|
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
|
||||||
|
/// so instead we build something that should pass. We also add a inline_asm
|
||||||
|
/// line, as one more barrier for rustc to prevent inlining of this function.
|
||||||
|
/// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
|
||||||
|
/// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
|
||||||
|
/// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
|
||||||
|
/// this function (which should never happen, since it is only a placeholder).
|
||||||
|
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
|
||||||
|
/// from optimizing any arguments away.
|
||||||
|
fn gen_enzyme_body(
|
||||||
|
ecx: &ExtCtxt<'_>,
|
||||||
|
x: &AutoDiffAttrs,
|
||||||
|
n_active: u32,
|
||||||
|
sig: &ast::FnSig,
|
||||||
|
d_sig: &ast::FnSig,
|
||||||
|
primal: Ident,
|
||||||
|
new_names: &[String],
|
||||||
|
span: Span,
|
||||||
|
sig_span: Span,
|
||||||
|
new_decl_span: Span,
|
||||||
|
idents: Vec<Ident>,
|
||||||
|
errored: bool,
|
||||||
|
) -> P<ast::Block> {
|
||||||
|
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
|
||||||
|
let noop = ast::InlineAsm {
|
||||||
|
asm_macro: ast::AsmMacro::Asm,
|
||||||
|
template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
|
||||||
|
template_strs: Box::new([]),
|
||||||
|
operands: vec![],
|
||||||
|
clobber_abis: vec![],
|
||||||
|
options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
|
||||||
|
line_spans: vec![],
|
||||||
|
};
|
||||||
|
let noop_expr = ecx.expr_asm(span, P(noop));
|
||||||
|
let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
|
||||||
|
let unsf_block = ast::Block {
|
||||||
|
stmts: thin_vec![ecx.stmt_semi(noop_expr)],
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
tokens: None,
|
||||||
|
rules: unsf,
|
||||||
|
span,
|
||||||
|
could_be_bare_literal: false,
|
||||||
|
};
|
||||||
|
let unsf_expr = ecx.expr_block(P(unsf_block));
|
||||||
|
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
|
||||||
|
let primal_call = gen_primal_call(ecx, span, primal, idents);
|
||||||
|
let black_box_primal_call =
|
||||||
|
ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![
|
||||||
|
primal_call.clone()
|
||||||
|
]);
|
||||||
|
let tup_args = new_names
|
||||||
|
.iter()
|
||||||
|
.map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let black_box_remaining_args =
|
||||||
|
ecx.expr_call(sig_span, blackbox_call_expr.clone(), thin_vec![
|
||||||
|
ecx.expr_tuple(sig_span, tup_args)
|
||||||
|
]);
|
||||||
|
|
||||||
|
let mut body = ecx.block(span, ThinVec::new());
|
||||||
|
body.stmts.push(ecx.stmt_semi(unsf_expr));
|
||||||
|
|
||||||
|
// This uses primal args which won't be available if we errored before
|
||||||
|
if !errored {
|
||||||
|
body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
|
||||||
|
}
|
||||||
|
body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
|
||||||
|
|
||||||
|
if !has_ret(&d_sig.decl.output) {
|
||||||
|
// there is no return type that we have to match, () works fine.
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
|
||||||
|
// having an active-only return means we'll drop the original return type.
|
||||||
|
// So that can be treated identical to not having one in the first place.
|
||||||
|
let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
|
||||||
|
|
||||||
|
if primal_ret && n_active == 0 && x.mode.is_rev() {
|
||||||
|
// We only have the primal ret.
|
||||||
|
body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone()));
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !primal_ret && n_active == 1 {
|
||||||
|
// Again no tuple return, so return default float val.
|
||||||
|
let ty = match d_sig.decl.output {
|
||||||
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||||
|
FnRetTy::Default(span) => {
|
||||||
|
panic!("Did not expect Default ret ty: {:?}", span);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let arg = ty.kind.is_simple_path().unwrap();
|
||||||
|
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||||
|
let tmp = ecx.def_site_path(&sl);
|
||||||
|
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||||
|
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||||
|
body.stmts.push(ecx.stmt_expr(default_call_expr));
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut exprs = ThinVec::<P<ast::Expr>>::new();
|
||||||
|
if primal_ret {
|
||||||
|
// We have both primal ret and active floats.
|
||||||
|
// primal ret is first, by construction.
|
||||||
|
exprs.push(primal_call.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now construct default placeholder for each active float.
|
||||||
|
// Is there something nicer than f32::default() and f64::default()?
|
||||||
|
let d_ret_ty = match d_sig.decl.output {
|
||||||
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||||
|
FnRetTy::Default(span) => {
|
||||||
|
panic!("Did not expect Default ret ty: {:?}", span);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut d_ret_ty = match d_ret_ty.kind.clone() {
|
||||||
|
TyKind::Tup(ref tys) => tys.clone(),
|
||||||
|
TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
|
||||||
|
if let [segment] = &segments[..]
|
||||||
|
&& segment.args.is_none()
|
||||||
|
{
|
||||||
|
let id = vec![segments[0].ident];
|
||||||
|
let kind = TyKind::Path(None, ecx.path(span, id));
|
||||||
|
let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
|
||||||
|
thin_vec![ty]
|
||||||
|
} else {
|
||||||
|
panic!("Expected tuple or simple path return type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// We messed up construction of d_sig
|
||||||
|
panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
|
||||||
|
assert!(d_ret_ty.len() == 2);
|
||||||
|
// both should be identical, by construction
|
||||||
|
let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
|
||||||
|
let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
|
||||||
|
assert!(arg == arg2);
|
||||||
|
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||||
|
let tmp = ecx.def_site_path(&sl);
|
||||||
|
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||||
|
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||||
|
exprs.push(default_call_expr);
|
||||||
|
} else if x.mode.is_rev() {
|
||||||
|
if primal_ret {
|
||||||
|
// We have extra handling above for the primal ret
|
||||||
|
d_ret_ty = d_ret_ty[1..].to_vec().into();
|
||||||
|
}
|
||||||
|
|
||||||
|
for arg in d_ret_ty.iter() {
|
||||||
|
let arg = arg.kind.is_simple_path().unwrap();
|
||||||
|
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||||
|
let tmp = ecx.def_site_path(&sl);
|
||||||
|
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||||
|
let default_call_expr =
|
||||||
|
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||||
|
exprs.push(default_call_expr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let ret: P<ast::Expr>;
|
||||||
|
match &exprs[..] {
|
||||||
|
[] => {
|
||||||
|
assert!(!has_ret(&d_sig.decl.output));
|
||||||
|
// We don't have to match the return type.
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
[arg] => {
|
||||||
|
ret = ecx
|
||||||
|
.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![arg.clone()]);
|
||||||
|
}
|
||||||
|
args => {
|
||||||
|
let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
|
||||||
|
ret =
|
||||||
|
ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert!(has_ret(&d_sig.decl.output));
|
||||||
|
body.stmts.push(ecx.stmt_expr(ret));
|
||||||
|
|
||||||
|
body
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_primal_call(
|
||||||
|
ecx: &ExtCtxt<'_>,
|
||||||
|
span: Span,
|
||||||
|
primal: Ident,
|
||||||
|
idents: Vec<Ident>,
|
||||||
|
) -> P<ast::Expr> {
|
||||||
|
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
|
||||||
|
if has_self {
|
||||||
|
let args: ThinVec<_> =
|
||||||
|
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
|
||||||
|
let self_expr = ecx.expr_self(span);
|
||||||
|
ecx.expr_method_call(span, self_expr, primal, args.clone())
|
||||||
|
} else {
|
||||||
|
let args: ThinVec<_> =
|
||||||
|
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
|
||||||
|
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
|
||||||
|
ecx.expr_call(span, primal_call_expr, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
|
||||||
|
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
|
||||||
|
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
|
||||||
|
// zero-initialized by Enzyme).
|
||||||
|
// Each argument of the primal function (and the return type if existing) must be annotated with an
|
||||||
|
// activity.
|
||||||
|
//
|
||||||
|
// Error handling: If the user provides an invalid configuration (incorrect numbers, types, or
|
||||||
|
// both), we emit an error and return the original signature. This allows us to continue parsing.
|
||||||
|
fn gen_enzyme_decl(
|
||||||
|
ecx: &ExtCtxt<'_>,
|
||||||
|
sig: &ast::FnSig,
|
||||||
|
x: &AutoDiffAttrs,
|
||||||
|
span: Span,
|
||||||
|
) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
|
||||||
|
let dcx = ecx.sess.dcx();
|
||||||
|
let has_ret = has_ret(&sig.decl.output);
|
||||||
|
let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
|
||||||
|
let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
|
||||||
|
if sig_args != num_activities {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
|
||||||
|
span,
|
||||||
|
expected: sig_args,
|
||||||
|
found: num_activities,
|
||||||
|
});
|
||||||
|
// This is not the right signature, but we can continue parsing.
|
||||||
|
return (sig.clone(), vec![], vec![], true);
|
||||||
|
}
|
||||||
|
assert!(sig.decl.inputs.len() == x.input_activity.len());
|
||||||
|
assert!(has_ret == x.has_ret_activity());
|
||||||
|
let mut d_decl = sig.decl.clone();
|
||||||
|
let mut d_inputs = Vec::new();
|
||||||
|
let mut new_inputs = Vec::new();
|
||||||
|
let mut idents = Vec::new();
|
||||||
|
let mut act_ret = ThinVec::new();
|
||||||
|
|
||||||
|
// We have two loops, a first one just to check the activities and types and possibly report
|
||||||
|
// multiple errors in one compilation session.
|
||||||
|
let mut errors = false;
|
||||||
|
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
|
||||||
|
if !valid_input_activity(x.mode, *activity) {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
|
||||||
|
span,
|
||||||
|
mode: x.mode.to_string(),
|
||||||
|
act: activity.to_string(),
|
||||||
|
});
|
||||||
|
errors = true;
|
||||||
|
}
|
||||||
|
if !valid_ty_for_activity(&arg.ty, *activity) {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
|
||||||
|
span: arg.ty.span,
|
||||||
|
act: activity.to_string(),
|
||||||
|
});
|
||||||
|
errors = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errors {
|
||||||
|
// This is not the right signature, but we can continue parsing.
|
||||||
|
return (sig.clone(), new_inputs, idents, true);
|
||||||
|
}
|
||||||
|
let unsafe_activities = x
|
||||||
|
.input_activity
|
||||||
|
.iter()
|
||||||
|
.any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
|
||||||
|
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
|
||||||
|
d_inputs.push(arg.clone());
|
||||||
|
match activity {
|
||||||
|
DiffActivity::Active => {
|
||||||
|
act_ret.push(arg.ty.clone());
|
||||||
|
}
|
||||||
|
DiffActivity::ActiveOnly => {
|
||||||
|
// We will add the active scalar to the return type.
|
||||||
|
// This is handled later.
|
||||||
|
}
|
||||||
|
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
|
||||||
|
let mut shadow_arg = arg.clone();
|
||||||
|
// We += into the shadow in reverse mode.
|
||||||
|
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
|
||||||
|
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||||
|
ident.name
|
||||||
|
} else {
|
||||||
|
debug!("{:#?}", &shadow_arg.pat);
|
||||||
|
panic!("not an ident?");
|
||||||
|
};
|
||||||
|
let name: String = format!("d{}", old_name);
|
||||||
|
new_inputs.push(name.clone());
|
||||||
|
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||||
|
shadow_arg.pat = P(ast::Pat {
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||||
|
span: shadow_arg.pat.span,
|
||||||
|
tokens: shadow_arg.pat.tokens.clone(),
|
||||||
|
});
|
||||||
|
d_inputs.push(shadow_arg);
|
||||||
|
}
|
||||||
|
DiffActivity::Dual | DiffActivity::DualOnly => {
|
||||||
|
let mut shadow_arg = arg.clone();
|
||||||
|
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||||
|
ident.name
|
||||||
|
} else {
|
||||||
|
debug!("{:#?}", &shadow_arg.pat);
|
||||||
|
panic!("not an ident?");
|
||||||
|
};
|
||||||
|
let name: String = format!("b{}", old_name);
|
||||||
|
new_inputs.push(name.clone());
|
||||||
|
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||||
|
shadow_arg.pat = P(ast::Pat {
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||||
|
span: shadow_arg.pat.span,
|
||||||
|
tokens: shadow_arg.pat.tokens.clone(),
|
||||||
|
});
|
||||||
|
d_inputs.push(shadow_arg);
|
||||||
|
}
|
||||||
|
DiffActivity::Const => {
|
||||||
|
// Nothing to do here.
|
||||||
|
}
|
||||||
|
DiffActivity::None | DiffActivity::FakeActivitySize => {
|
||||||
|
panic!("Should not happen");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||||
|
idents.push(ident.clone());
|
||||||
|
} else {
|
||||||
|
panic!("not an ident?");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
|
||||||
|
if active_only_ret {
|
||||||
|
assert!(x.mode.is_rev());
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we return a scalar in the primal and the scalar is active,
|
||||||
|
// then add it as last arg to the inputs.
|
||||||
|
if x.mode.is_rev() {
|
||||||
|
match x.ret_activity {
|
||||||
|
DiffActivity::Active | DiffActivity::ActiveOnly => {
|
||||||
|
let ty = match d_decl.output {
|
||||||
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||||
|
FnRetTy::Default(span) => {
|
||||||
|
panic!("Did not expect Default ret ty: {:?}", span);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let name = "dret".to_string();
|
||||||
|
let ident = Ident::from_str_and_span(&name, ty.span);
|
||||||
|
let shadow_arg = ast::Param {
|
||||||
|
attrs: ThinVec::new(),
|
||||||
|
ty: ty.clone(),
|
||||||
|
pat: P(ast::Pat {
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||||
|
span: ty.span,
|
||||||
|
tokens: None,
|
||||||
|
}),
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
span: ty.span,
|
||||||
|
is_placeholder: false,
|
||||||
|
};
|
||||||
|
d_inputs.push(shadow_arg);
|
||||||
|
new_inputs.push(name);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d_decl.inputs = d_inputs.into();
|
||||||
|
|
||||||
|
if x.mode.is_fwd() {
|
||||||
|
if let DiffActivity::Dual = x.ret_activity {
|
||||||
|
let ty = match d_decl.output {
|
||||||
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||||
|
FnRetTy::Default(span) => {
|
||||||
|
panic!("Did not expect Default ret ty: {:?}", span);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Dual can only be used for f32/f64 ret.
|
||||||
|
// In that case we return now a tuple with two floats.
|
||||||
|
let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
|
||||||
|
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
||||||
|
d_decl.output = FnRetTy::Ty(ty);
|
||||||
|
}
|
||||||
|
if let DiffActivity::DualOnly = x.ret_activity {
|
||||||
|
// No need to change the return type,
|
||||||
|
// we will just return the shadow in place
|
||||||
|
// of the primal return.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we use ActiveOnly, drop the original return value.
|
||||||
|
d_decl.output =
|
||||||
|
if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
|
||||||
|
|
||||||
|
trace!("act_ret: {:?}", act_ret);
|
||||||
|
|
||||||
|
// If we have an active input scalar, add it's gradient to the
|
||||||
|
// return type. This might require changing the return type to a
|
||||||
|
// tuple.
|
||||||
|
if act_ret.len() > 0 {
|
||||||
|
let ret_ty = match d_decl.output {
|
||||||
|
FnRetTy::Ty(ref ty) => {
|
||||||
|
if !active_only_ret {
|
||||||
|
act_ret.insert(0, ty.clone());
|
||||||
|
}
|
||||||
|
let kind = TyKind::Tup(act_ret);
|
||||||
|
P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
|
||||||
|
}
|
||||||
|
FnRetTy::Default(span) => {
|
||||||
|
if act_ret.len() == 1 {
|
||||||
|
act_ret[0].clone()
|
||||||
|
} else {
|
||||||
|
let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
|
||||||
|
P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
d_decl.output = FnRetTy::Ty(ret_ty);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut d_header = sig.header.clone();
|
||||||
|
if unsafe_activities {
|
||||||
|
d_header.safety = rustc_ast::Safety::Unsafe(span);
|
||||||
|
}
|
||||||
|
let d_sig = FnSig { header: d_header, decl: d_decl, span };
|
||||||
|
trace!("Generated signature: {:?}", d_sig);
|
||||||
|
(d_sig, new_inputs, idents, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(llvm_enzyme))]
|
||||||
|
mod ad_fallback {
|
||||||
|
use rustc_ast::ast;
|
||||||
|
use rustc_expand::base::{Annotatable, ExtCtxt};
|
||||||
|
use rustc_span::Span;
|
||||||
|
|
||||||
|
use crate::errors;
|
||||||
|
pub(crate) fn expand(
|
||||||
|
ecx: &mut ExtCtxt<'_>,
|
||||||
|
_expand_span: Span,
|
||||||
|
meta_item: &ast::MetaItem,
|
||||||
|
item: Annotatable,
|
||||||
|
) -> Vec<Annotatable> {
|
||||||
|
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(llvm_enzyme))]
|
||||||
|
pub(crate) use ad_fallback::expand;
|
||||||
|
#[cfg(llvm_enzyme)]
|
||||||
|
pub(crate) use llvm_enzyme::expand;
|
@ -145,6 +145,78 @@ pub(crate) struct AllocMustStatics {
|
|||||||
pub(crate) span: Span,
|
pub(crate) span: Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(llvm_enzyme)]
|
||||||
|
pub(crate) use autodiff::*;
|
||||||
|
|
||||||
|
#[cfg(llvm_enzyme)]
|
||||||
|
mod autodiff {
|
||||||
|
use super::*;
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_missing_config)]
|
||||||
|
pub(crate) struct AutoDiffMissingConfig {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
}
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_unknown_activity)]
|
||||||
|
pub(crate) struct AutoDiffUnknownActivity {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
pub(crate) act: String,
|
||||||
|
}
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_ty_activity)]
|
||||||
|
pub(crate) struct AutoDiffInvalidTypeForActivity {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
pub(crate) act: String,
|
||||||
|
}
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_number_activities)]
|
||||||
|
pub(crate) struct AutoDiffInvalidNumberActivities {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
pub(crate) expected: usize,
|
||||||
|
pub(crate) found: usize,
|
||||||
|
}
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_mode_activity)]
|
||||||
|
pub(crate) struct AutoDiffInvalidApplicationModeAct {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
pub(crate) mode: String,
|
||||||
|
pub(crate) act: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_mode)]
|
||||||
|
pub(crate) struct AutoDiffInvalidMode {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
pub(crate) mode: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff)]
|
||||||
|
pub(crate) struct AutoDiffInvalidApplication {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(llvm_enzyme))]
|
||||||
|
pub(crate) use ad_fallback::*;
|
||||||
|
#[cfg(not(llvm_enzyme))]
|
||||||
|
mod ad_fallback {
|
||||||
|
use super::*;
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_not_build)]
|
||||||
|
pub(crate) struct AutoDiffSupportNotBuild {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Diagnostic)]
|
#[derive(Diagnostic)]
|
||||||
#[diag(builtin_macros_concat_bytes_invalid)]
|
#[diag(builtin_macros_concat_bytes_invalid)]
|
||||||
pub(crate) struct ConcatBytesInvalid {
|
pub(crate) struct ConcatBytesInvalid {
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#![allow(internal_features)]
|
#![allow(internal_features)]
|
||||||
#![allow(rustc::diagnostic_outside_of_impl)]
|
#![allow(rustc::diagnostic_outside_of_impl)]
|
||||||
#![allow(rustc::untranslatable_diagnostic)]
|
#![allow(rustc::untranslatable_diagnostic)]
|
||||||
|
#![cfg_attr(not(bootstrap), feature(autodiff))]
|
||||||
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
|
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
|
||||||
#![doc(rust_logo)]
|
#![doc(rust_logo)]
|
||||||
#![feature(assert_matches)]
|
#![feature(assert_matches)]
|
||||||
@ -29,6 +30,7 @@ use crate::deriving::*;
|
|||||||
|
|
||||||
mod alloc_error_handler;
|
mod alloc_error_handler;
|
||||||
mod assert;
|
mod assert;
|
||||||
|
mod autodiff;
|
||||||
mod cfg;
|
mod cfg;
|
||||||
mod cfg_accessible;
|
mod cfg_accessible;
|
||||||
mod cfg_eval;
|
mod cfg_eval;
|
||||||
@ -106,6 +108,7 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
|
|||||||
|
|
||||||
register_attr! {
|
register_attr! {
|
||||||
alloc_error_handler: alloc_error_handler::expand,
|
alloc_error_handler: alloc_error_handler::expand,
|
||||||
|
autodiff: autodiff::expand,
|
||||||
bench: test::expand_bench,
|
bench: test::expand_bench,
|
||||||
cfg_accessible: cfg_accessible::Expander,
|
cfg_accessible: cfg_accessible::Expander,
|
||||||
cfg_eval: cfg_eval::expand,
|
cfg_eval: cfg_eval::expand,
|
||||||
|
@ -220,6 +220,10 @@ impl<'a> ExtCtxt<'a> {
|
|||||||
self.stmt_local(local, span)
|
self.stmt_local(local, span)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn stmt_semi(&self, expr: P<ast::Expr>) -> ast::Stmt {
|
||||||
|
ast::Stmt { id: ast::DUMMY_NODE_ID, span: expr.span, kind: ast::StmtKind::Semi(expr) }
|
||||||
|
}
|
||||||
|
|
||||||
pub fn stmt_local(&self, local: P<ast::Local>, span: Span) -> ast::Stmt {
|
pub fn stmt_local(&self, local: P<ast::Local>, span: Span) -> ast::Stmt {
|
||||||
ast::Stmt { id: ast::DUMMY_NODE_ID, kind: ast::StmtKind::Let(local), span }
|
ast::Stmt { id: ast::DUMMY_NODE_ID, kind: ast::StmtKind::Let(local), span }
|
||||||
}
|
}
|
||||||
@ -287,6 +291,25 @@ impl<'a> ExtCtxt<'a> {
|
|||||||
self.expr(sp, ast::ExprKind::Paren(e))
|
self.expr(sp, ast::ExprKind::Paren(e))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn expr_method_call(
|
||||||
|
&self,
|
||||||
|
span: Span,
|
||||||
|
expr: P<ast::Expr>,
|
||||||
|
ident: Ident,
|
||||||
|
args: ThinVec<P<ast::Expr>>,
|
||||||
|
) -> P<ast::Expr> {
|
||||||
|
let seg = ast::PathSegment::from_ident(ident);
|
||||||
|
self.expr(
|
||||||
|
span,
|
||||||
|
ast::ExprKind::MethodCall(Box::new(ast::MethodCall {
|
||||||
|
seg,
|
||||||
|
receiver: expr,
|
||||||
|
args,
|
||||||
|
span,
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn expr_call(
|
pub fn expr_call(
|
||||||
&self,
|
&self,
|
||||||
span: Span,
|
span: Span,
|
||||||
@ -295,6 +318,12 @@ impl<'a> ExtCtxt<'a> {
|
|||||||
) -> P<ast::Expr> {
|
) -> P<ast::Expr> {
|
||||||
self.expr(span, ast::ExprKind::Call(expr, args))
|
self.expr(span, ast::ExprKind::Call(expr, args))
|
||||||
}
|
}
|
||||||
|
pub fn expr_loop(&self, sp: Span, block: P<ast::Block>) -> P<ast::Expr> {
|
||||||
|
self.expr(sp, ast::ExprKind::Loop(block, None, sp))
|
||||||
|
}
|
||||||
|
pub fn expr_asm(&self, sp: Span, expr: P<ast::InlineAsm>) -> P<ast::Expr> {
|
||||||
|
self.expr(sp, ast::ExprKind::InlineAsm(expr))
|
||||||
|
}
|
||||||
pub fn expr_call_ident(
|
pub fn expr_call_ident(
|
||||||
&self,
|
&self,
|
||||||
span: Span,
|
span: Span,
|
||||||
|
@ -752,6 +752,11 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[
|
|||||||
template!(NameValueStr: "transparent|semitransparent|opaque"), ErrorFollowing,
|
template!(NameValueStr: "transparent|semitransparent|opaque"), ErrorFollowing,
|
||||||
EncodeCrossCrate::Yes, "used internally for testing macro hygiene",
|
EncodeCrossCrate::Yes, "used internally for testing macro hygiene",
|
||||||
),
|
),
|
||||||
|
rustc_attr!(
|
||||||
|
rustc_autodiff, Normal,
|
||||||
|
template!(Word, List: r#""...""#), DuplicatesOk,
|
||||||
|
EncodeCrossCrate::No, INTERNAL_UNSTABLE
|
||||||
|
),
|
||||||
|
|
||||||
// ==========================================================================
|
// ==========================================================================
|
||||||
// Internal attributes, Diagnostics related:
|
// Internal attributes, Diagnostics related:
|
||||||
|
@ -49,6 +49,10 @@ passes_attr_crate_level =
|
|||||||
passes_attr_only_in_functions =
|
passes_attr_only_in_functions =
|
||||||
`{$attr}` attribute can only be used on functions
|
`{$attr}` attribute can only be used on functions
|
||||||
|
|
||||||
|
passes_autodiff_attr =
|
||||||
|
`#[autodiff]` should be applied to a function
|
||||||
|
.label = not a function
|
||||||
|
|
||||||
passes_both_ffi_const_and_pure =
|
passes_both_ffi_const_and_pure =
|
||||||
`#[ffi_const]` function cannot be `#[ffi_pure]`
|
`#[ffi_const]` function cannot be `#[ffi_pure]`
|
||||||
|
|
||||||
|
@ -243,6 +243,9 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
|
|||||||
self.check_generic_attr(hir_id, attr, target, Target::Fn);
|
self.check_generic_attr(hir_id, attr, target, Target::Fn);
|
||||||
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
|
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
|
||||||
}
|
}
|
||||||
|
[sym::autodiff, ..] => {
|
||||||
|
self.check_autodiff(hir_id, attr, span, target)
|
||||||
|
}
|
||||||
[sym::coroutine, ..] => {
|
[sym::coroutine, ..] => {
|
||||||
self.check_coroutine(attr, target);
|
self.check_coroutine(attr, target);
|
||||||
}
|
}
|
||||||
@ -2345,6 +2348,18 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
|
|||||||
self.dcx().emit_err(errors::RustcPubTransparent { span, attr_span });
|
self.dcx().emit_err(errors::RustcPubTransparent { span, attr_span });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Checks if `#[autodiff]` is applied to an item other than a function item.
|
||||||
|
fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) {
|
||||||
|
debug!("check_autodiff");
|
||||||
|
match target {
|
||||||
|
Target::Fn => {}
|
||||||
|
_ => {
|
||||||
|
self.dcx().emit_err(errors::AutoDiffAttr { attr_span: span });
|
||||||
|
self.abort.set(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> {
|
impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> {
|
||||||
|
@ -20,6 +20,14 @@ use crate::lang_items::Duplicate;
|
|||||||
#[diag(passes_incorrect_do_not_recommend_location)]
|
#[diag(passes_incorrect_do_not_recommend_location)]
|
||||||
pub(crate) struct IncorrectDoNotRecommendLocation;
|
pub(crate) struct IncorrectDoNotRecommendLocation;
|
||||||
|
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(passes_autodiff_attr)]
|
||||||
|
pub(crate) struct AutoDiffAttr {
|
||||||
|
#[primary_span]
|
||||||
|
#[label]
|
||||||
|
pub attr_span: Span,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(LintDiagnostic)]
|
#[derive(LintDiagnostic)]
|
||||||
#[diag(passes_outer_crate_level_attr)]
|
#[diag(passes_outer_crate_level_attr)]
|
||||||
pub(crate) struct OuterCrateLevelAttr;
|
pub(crate) struct OuterCrateLevelAttr;
|
||||||
|
@ -481,6 +481,8 @@ symbols! {
|
|||||||
audit_that,
|
audit_that,
|
||||||
augmented_assignments,
|
augmented_assignments,
|
||||||
auto_traits,
|
auto_traits,
|
||||||
|
autodiff,
|
||||||
|
autodiff_fallback,
|
||||||
automatically_derived,
|
automatically_derived,
|
||||||
avx,
|
avx,
|
||||||
avx512_target_feature,
|
avx512_target_feature,
|
||||||
@ -544,6 +546,7 @@ symbols! {
|
|||||||
cfg_accessible,
|
cfg_accessible,
|
||||||
cfg_attr,
|
cfg_attr,
|
||||||
cfg_attr_multi,
|
cfg_attr_multi,
|
||||||
|
cfg_autodiff_fallback,
|
||||||
cfg_boolean_literals,
|
cfg_boolean_literals,
|
||||||
cfg_doctest,
|
cfg_doctest,
|
||||||
cfg_eval,
|
cfg_eval,
|
||||||
@ -1002,6 +1005,7 @@ symbols! {
|
|||||||
hashset_iter_ty,
|
hashset_iter_ty,
|
||||||
hexagon_target_feature,
|
hexagon_target_feature,
|
||||||
hidden,
|
hidden,
|
||||||
|
hint,
|
||||||
homogeneous_aggregate,
|
homogeneous_aggregate,
|
||||||
host,
|
host,
|
||||||
html_favicon_url,
|
html_favicon_url,
|
||||||
@ -1654,6 +1658,7 @@ symbols! {
|
|||||||
rustc_allow_incoherent_impl,
|
rustc_allow_incoherent_impl,
|
||||||
rustc_allowed_through_unstable_modules,
|
rustc_allowed_through_unstable_modules,
|
||||||
rustc_attrs,
|
rustc_attrs,
|
||||||
|
rustc_autodiff,
|
||||||
rustc_box,
|
rustc_box,
|
||||||
rustc_builtin_macro,
|
rustc_builtin_macro,
|
||||||
rustc_capture_analysis,
|
rustc_capture_analysis,
|
||||||
|
@ -270,6 +270,15 @@ pub mod assert_matches {
|
|||||||
pub use crate::macros::{assert_matches, debug_assert_matches};
|
pub use crate::macros::{assert_matches, debug_assert_matches};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We don't export this through #[macro_export] for now, to avoid breakage.
|
||||||
|
#[cfg(not(bootstrap))]
|
||||||
|
#[unstable(feature = "autodiff", issue = "124509")]
|
||||||
|
/// Unstable module containing the unstable `autodiff` macro.
|
||||||
|
pub mod autodiff {
|
||||||
|
#[unstable(feature = "autodiff", issue = "124509")]
|
||||||
|
pub use crate::macros::builtin::autodiff;
|
||||||
|
}
|
||||||
|
|
||||||
#[unstable(feature = "cfg_match", issue = "115585")]
|
#[unstable(feature = "cfg_match", issue = "115585")]
|
||||||
pub use crate::macros::cfg_match;
|
pub use crate::macros::cfg_match;
|
||||||
|
|
||||||
|
@ -1539,6 +1539,24 @@ pub(crate) mod builtin {
|
|||||||
($file:expr $(,)?) => {{ /* compiler built-in */ }};
|
($file:expr $(,)?) => {{ /* compiler built-in */ }};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Automatic Differentiation macro which allows generating a new function to compute
|
||||||
|
/// the derivative of a given function. It may only be applied to a function.
|
||||||
|
/// The expected usage syntax is
|
||||||
|
/// `#[autodiff(NAME, MODE, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
|
||||||
|
/// where:
|
||||||
|
/// NAME is a string that represents a valid function name.
|
||||||
|
/// MODE is any of Forward, Reverse, ForwardFirst, ReverseFirst.
|
||||||
|
/// INPUT_ACTIVITIES consists of one valid activity for each input parameter.
|
||||||
|
/// OUTPUT_ACTIVITY must not be set if we implicitely return nothing (or explicitely return
|
||||||
|
/// `-> ()`. Otherwise it must be set to one of the allowed activities.
|
||||||
|
#[unstable(feature = "autodiff", issue = "124509")]
|
||||||
|
#[allow_internal_unstable(rustc_attrs)]
|
||||||
|
#[rustc_builtin_macro]
|
||||||
|
#[cfg(not(bootstrap))]
|
||||||
|
pub macro autodiff($item:item) {
|
||||||
|
/* compiler built-in */
|
||||||
|
}
|
||||||
|
|
||||||
/// Asserts that a boolean expression is `true` at runtime.
|
/// Asserts that a boolean expression is `true` at runtime.
|
||||||
///
|
///
|
||||||
/// This will invoke the [`panic!`] macro if the provided expression cannot be
|
/// This will invoke the [`panic!`] macro if the provided expression cannot be
|
||||||
|
@ -267,6 +267,7 @@
|
|||||||
#![allow(unused_features)]
|
#![allow(unused_features)]
|
||||||
//
|
//
|
||||||
// Features:
|
// Features:
|
||||||
|
#![cfg_attr(not(bootstrap), feature(autodiff))]
|
||||||
#![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count, rt))]
|
#![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count, rt))]
|
||||||
#![cfg_attr(
|
#![cfg_attr(
|
||||||
all(target_vendor = "fortanix", target_env = "sgx"),
|
all(target_vendor = "fortanix", target_env = "sgx"),
|
||||||
@ -624,7 +625,13 @@ pub mod simd {
|
|||||||
#[doc(inline)]
|
#[doc(inline)]
|
||||||
pub use crate::std_float::StdFloat;
|
pub use crate::std_float::StdFloat;
|
||||||
}
|
}
|
||||||
|
#[cfg(not(bootstrap))]
|
||||||
|
#[unstable(feature = "autodiff", issue = "124509")]
|
||||||
|
/// This module provides support for automatic differentiation.
|
||||||
|
pub mod autodiff {
|
||||||
|
/// This macro handles automatic differentiation.
|
||||||
|
pub use core::autodiff::autodiff;
|
||||||
|
}
|
||||||
#[stable(feature = "futures_api", since = "1.36.0")]
|
#[stable(feature = "futures_api", since = "1.36.0")]
|
||||||
pub mod task {
|
pub mod task {
|
||||||
//! Types and Traits for working with asynchronous tasks.
|
//! Types and Traits for working with asynchronous tasks.
|
||||||
|
107
tests/pretty/autodiff_forward.pp
Normal file
107
tests/pretty/autodiff_forward.pp
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
#![feature(prelude_import)]
|
||||||
|
#![no_std]
|
||||||
|
//@ needs-enzyme
|
||||||
|
|
||||||
|
#![feature(autodiff)]
|
||||||
|
#[prelude_import]
|
||||||
|
use ::std::prelude::rust_2015::*;
|
||||||
|
#[macro_use]
|
||||||
|
extern crate std;
|
||||||
|
//@ pretty-mode:expanded
|
||||||
|
//@ pretty-compare-only
|
||||||
|
//@ pp-exact:autodiff_forward.pp
|
||||||
|
|
||||||
|
// Test that forward mode ad macros are expanded correctly.
|
||||||
|
|
||||||
|
use std::autodiff::autodiff;
|
||||||
|
|
||||||
|
#[rustc_autodiff]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn f1(x: &[f64], y: f64) -> f64 {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Not the most interesting derivative, but who are we to judge
|
||||||
|
|
||||||
|
// We want to be sure that the same function can be differentiated in different ways
|
||||||
|
|
||||||
|
::core::panicking::panic("not implemented")
|
||||||
|
}
|
||||||
|
#[rustc_autodiff(Forward, Dual, Const, Dual,)]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn df1(x: &[f64], bx: &[f64], y: f64) -> (f64, f64) {
|
||||||
|
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||||
|
::core::hint::black_box(f1(x, y));
|
||||||
|
::core::hint::black_box((bx,));
|
||||||
|
::core::hint::black_box((f1(x, y), f64::default()))
|
||||||
|
}
|
||||||
|
#[rustc_autodiff]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn f2(x: &[f64], y: f64) -> f64 {
|
||||||
|
::core::panicking::panic("not implemented")
|
||||||
|
}
|
||||||
|
#[rustc_autodiff(Forward, Dual, Const, Const,)]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn df2(x: &[f64], bx: &[f64], y: f64) -> f64 {
|
||||||
|
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||||
|
::core::hint::black_box(f2(x, y));
|
||||||
|
::core::hint::black_box((bx,));
|
||||||
|
::core::hint::black_box(f2(x, y))
|
||||||
|
}
|
||||||
|
#[rustc_autodiff]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn f3(x: &[f64], y: f64) -> f64 {
|
||||||
|
::core::panicking::panic("not implemented")
|
||||||
|
}
|
||||||
|
#[rustc_autodiff(ForwardFirst, Dual, Const, Const,)]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn df3(x: &[f64], bx: &[f64], y: f64) -> f64 {
|
||||||
|
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||||
|
::core::hint::black_box(f3(x, y));
|
||||||
|
::core::hint::black_box((bx,));
|
||||||
|
::core::hint::black_box(f3(x, y))
|
||||||
|
}
|
||||||
|
#[rustc_autodiff]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn f4() {}
|
||||||
|
#[rustc_autodiff(Forward, None)]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn df4() {
|
||||||
|
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||||
|
::core::hint::black_box(f4());
|
||||||
|
::core::hint::black_box(());
|
||||||
|
}
|
||||||
|
#[rustc_autodiff]
|
||||||
|
#[inline(never)]
|
||||||
|
#[rustc_autodiff]
|
||||||
|
#[inline(never)]
|
||||||
|
#[rustc_autodiff]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn f5(x: &[f64], y: f64) -> f64 {
|
||||||
|
::core::panicking::panic("not implemented")
|
||||||
|
}
|
||||||
|
#[rustc_autodiff(Forward, Const, Dual, Const,)]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn df5_y(x: &[f64], y: f64, by: f64) -> f64 {
|
||||||
|
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||||
|
::core::hint::black_box(f5(x, y));
|
||||||
|
::core::hint::black_box((by,));
|
||||||
|
::core::hint::black_box(f5(x, y))
|
||||||
|
}
|
||||||
|
#[rustc_autodiff(Forward, Dual, Const, Const,)]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn df5_x(x: &[f64], bx: &[f64], y: f64) -> f64 {
|
||||||
|
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||||
|
::core::hint::black_box(f5(x, y));
|
||||||
|
::core::hint::black_box((bx,));
|
||||||
|
::core::hint::black_box(f5(x, y))
|
||||||
|
}
|
||||||
|
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn df5_rev(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||||
|
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||||
|
::core::hint::black_box(f5(x, y));
|
||||||
|
::core::hint::black_box((dx, dret));
|
||||||
|
::core::hint::black_box(f5(x, y))
|
||||||
|
}
|
||||||
|
fn main() {}
|
39
tests/pretty/autodiff_forward.rs
Normal file
39
tests/pretty/autodiff_forward.rs
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
//@ needs-enzyme
|
||||||
|
|
||||||
|
#![feature(autodiff)]
|
||||||
|
//@ pretty-mode:expanded
|
||||||
|
//@ pretty-compare-only
|
||||||
|
//@ pp-exact:autodiff_forward.pp
|
||||||
|
|
||||||
|
// Test that forward mode ad macros are expanded correctly.
|
||||||
|
|
||||||
|
use std::autodiff::autodiff;
|
||||||
|
|
||||||
|
#[autodiff(df1, Forward, Dual, Const, Dual)]
|
||||||
|
pub fn f1(x: &[f64], y: f64) -> f64 {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[autodiff(df2, Forward, Dual, Const, Const)]
|
||||||
|
pub fn f2(x: &[f64], y: f64) -> f64 {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[autodiff(df3, ForwardFirst, Dual, Const, Const)]
|
||||||
|
pub fn f3(x: &[f64], y: f64) -> f64 {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not the most interesting derivative, but who are we to judge
|
||||||
|
#[autodiff(df4, Forward)]
|
||||||
|
pub fn f4() {}
|
||||||
|
|
||||||
|
// We want to be sure that the same function can be differentiated in different ways
|
||||||
|
#[autodiff(df5_rev, Reverse, Duplicated, Const, Active)]
|
||||||
|
#[autodiff(df5_x, Forward, Dual, Const, Const)]
|
||||||
|
#[autodiff(df5_y, Forward, Const, Dual, Const)]
|
||||||
|
pub fn f5(x: &[f64], y: f64) -> f64 {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
86
tests/pretty/autodiff_reverse.pp
Normal file
86
tests/pretty/autodiff_reverse.pp
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
#![feature(prelude_import)]
|
||||||
|
#![no_std]
|
||||||
|
//@ needs-enzyme
|
||||||
|
|
||||||
|
#![feature(autodiff)]
|
||||||
|
#[prelude_import]
|
||||||
|
use ::std::prelude::rust_2015::*;
|
||||||
|
#[macro_use]
|
||||||
|
extern crate std;
|
||||||
|
//@ pretty-mode:expanded
|
||||||
|
//@ pretty-compare-only
|
||||||
|
//@ pp-exact:autodiff_reverse.pp
|
||||||
|
|
||||||
|
// Test that reverse mode ad macros are expanded correctly.
|
||||||
|
|
||||||
|
use std::autodiff::autodiff;
|
||||||
|
|
||||||
|
#[rustc_autodiff]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn f1(x: &[f64], y: f64) -> f64 {
|
||||||
|
|
||||||
|
// Not the most interesting derivative, but who are we to judge
|
||||||
|
|
||||||
|
|
||||||
|
// What happens if we already have Reverse in type (enum variant decl) and value (enum variant
|
||||||
|
// constructor) namespace? > It's expected to work normally.
|
||||||
|
|
||||||
|
|
||||||
|
::core::panicking::panic("not implemented")
|
||||||
|
}
|
||||||
|
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn df1(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||||
|
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||||
|
::core::hint::black_box(f1(x, y));
|
||||||
|
::core::hint::black_box((dx, dret));
|
||||||
|
::core::hint::black_box(f1(x, y))
|
||||||
|
}
|
||||||
|
#[rustc_autodiff]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn f2() {}
|
||||||
|
#[rustc_autodiff(Reverse, None)]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn df2() {
|
||||||
|
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||||
|
::core::hint::black_box(f2());
|
||||||
|
::core::hint::black_box(());
|
||||||
|
}
|
||||||
|
#[rustc_autodiff]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn f3(x: &[f64], y: f64) -> f64 {
|
||||||
|
::core::panicking::panic("not implemented")
|
||||||
|
}
|
||||||
|
#[rustc_autodiff(ReverseFirst, Duplicated, Const, Active,)]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn df3(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||||
|
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||||
|
::core::hint::black_box(f3(x, y));
|
||||||
|
::core::hint::black_box((dx, dret));
|
||||||
|
::core::hint::black_box(f3(x, y))
|
||||||
|
}
|
||||||
|
enum Foo { Reverse, }
|
||||||
|
use Foo::Reverse;
|
||||||
|
#[rustc_autodiff]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
|
||||||
|
#[rustc_autodiff(Reverse, Const, None)]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn df4(x: f32) {
|
||||||
|
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||||
|
::core::hint::black_box(f4(x));
|
||||||
|
::core::hint::black_box(());
|
||||||
|
}
|
||||||
|
#[rustc_autodiff]
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn f5(x: *const f32, y: &f32) {
|
||||||
|
::core::panicking::panic("not implemented")
|
||||||
|
}
|
||||||
|
#[rustc_autodiff(Reverse, DuplicatedOnly, Duplicated, None)]
|
||||||
|
#[inline(never)]
|
||||||
|
pub unsafe fn df5(x: *const f32, dx: *mut f32, y: &f32, dy: &mut f32) {
|
||||||
|
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||||
|
::core::hint::black_box(f5(x, y));
|
||||||
|
::core::hint::black_box((dx, dy));
|
||||||
|
}
|
||||||
|
fn main() {}
|
40
tests/pretty/autodiff_reverse.rs
Normal file
40
tests/pretty/autodiff_reverse.rs
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
//@ needs-enzyme
|
||||||
|
|
||||||
|
#![feature(autodiff)]
|
||||||
|
//@ pretty-mode:expanded
|
||||||
|
//@ pretty-compare-only
|
||||||
|
//@ pp-exact:autodiff_reverse.pp
|
||||||
|
|
||||||
|
// Test that reverse mode ad macros are expanded correctly.
|
||||||
|
|
||||||
|
use std::autodiff::autodiff;
|
||||||
|
|
||||||
|
#[autodiff(df1, Reverse, Duplicated, Const, Active)]
|
||||||
|
pub fn f1(x: &[f64], y: f64) -> f64 {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not the most interesting derivative, but who are we to judge
|
||||||
|
#[autodiff(df2, Reverse)]
|
||||||
|
pub fn f2() {}
|
||||||
|
|
||||||
|
#[autodiff(df3, ReverseFirst, Duplicated, Const, Active)]
|
||||||
|
pub fn f3(x: &[f64], y: f64) -> f64 {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
enum Foo { Reverse }
|
||||||
|
use Foo::Reverse;
|
||||||
|
// What happens if we already have Reverse in type (enum variant decl) and value (enum variant
|
||||||
|
// constructor) namespace? > It's expected to work normally.
|
||||||
|
#[autodiff(df4, Reverse, Const)]
|
||||||
|
pub fn f4(x: f32) {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[autodiff(df5, Reverse, DuplicatedOnly, Duplicated)]
|
||||||
|
pub fn f5(x: *const f32, y: &f32) {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
160
tests/ui/autodiff/autodiff_illegal.rs
Normal file
160
tests/ui/autodiff/autodiff_illegal.rs
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
//@ needs-enzyme
|
||||||
|
|
||||||
|
#![feature(autodiff)]
|
||||||
|
//@ pretty-mode:expanded
|
||||||
|
//@ pretty-compare-only
|
||||||
|
//@ pp-exact:autodiff_illegal.pp
|
||||||
|
|
||||||
|
// Test that invalid ad macros give nice errors and don't ICE.
|
||||||
|
|
||||||
|
use std::autodiff::autodiff;
|
||||||
|
|
||||||
|
// We can't use Duplicated on scalars
|
||||||
|
#[autodiff(df1, Reverse, Duplicated)]
|
||||||
|
pub fn f1(x: f64) {
|
||||||
|
//~^ ERROR Duplicated can not be used for this type
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Too many activities
|
||||||
|
#[autodiff(df3, Reverse, Duplicated, Const)]
|
||||||
|
pub fn f3(x: f64) {
|
||||||
|
//~^^ ERROR expected 1 activities, but found 2
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// To few activities
|
||||||
|
#[autodiff(df4, Reverse)]
|
||||||
|
pub fn f4(x: f64) {
|
||||||
|
//~^^ ERROR expected 1 activities, but found 0
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// We can't use Dual in Reverse mode
|
||||||
|
#[autodiff(df5, Reverse, Dual)]
|
||||||
|
pub fn f5(x: f64) {
|
||||||
|
//~^^ ERROR Dual can not be used in Reverse Mode
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// We can't use Duplicated in Forward mode
|
||||||
|
#[autodiff(df6, Forward, Duplicated)]
|
||||||
|
pub fn f6(x: f64) {
|
||||||
|
//~^^ ERROR Duplicated can not be used in Forward Mode
|
||||||
|
//~^^ ERROR Duplicated can not be used for this type
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dummy() {
|
||||||
|
|
||||||
|
#[autodiff(df7, Forward, Dual)]
|
||||||
|
let mut x = 5;
|
||||||
|
//~^ ERROR autodiff must be applied to function
|
||||||
|
|
||||||
|
#[autodiff(df7, Forward, Dual)]
|
||||||
|
x = x + 3;
|
||||||
|
//~^^ ERROR attributes on expressions are experimental [E0658]
|
||||||
|
//~^^ ERROR autodiff must be applied to function
|
||||||
|
|
||||||
|
#[autodiff(df7, Forward, Dual)]
|
||||||
|
let add_one_v2 = |x: u32| -> u32 { x + 1 };
|
||||||
|
//~^ ERROR autodiff must be applied to function
|
||||||
|
}
|
||||||
|
|
||||||
|
// Malformed, where args?
|
||||||
|
#[autodiff]
|
||||||
|
pub fn f7(x: f64) {
|
||||||
|
//~^ ERROR autodiff must be applied to function
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Malformed, where args?
|
||||||
|
#[autodiff()]
|
||||||
|
pub fn f8(x: f64) {
|
||||||
|
//~^ ERROR autodiff requires at least a name and mode
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid attribute syntax
|
||||||
|
#[autodiff = ""]
|
||||||
|
pub fn f9(x: f64) {
|
||||||
|
//~^ ERROR autodiff must be applied to function
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fn_exists() {}
|
||||||
|
|
||||||
|
// We colide with an already existing function
|
||||||
|
#[autodiff(fn_exists, Reverse, Active)]
|
||||||
|
pub fn f10(x: f64) {
|
||||||
|
//~^^ ERROR the name `fn_exists` is defined multiple times [E0428]
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Malformed, missing a mode
|
||||||
|
#[autodiff(df11)]
|
||||||
|
pub fn f11() {
|
||||||
|
//~^ ERROR autodiff requires at least a name and mode
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid Mode
|
||||||
|
#[autodiff(df12, Debug)]
|
||||||
|
pub fn f12() {
|
||||||
|
//~^^ ERROR unknown Mode: `Debug`. Use `Forward` or `Reverse`
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid, please pick one Mode
|
||||||
|
// or use two autodiff macros.
|
||||||
|
#[autodiff(df13, Forward, Reverse)]
|
||||||
|
pub fn f13() {
|
||||||
|
//~^^ ERROR did not recognize Activity: `Reverse`
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Foo {}
|
||||||
|
|
||||||
|
// We can't handle Active structs, because that would mean (in the general case), that we would
|
||||||
|
// need to allocate and initialize arbitrary user types. We have Duplicated/Dual input args for
|
||||||
|
// that. FIXME: Give a nicer error and suggest to the user to have a `&mut Foo` input instead.
|
||||||
|
#[autodiff(df14, Reverse, Active, Active)]
|
||||||
|
fn f14(x: f32) -> Foo {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
type MyFloat = f32;
|
||||||
|
|
||||||
|
// We would like to support type alias to f32/f64 in argument type in the future,
|
||||||
|
// but that requires us to implement our checks at a later stage
|
||||||
|
// like THIR which has type information available.
|
||||||
|
#[autodiff(df15, Reverse, Active, Active)]
|
||||||
|
fn f15(x: MyFloat) -> f32 {
|
||||||
|
//~^^ ERROR failed to resolve: use of undeclared type `MyFloat` [E0433]
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// We would like to support type alias to f32/f64 in return type in the future
|
||||||
|
#[autodiff(df16, Reverse, Active, Active)]
|
||||||
|
fn f16(x: f32) -> MyFloat {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(transparent)]
|
||||||
|
struct F64Trans { inner: f64 }
|
||||||
|
|
||||||
|
// We would like to support `#[repr(transparent)]` f32/f64 wrapper in return type in the future
|
||||||
|
#[autodiff(df17, Reverse, Active, Active)]
|
||||||
|
fn f17(x: f64) -> F64Trans {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// We would like to support `#[repr(transparent)]` f32/f64 wrapper in argument type in the future
|
||||||
|
#[autodiff(df18, Reverse, Active, Active)]
|
||||||
|
fn f18(x: F64Trans) -> f64 {
|
||||||
|
//~^^ ERROR failed to resolve: use of undeclared type `F64Trans` [E0433]
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn main() {}
|
152
tests/ui/autodiff/autodiff_illegal.stderr
Normal file
152
tests/ui/autodiff/autodiff_illegal.stderr
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
error[E0658]: attributes on expressions are experimental
|
||||||
|
--> $DIR/autodiff_illegal.rs:54:5
|
||||||
|
|
|
||||||
|
LL | #[autodiff(df7, Forward, Dual)]
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
= note: see issue #15701 <https://github.com/rust-lang/rust/issues/15701> for more information
|
||||||
|
= help: add `#![feature(stmt_expr_attributes)]` to the crate attributes to enable
|
||||||
|
= note: this compiler was built on YYYY-MM-DD; consider upgrading it if it is out of date
|
||||||
|
|
||||||
|
error: Duplicated can not be used for this type
|
||||||
|
--> $DIR/autodiff_illegal.rs:14:14
|
||||||
|
|
|
||||||
|
LL | pub fn f1(x: f64) {
|
||||||
|
| ^^^
|
||||||
|
|
||||||
|
error: expected 1 activities, but found 2
|
||||||
|
--> $DIR/autodiff_illegal.rs:20:1
|
||||||
|
|
|
||||||
|
LL | #[autodiff(df3, Reverse, Duplicated, Const)]
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||||
|
|
||||||
|
error: expected 1 activities, but found 0
|
||||||
|
--> $DIR/autodiff_illegal.rs:27:1
|
||||||
|
|
|
||||||
|
LL | #[autodiff(df4, Reverse)]
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||||
|
|
||||||
|
error: Dual can not be used in Reverse Mode
|
||||||
|
--> $DIR/autodiff_illegal.rs:34:1
|
||||||
|
|
|
||||||
|
LL | #[autodiff(df5, Reverse, Dual)]
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||||
|
|
||||||
|
error: Duplicated can not be used in Forward Mode
|
||||||
|
--> $DIR/autodiff_illegal.rs:41:1
|
||||||
|
|
|
||||||
|
LL | #[autodiff(df6, Forward, Duplicated)]
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||||
|
|
||||||
|
error: Duplicated can not be used for this type
|
||||||
|
--> $DIR/autodiff_illegal.rs:42:14
|
||||||
|
|
|
||||||
|
LL | pub fn f6(x: f64) {
|
||||||
|
| ^^^
|
||||||
|
|
||||||
|
error: autodiff must be applied to function
|
||||||
|
--> $DIR/autodiff_illegal.rs:51:5
|
||||||
|
|
|
||||||
|
LL | let mut x = 5;
|
||||||
|
| ^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
error: autodiff must be applied to function
|
||||||
|
--> $DIR/autodiff_illegal.rs:55:5
|
||||||
|
|
|
||||||
|
LL | x = x + 3;
|
||||||
|
| ^
|
||||||
|
|
||||||
|
error: autodiff must be applied to function
|
||||||
|
--> $DIR/autodiff_illegal.rs:60:5
|
||||||
|
|
|
||||||
|
LL | let add_one_v2 = |x: u32| -> u32 { x + 1 };
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
error: autodiff must be applied to function
|
||||||
|
--> $DIR/autodiff_illegal.rs:66:1
|
||||||
|
|
|
||||||
|
LL | / pub fn f7(x: f64) {
|
||||||
|
LL | |
|
||||||
|
LL | | unimplemented!()
|
||||||
|
LL | | }
|
||||||
|
| |_^
|
||||||
|
|
||||||
|
error: autodiff requires at least a name and mode
|
||||||
|
--> $DIR/autodiff_illegal.rs:73:1
|
||||||
|
|
|
||||||
|
LL | / pub fn f8(x: f64) {
|
||||||
|
LL | |
|
||||||
|
LL | | unimplemented!()
|
||||||
|
LL | | }
|
||||||
|
| |_^
|
||||||
|
|
||||||
|
error: autodiff must be applied to function
|
||||||
|
--> $DIR/autodiff_illegal.rs:80:1
|
||||||
|
|
|
||||||
|
LL | / pub fn f9(x: f64) {
|
||||||
|
LL | |
|
||||||
|
LL | | unimplemented!()
|
||||||
|
LL | | }
|
||||||
|
| |_^
|
||||||
|
|
||||||
|
error[E0428]: the name `fn_exists` is defined multiple times
|
||||||
|
--> $DIR/autodiff_illegal.rs:88:1
|
||||||
|
|
|
||||||
|
LL | fn fn_exists() {}
|
||||||
|
| -------------- previous definition of the value `fn_exists` here
|
||||||
|
...
|
||||||
|
LL | #[autodiff(fn_exists, Reverse, Active)]
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `fn_exists` redefined here
|
||||||
|
|
|
||||||
|
= note: `fn_exists` must be defined only once in the value namespace of this module
|
||||||
|
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||||
|
|
||||||
|
error: autodiff requires at least a name and mode
|
||||||
|
--> $DIR/autodiff_illegal.rs:96:1
|
||||||
|
|
|
||||||
|
LL | / pub fn f11() {
|
||||||
|
LL | |
|
||||||
|
LL | | unimplemented!()
|
||||||
|
LL | | }
|
||||||
|
| |_^
|
||||||
|
|
||||||
|
error: unknown Mode: `Debug`. Use `Forward` or `Reverse`
|
||||||
|
--> $DIR/autodiff_illegal.rs:102:18
|
||||||
|
|
|
||||||
|
LL | #[autodiff(df12, Debug)]
|
||||||
|
| ^^^^^
|
||||||
|
|
||||||
|
error: did not recognize Activity: `Reverse`
|
||||||
|
--> $DIR/autodiff_illegal.rs:110:27
|
||||||
|
|
|
||||||
|
LL | #[autodiff(df13, Forward, Reverse)]
|
||||||
|
| ^^^^^^^
|
||||||
|
|
||||||
|
error[E0433]: failed to resolve: use of undeclared type `MyFloat`
|
||||||
|
--> $DIR/autodiff_illegal.rs:131:1
|
||||||
|
|
|
||||||
|
LL | #[autodiff(df15, Reverse, Active, Active)]
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `MyFloat`
|
||||||
|
|
|
||||||
|
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||||
|
|
||||||
|
error[E0433]: failed to resolve: use of undeclared type `F64Trans`
|
||||||
|
--> $DIR/autodiff_illegal.rs:153:1
|
||||||
|
|
|
||||||
|
LL | #[autodiff(df18, Reverse, Active, Active)]
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `F64Trans`
|
||||||
|
|
|
||||||
|
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||||
|
|
||||||
|
error: aborting due to 19 previous errors
|
||||||
|
|
||||||
|
Some errors have detailed explanations: E0428, E0433, E0658.
|
||||||
|
For more information about an error, try `rustc --explain E0428`.
|
12
tests/ui/autodiff/auxiliary/my_macro.rs
Normal file
12
tests/ui/autodiff/auxiliary/my_macro.rs
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
//@ force-host
|
||||||
|
//@ no-prefer-dynamic
|
||||||
|
#![crate_type = "proc-macro"]
|
||||||
|
|
||||||
|
extern crate proc_macro;
|
||||||
|
use proc_macro::TokenStream;
|
||||||
|
|
||||||
|
#[proc_macro_attribute]
|
||||||
|
#[macro_use]
|
||||||
|
pub fn autodiff(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||||
|
item // identity proc-macro
|
||||||
|
}
|
17
tests/ui/autodiff/visibility.rs
Normal file
17
tests/ui/autodiff/visibility.rs
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
//@ ignore-enzyme
|
||||||
|
//@ revisions: std_autodiff no_std_autodiff
|
||||||
|
//@[no_std_autodiff] check-pass
|
||||||
|
//@ aux-build: my_macro.rs
|
||||||
|
#![crate_type = "lib"]
|
||||||
|
#![feature(autodiff)]
|
||||||
|
|
||||||
|
#[cfg(std_autodiff)]
|
||||||
|
use std::autodiff::autodiff;
|
||||||
|
|
||||||
|
extern crate my_macro;
|
||||||
|
use my_macro::autodiff; // bring `autodiff` in scope
|
||||||
|
|
||||||
|
#[autodiff]
|
||||||
|
//[std_autodiff]~^^^ ERROR the name `autodiff` is defined multiple times
|
||||||
|
//[std_autodiff]~^^ ERROR this rustc version does not support autodiff
|
||||||
|
fn foo() {}
|
24
tests/ui/autodiff/visibility.std_autodiff.stderr
Normal file
24
tests/ui/autodiff/visibility.std_autodiff.stderr
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
error[E0252]: the name `autodiff` is defined multiple times
|
||||||
|
--> $DIR/visibility.rs:12:5
|
||||||
|
|
|
||||||
|
LL | use std::autodiff::autodiff;
|
||||||
|
| ----------------------- previous import of the macro `autodiff` here
|
||||||
|
...
|
||||||
|
LL | use my_macro::autodiff; // bring `autodiff` in scope
|
||||||
|
| ^^^^^^^^^^^^^^^^^^ `autodiff` reimported here
|
||||||
|
|
|
||||||
|
= note: `autodiff` must be defined only once in the macro namespace of this module
|
||||||
|
help: you can use `as` to change the binding name of the import
|
||||||
|
|
|
||||||
|
LL | use my_macro::autodiff as other_autodiff; // bring `autodiff` in scope
|
||||||
|
| +++++++++++++++++
|
||||||
|
|
||||||
|
error: this rustc version does not support autodiff
|
||||||
|
--> $DIR/visibility.rs:14:1
|
||||||
|
|
|
||||||
|
LL | #[autodiff]
|
||||||
|
| ^^^^^^^^^^^
|
||||||
|
|
||||||
|
error: aborting due to 2 previous errors
|
||||||
|
|
||||||
|
For more information about this error, try `rustc --explain E0252`.
|
@ -0,0 +1,23 @@
|
|||||||
|
error[E0658]: use of unstable library feature 'autodiff'
|
||||||
|
--> $DIR/feature-gate-autodiff-use.rs:13:3
|
||||||
|
|
|
||||||
|
LL | #[autodiff(dfoo, Reverse)]
|
||||||
|
| ^^^^^^^^
|
||||||
|
|
|
||||||
|
= note: see issue #124509 <https://github.com/rust-lang/rust/issues/124509> for more information
|
||||||
|
= help: add `#![feature(autodiff)]` to the crate attributes to enable
|
||||||
|
= note: this compiler was built on YYYY-MM-DD; consider upgrading it if it is out of date
|
||||||
|
|
||||||
|
error[E0658]: use of unstable library feature 'autodiff'
|
||||||
|
--> $DIR/feature-gate-autodiff-use.rs:9:5
|
||||||
|
|
|
||||||
|
LL | use std::autodiff::autodiff;
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
= note: see issue #124509 <https://github.com/rust-lang/rust/issues/124509> for more information
|
||||||
|
= help: add `#![feature(autodiff)]` to the crate attributes to enable
|
||||||
|
= note: this compiler was built on YYYY-MM-DD; consider upgrading it if it is out of date
|
||||||
|
|
||||||
|
error: aborting due to 2 previous errors
|
||||||
|
|
||||||
|
For more information about this error, try `rustc --explain E0658`.
|
@ -0,0 +1,29 @@
|
|||||||
|
error[E0658]: use of unstable library feature 'autodiff'
|
||||||
|
--> $DIR/feature-gate-autodiff-use.rs:13:3
|
||||||
|
|
|
||||||
|
LL | #[autodiff(dfoo, Reverse)]
|
||||||
|
| ^^^^^^^^
|
||||||
|
|
|
||||||
|
= note: see issue #124509 <https://github.com/rust-lang/rust/issues/124509> for more information
|
||||||
|
= help: add `#![feature(autodiff)]` to the crate attributes to enable
|
||||||
|
= note: this compiler was built on YYYY-MM-DD; consider upgrading it if it is out of date
|
||||||
|
|
||||||
|
error: this rustc version does not support autodiff
|
||||||
|
--> $DIR/feature-gate-autodiff-use.rs:13:1
|
||||||
|
|
|
||||||
|
LL | #[autodiff(dfoo, Reverse)]
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
error[E0658]: use of unstable library feature 'autodiff'
|
||||||
|
--> $DIR/feature-gate-autodiff-use.rs:9:5
|
||||||
|
|
|
||||||
|
LL | use std::autodiff::autodiff;
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
= note: see issue #124509 <https://github.com/rust-lang/rust/issues/124509> for more information
|
||||||
|
= help: add `#![feature(autodiff)]` to the crate attributes to enable
|
||||||
|
= note: this compiler was built on YYYY-MM-DD; consider upgrading it if it is out of date
|
||||||
|
|
||||||
|
error: aborting due to 3 previous errors
|
||||||
|
|
||||||
|
For more information about this error, try `rustc --explain E0658`.
|
17
tests/ui/feature-gates/feature-gate-autodiff-use.rs
Normal file
17
tests/ui/feature-gates/feature-gate-autodiff-use.rs
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
//@ revisions: has_support no_support
|
||||||
|
//@[no_support] ignore-enzyme
|
||||||
|
//@[has_support] needs-enzyme
|
||||||
|
|
||||||
|
// This checks that without enabling the autodiff feature, we can't import std::autodiff::autodiff;
|
||||||
|
|
||||||
|
#![crate_type = "lib"]
|
||||||
|
|
||||||
|
use std::autodiff::autodiff;
|
||||||
|
//[has_support]~^ ERROR use of unstable library feature 'autodiff'
|
||||||
|
//[no_support]~^^ ERROR use of unstable library feature 'autodiff'
|
||||||
|
|
||||||
|
#[autodiff(dfoo, Reverse)]
|
||||||
|
//[has_support]~^ ERROR use of unstable library feature 'autodiff' [E0658]
|
||||||
|
//[no_support]~^^ ERROR use of unstable library feature 'autodiff' [E0658]
|
||||||
|
//[no_support]~| ERROR this rustc version does not support autodiff
|
||||||
|
fn foo() {}
|
@ -0,0 +1,13 @@
|
|||||||
|
error: cannot find attribute `autodiff` in this scope
|
||||||
|
--> $DIR/feature-gate-autodiff.rs:9:3
|
||||||
|
|
|
||||||
|
LL | #[autodiff(dfoo, Reverse)]
|
||||||
|
| ^^^^^^^^
|
||||||
|
|
|
||||||
|
help: consider importing this attribute macro
|
||||||
|
|
|
||||||
|
LL + use std::autodiff::autodiff;
|
||||||
|
|
|
||||||
|
|
||||||
|
error: aborting due to 1 previous error
|
||||||
|
|
@ -0,0 +1,13 @@
|
|||||||
|
error: cannot find attribute `autodiff` in this scope
|
||||||
|
--> $DIR/feature-gate-autodiff.rs:9:3
|
||||||
|
|
|
||||||
|
LL | #[autodiff(dfoo, Reverse)]
|
||||||
|
| ^^^^^^^^
|
||||||
|
|
|
||||||
|
help: consider importing this attribute macro
|
||||||
|
|
|
||||||
|
LL + use std::autodiff::autodiff;
|
||||||
|
|
|
||||||
|
|
||||||
|
error: aborting due to 1 previous error
|
||||||
|
|
12
tests/ui/feature-gates/feature-gate-autodiff.rs
Normal file
12
tests/ui/feature-gates/feature-gate-autodiff.rs
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
//@ revisions: has_support no_support
|
||||||
|
//@[no_support] ignore-enzyme
|
||||||
|
//@[has_support] needs-enzyme
|
||||||
|
|
||||||
|
#![crate_type = "lib"]
|
||||||
|
|
||||||
|
// This checks that without the autodiff feature enabled, we can't use it.
|
||||||
|
|
||||||
|
#[autodiff(dfoo, Reverse)]
|
||||||
|
//[has_support]~^ ERROR cannot find attribute `autodiff` in this scope
|
||||||
|
//[no_support]~^^ ERROR cannot find attribute `autodiff` in this scope
|
||||||
|
fn foo() {}
|
Loading…
Reference in New Issue
Block a user