attr: add a pre-codegen #[spirv(...)] attribute checking pass.

This commit is contained in:
Eduard-Mihai Burtescu 2021-03-02 05:23:49 +02:00 committed by Eduard-Mihai Burtescu
parent 49509e3ccb
commit 71254b48fa
3 changed files with 340 additions and 17 deletions

View File

@ -0,0 +1,312 @@
//! `#[spirv(...)]` attribute support.
//!
//! The attribute-checking parts of this try to follow `rustc_passes::check_attr`.
use crate::symbols::{SpirvAttribute, Symbols};
use rustc_ast::Attribute;
use rustc_hir as hir;
use rustc_hir::def_id::LocalDefId;
use rustc_hir::intravisit::{self, NestedVisitorMap, Visitor};
use rustc_hir::{HirId, MethodKind, Target, CRATE_HIR_ID};
use rustc_middle::hir::map::Map;
use rustc_middle::ty::query::Providers;
use rustc_middle::ty::TyCtxt;
use std::fmt;
use std::rc::Rc;
// FIXME(eddyb) make this reusable from somewhere in `rustc`.
pub(crate) fn target_from_impl_item<'tcx>(
tcx: TyCtxt<'tcx>,
impl_item: &hir::ImplItem<'_>,
) -> Target {
match impl_item.kind {
hir::ImplItemKind::Const(..) => Target::AssocConst,
hir::ImplItemKind::Fn(..) => {
let parent_hir_id = tcx.hir().get_parent_item(impl_item.hir_id);
let containing_item = tcx.hir().expect_item(parent_hir_id);
let containing_impl_is_for_trait = match &containing_item.kind {
hir::ItemKind::Impl { of_trait, .. } => of_trait.is_some(),
_ => unreachable!("parent of an ImplItem must be an Impl"),
};
if containing_impl_is_for_trait {
Target::Method(MethodKind::Trait { body: true })
} else {
Target::Method(MethodKind::Inherent)
}
}
hir::ImplItemKind::TyAlias(..) => Target::AssocTy,
}
}
// HACK(eddyb) current `Target` (after rust-lang/rust#80641 + rust-lang/rust#80920),
// emulated before we can rustup to that point and use the new variants directly.
enum TargetNew {
Old(Target),
// Added by rust-lang/rust#80641.
Field,
Arm,
MacroDef,
// Added by rust-lang/rust#80920.
Param,
}
impl From<Target> for TargetNew {
fn from(target: Target) -> Self {
Self::Old(target)
}
}
impl fmt::Display for TargetNew {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let description = match self {
Self::Old(target) => return write!(f, "{}", target),
Self::Field => "struct field",
Self::Arm => "match arm",
Self::MacroDef => "macro def",
Self::Param => "function param",
};
f.write_str(description)
}
}
struct CheckSpirvAttrVisitor<'tcx> {
tcx: TyCtxt<'tcx>,
sym: Rc<Symbols>,
}
impl CheckSpirvAttrVisitor<'_> {
fn check_spirv_attributes(
&self,
hir_id: HirId,
attrs: &[Attribute],
target: impl Into<TargetNew>,
) {
let parse_attrs = |attrs| crate::symbols::parse_attrs_for_checking(&self.sym, attrs);
let target = target.into();
for parse_attr_result in parse_attrs(attrs) {
let (span, attr) = match parse_attr_result {
Ok(span_and_attr) => span_and_attr,
Err((span, msg)) => {
self.tcx.sess.span_err(span, &msg);
continue;
}
};
/// Error newtype marker used below for readability.
struct Expected<T>(T);
let valid_target = match attr {
SpirvAttribute::Builtin(_)
| SpirvAttribute::DescriptorSet(_)
| SpirvAttribute::Binding(_)
| SpirvAttribute::Flat => match target {
TargetNew::Param => {
let parent_hir_id = self.tcx.hir().get_parent_node(hir_id);
let parent_is_entry_point =
parse_attrs(self.tcx.hir().attrs(parent_hir_id))
.filter_map(|r| r.ok())
.any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
if !parent_is_entry_point {
self.tcx.sess.span_err(
span,
"attribute is only valid on a parameter of an entry-point function",
);
}
Ok(())
}
_ => Err(Expected("function parameter")),
},
SpirvAttribute::Entry(_) => match target {
TargetNew::Old(Target::Fn)
| TargetNew::Old(Target::Method(MethodKind::Trait { body: true }))
| TargetNew::Old(Target::Method(MethodKind::Inherent)) => {
// FIXME(eddyb) further check entry-point attribute validity,
// e.g. signature, shouldn't have `#[inline]` or generics, etc.
Ok(())
}
_ => Err(Expected("function")),
},
SpirvAttribute::UnrollLoops => match target {
TargetNew::Old(Target::Fn)
| TargetNew::Old(Target::Closure)
| TargetNew::Old(Target::Method(MethodKind::Trait { body: true }))
| TargetNew::Old(Target::Method(MethodKind::Inherent)) => Ok(()),
_ => Err(Expected("function or closure")),
},
SpirvAttribute::StorageClass(_)
| SpirvAttribute::ImageType { .. }
| SpirvAttribute::Sampler
| SpirvAttribute::SampledImage
| SpirvAttribute::Block => match target {
TargetNew::Old(Target::Struct) => {
// FIXME(eddyb) further check type attribute validity,
// e.g. layout, generics, other attributes, etc.
Ok(())
}
_ => Err(Expected("struct")),
},
};
match valid_target {
Ok(()) => {}
Err(Expected(expected_target)) => self.tcx.sess.span_err(
span,
&format!(
"attribute is only valid on a {}, not on a {}",
expected_target, target
),
),
}
}
}
}
// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> {
type Map = Map<'tcx>;
fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
NestedVisitorMap::OnlyBodies(self.tcx.hir())
}
fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) {
let target = Target::from_item(item);
self.check_spirv_attributes(item.hir_id, item.attrs, target);
intravisit::walk_item(self, item)
}
fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) {
let target = Target::from_generic_param(generic_param);
self.check_spirv_attributes(generic_param.hir_id, generic_param.attrs, target);
intravisit::walk_generic_param(self, generic_param)
}
fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) {
let target = Target::from_trait_item(trait_item);
self.check_spirv_attributes(trait_item.hir_id, trait_item.attrs, target);
intravisit::walk_trait_item(self, trait_item)
}
fn visit_struct_field(&mut self, struct_field: &'tcx hir::StructField<'tcx>) {
self.check_spirv_attributes(struct_field.hir_id, struct_field.attrs, TargetNew::Field);
intravisit::walk_struct_field(self, struct_field);
}
fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) {
self.check_spirv_attributes(arm.hir_id, arm.attrs, TargetNew::Arm);
intravisit::walk_arm(self, arm);
}
fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) {
let target = Target::from_foreign_item(f_item);
self.check_spirv_attributes(f_item.hir_id, f_item.attrs, target);
intravisit::walk_foreign_item(self, f_item)
}
fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) {
let target = target_from_impl_item(self.tcx, impl_item);
self.check_spirv_attributes(impl_item.hir_id, impl_item.attrs, target);
intravisit::walk_impl_item(self, impl_item)
}
fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) {
// When checking statements ignore expressions, they will be checked later.
if let hir::StmtKind::Local(l) = stmt.kind {
self.check_spirv_attributes(l.hir_id, &l.attrs, Target::Statement);
}
intravisit::walk_stmt(self, stmt)
}
fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
let target = match expr.kind {
hir::ExprKind::Closure(..) => Target::Closure,
_ => Target::Expression,
};
self.check_spirv_attributes(expr.hir_id, &expr.attrs, target);
intravisit::walk_expr(self, expr)
}
fn visit_variant(
&mut self,
variant: &'tcx hir::Variant<'tcx>,
generics: &'tcx hir::Generics<'tcx>,
item_id: HirId,
) {
self.check_spirv_attributes(variant.id, variant.attrs, Target::Variant);
intravisit::walk_variant(self, variant, generics, item_id)
}
fn visit_macro_def(&mut self, macro_def: &'tcx hir::MacroDef<'tcx>) {
self.check_spirv_attributes(macro_def.hir_id, macro_def.attrs, TargetNew::MacroDef);
intravisit::walk_macro_def(self, macro_def);
}
fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) {
self.check_spirv_attributes(param.hir_id, param.attrs, TargetNew::Param);
intravisit::walk_param(self, param);
}
}
fn check_invalid_macro_level_spirv_attr(tcx: TyCtxt<'_>, sym: &Symbols, attrs: &[Attribute]) {
for attr in attrs {
if tcx.sess.check_name(attr, sym.spirv) {
tcx.sess
.span_err(attr.span, "#[spirv(..)] cannot be applied to a macro");
}
}
}
// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalDefId) {
let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor {
tcx,
sym: Symbols::get(),
};
tcx.hir().visit_item_likes_in_module(
module_def_id,
&mut check_spirv_attr_visitor.as_deep_visitor(),
);
// FIXME(eddyb) use `tcx.hir().visit_exported_macros_in_krate(...)` after rustup.
for id in tcx.hir().krate().exported_macros {
check_spirv_attr_visitor.visit_macro_def(match tcx.hir().find(id.hir_id) {
Some(hir::Node::MacroDef(macro_def)) => macro_def,
_ => unreachable!(),
});
}
check_invalid_macro_level_spirv_attr(
tcx,
&check_spirv_attr_visitor.sym,
tcx.hir().krate().non_exported_macro_attrs,
);
if module_def_id.is_top_level_module() {
check_spirv_attr_visitor.check_spirv_attributes(
CRATE_HIR_ID,
tcx.hir().krate_attrs(),
Target::Mod,
);
}
}
pub(crate) fn provide(providers: &mut Providers) {
*providers = Providers {
check_mod_attrs: |tcx, def_id| {
// Run both the default checks, and our `#[spirv(...)]` ones.
(rustc_interface::DEFAULT_QUERY_PROVIDERS.check_mod_attrs)(tcx, def_id);
check_mod_attrs(tcx, def_id)
},
..*providers
};
}

View File

@ -81,6 +81,7 @@ macro_rules! assert_ty_eq {
}
mod abi;
mod attr;
mod builder;
mod builder_spirv;
mod codegen_cx;
@ -299,6 +300,9 @@ impl CodegenBackend for SpirvCodegenBackend {
inner
})
};
// Extra hooks provided by other parts of `rustc_codegen_spirv`.
crate::attr::provide(providers);
}
fn provide_extern(&self, providers: &mut query::Providers) {

View File

@ -448,6 +448,7 @@ impl From<ExecutionModel> for Entry {
}
}
// FIXME(eddyb) maybe move this to `attr`?
#[derive(Debug, Clone)]
pub enum SpirvAttribute {
Builtin(BuiltIn),
@ -471,28 +472,32 @@ pub enum SpirvAttribute {
UnrollLoops,
}
// Note that we could mark the attr as used via cx.tcx.sess.mark_attr_used(attr), but unused
// reporting already happens even before we get here :(
// FIXME(eddyb) maybe move this to `attr`?
/// Returns only the spirv attributes that could successfully parsed.
/// For any malformed ones, an error is reported.
/// For any malformed ones, an error is reported prior to codegen, by a check pass.
pub fn parse_attrs<'a, 'tcx>(
cx: &'a CodegenCx<'tcx>,
attrs: &'tcx [Attribute],
) -> impl Iterator<Item = SpirvAttribute> + Captures<'tcx> + 'a {
parse_attrs_with_errors(&cx.sym, attrs).filter_map(move |parse_attr_result| {
parse_attrs_for_checking(&cx.sym, attrs)
.filter_map(move |parse_attr_result| {
// NOTE(eddyb) `delay_span_bug` ensures that if attribute checking fails
// to see an attribute error, it will cause an ICE instead.
parse_attr_result
.map_err(|(span, msg)| cx.tcx.sess.span_err(span, &msg))
.map_err(|(span, msg)| cx.tcx.sess.delay_span_bug(span, &msg))
.ok()
})
.map(|(_span, parsed_attr)| parsed_attr)
}
// FIXME(eddyb) find something nicer for the error type.
type ParseAttrError = (Span, String);
fn parse_attrs_with_errors<'a>(
// FIXME(eddyb) maybe move this to `attr`?
pub(crate) fn parse_attrs_for_checking<'a>(
sym: &'a Symbols,
attrs: &'a [Attribute],
) -> impl Iterator<Item = Result<SpirvAttribute, ParseAttrError>> + 'a {
) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'a {
attrs.iter().flat_map(move |attr| {
let is_spirv = match attr.kind {
AttrKind::Normal(ref item, _) => {
@ -519,18 +524,19 @@ fn parse_attrs_with_errors<'a>(
whole_attr_error
.into_iter()
.chain(args.into_iter().map(move |ref arg| {
if arg.has_name(sym.image_type) {
parse_image_type(sym, arg)
let span = arg.span();
let parsed_attr = if arg.has_name(sym.image_type) {
parse_image_type(sym, arg)?
} else if arg.has_name(sym.descriptor_set) {
Ok(SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?))
SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.binding) {
Ok(SpirvAttribute::Binding(parse_attr_int_value(arg)?))
SpirvAttribute::Binding(parse_attr_int_value(arg)?)
} else {
let name = match arg.ident() {
Some(i) => i,
None => {
return Err((
arg.span(),
span,
"#[spirv(..)] attribute argument must be single identifier"
.to_string(),
));
@ -548,8 +554,9 @@ fn parse_attrs_with_errors<'a>(
})
.unwrap_or_else(|| {
Err((name.span, "unknown argument to spirv attribute".to_string()))
})
}
})?
};
Ok((span, parsed_attr))
}))
})
}