mirror of
https://github.com/rust-lang/rust.git
synced 2025-04-28 02:57:37 +00:00
Rollup merge of #138314 - haenoe:autodiff-inner-function, r=ZuseZ4
fix usage of `autodiff` macro with inner functions This PR adds additional handling into the expansion step of the `std::autodiff` macro (in `compiler/rustc_builtin_macros/src/autodiff.rs`), which allows the macro to be applied to inner functions. ```rust #![feature(autodiff)] use std::autodiff::autodiff; fn main() { #[autodiff(d_inner, Forward, Dual, DualOnly)] fn inner(x: f32) -> f32 { x * x } } ``` Previously, the compiler didn't allow this due to only handling `Annotatable::Item` and `Annotatable::AssocItem` and missing the handling of `Annotatable::Stmt`. This resulted in the rather generic error ``` error: autodiff must be applied to function --> src/main.rs:6:5 | 6 | / fn inner(x: f32) -> f32 { 7 | | x * x 8 | | } | |_____^ error: could not compile `enzyme-test` (bin "enzyme-test") due to 1 previous error ``` This issue was originally reported [here](https://github.com/EnzymeAD/rust/issues/184). Quick question: would it make sense to add a ui test to ensure there is no regression on this? This is my first contribution, so I'm extra grateful for any piece of feedback!! :D r? `@oli-obk` Tracking issue for autodiff: #124509
This commit is contained in:
commit
5a43d92382
@ -17,7 +17,7 @@ mod llvm_enzyme {
|
||||
use rustc_ast::visit::AssocCtxt::*;
|
||||
use rustc_ast::{
|
||||
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
|
||||
MetaItemInner, PatKind, QSelf, TyKind,
|
||||
MetaItemInner, PatKind, QSelf, TyKind, Visibility,
|
||||
};
|
||||
use rustc_expand::base::{Annotatable, ExtCtxt};
|
||||
use rustc_span::{Ident, Span, Symbol, kw, sym};
|
||||
@ -72,6 +72,16 @@ mod llvm_enzyme {
|
||||
}
|
||||
}
|
||||
|
||||
// Get information about the function the macro is applied to
|
||||
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
|
||||
match &iitem.kind {
|
||||
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
|
||||
Some((iitem.vis.clone(), sig.clone(), ident.clone()))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_ast(
|
||||
ecx: &mut ExtCtxt<'_>,
|
||||
meta_item: &ThinVec<MetaItemInner>,
|
||||
@ -199,32 +209,26 @@ mod llvm_enzyme {
|
||||
return vec![item];
|
||||
}
|
||||
let dcx = ecx.sess.dcx();
|
||||
// first get the annotable item:
|
||||
let (primal, sig, is_impl): (Ident, FnSig, bool) = match &item {
|
||||
Annotatable::Item(iitem) => {
|
||||
let (ident, sig) = match &iitem.kind {
|
||||
ItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
|
||||
_ => {
|
||||
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||
return vec![item];
|
||||
}
|
||||
};
|
||||
(*ident, sig.clone(), false)
|
||||
}
|
||||
|
||||
// first get information about the annotable item:
|
||||
let Some((vis, sig, primal)) = (match &item {
|
||||
Annotatable::Item(iitem) => extract_item_info(iitem),
|
||||
Annotatable::Stmt(stmt) => match &stmt.kind {
|
||||
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
|
||||
_ => None,
|
||||
},
|
||||
Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
|
||||
let (ident, sig) = match &assoc_item.kind {
|
||||
ast::AssocItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
|
||||
_ => {
|
||||
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||
return vec![item];
|
||||
match &assoc_item.kind {
|
||||
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
|
||||
Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
|
||||
}
|
||||
};
|
||||
(*ident, sig.clone(), true)
|
||||
}
|
||||
_ => {
|
||||
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||
return vec![item];
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}) else {
|
||||
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||
return vec![item];
|
||||
};
|
||||
|
||||
let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
|
||||
@ -238,15 +242,6 @@ mod llvm_enzyme {
|
||||
let has_ret = has_ret(&sig.decl.output);
|
||||
let sig_span = ecx.with_call_site_ctxt(sig.span);
|
||||
|
||||
let vis = match &item {
|
||||
Annotatable::Item(iitem) => iitem.vis.clone(),
|
||||
Annotatable::AssocItem(assoc_item, _) => assoc_item.vis.clone(),
|
||||
_ => {
|
||||
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||
return vec![item];
|
||||
}
|
||||
};
|
||||
|
||||
// create TokenStream from vec elemtents:
|
||||
// meta_item doesn't have a .tokens field
|
||||
let mut ts: Vec<TokenTree> = vec![];
|
||||
@ -379,6 +374,22 @@ mod llvm_enzyme {
|
||||
}
|
||||
Annotatable::AssocItem(assoc_item.clone(), i)
|
||||
}
|
||||
Annotatable::Stmt(ref mut stmt) => {
|
||||
match stmt.kind {
|
||||
ast::StmtKind::Item(ref mut iitem) => {
|
||||
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
|
||||
iitem.attrs.push(attr);
|
||||
}
|
||||
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
|
||||
{
|
||||
iitem.attrs.push(inline_never.clone());
|
||||
}
|
||||
}
|
||||
_ => unreachable!("stmt kind checked previously"),
|
||||
};
|
||||
|
||||
Annotatable::Stmt(stmt.clone())
|
||||
}
|
||||
_ => {
|
||||
unreachable!("annotatable kind checked previously")
|
||||
}
|
||||
@ -389,22 +400,40 @@ mod llvm_enzyme {
|
||||
delim: rustc_ast::token::Delimiter::Parenthesis,
|
||||
tokens: ts,
|
||||
});
|
||||
|
||||
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
|
||||
let d_annotatable = if is_impl {
|
||||
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
|
||||
let d_fn = P(ast::AssocItem {
|
||||
attrs: thin_vec![d_attr, inline_never],
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
span,
|
||||
vis,
|
||||
kind: assoc_item,
|
||||
tokens: None,
|
||||
});
|
||||
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
|
||||
} else {
|
||||
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
|
||||
d_fn.vis = vis;
|
||||
Annotatable::Item(d_fn)
|
||||
let d_annotatable = match &item {
|
||||
Annotatable::AssocItem(_, _) => {
|
||||
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
|
||||
let d_fn = P(ast::AssocItem {
|
||||
attrs: thin_vec![d_attr, inline_never],
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
span,
|
||||
vis,
|
||||
kind: assoc_item,
|
||||
tokens: None,
|
||||
});
|
||||
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
|
||||
}
|
||||
Annotatable::Item(_) => {
|
||||
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
|
||||
d_fn.vis = vis;
|
||||
|
||||
Annotatable::Item(d_fn)
|
||||
}
|
||||
Annotatable::Stmt(_) => {
|
||||
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
|
||||
d_fn.vis = vis;
|
||||
|
||||
Annotatable::Stmt(P(ast::Stmt {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
kind: ast::StmtKind::Item(d_fn),
|
||||
span,
|
||||
}))
|
||||
}
|
||||
_ => {
|
||||
unreachable!("item kind checked previously")
|
||||
}
|
||||
};
|
||||
|
||||
return vec![orig_annotatable, d_annotatable];
|
||||
|
@ -29,6 +29,8 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
|
||||
// Make sure, that we add the None for the default return.
|
||||
|
||||
|
||||
// We want to make sure that we can use the macro for functions defined inside of functions
|
||||
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
|
||||
@ -158,4 +160,25 @@ fn f8_1(x: &f32, bx_0: &f32) -> f32 {
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(<f32>::default())
|
||||
}
|
||||
pub fn f9() {
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
fn inner(x: f32) -> f32 { x * x }
|
||||
#[rustc_autodiff(Forward, 1, Dual, Dual)]
|
||||
#[inline(never)]
|
||||
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(inner(x));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(<(f32, f32)>::default())
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
|
||||
#[inline(never)]
|
||||
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(inner(x));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(<f32>::default())
|
||||
}
|
||||
}
|
||||
fn main() {}
|
||||
|
@ -54,4 +54,13 @@ fn f8(x: &f32) -> f32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
// We want to make sure that we can use the macro for functions defined inside of functions
|
||||
pub fn f9() {
|
||||
#[autodiff(d_inner_1, Forward, Dual, DualOnly)]
|
||||
#[autodiff(d_inner_2, Forward, Dual, Dual)]
|
||||
fn inner(x: f32) -> f32 {
|
||||
x * x
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
|
Loading…
Reference in New Issue
Block a user