diff --git a/crates/rustc_codegen_spirv/src/attr.rs b/crates/rustc_codegen_spirv/src/attr.rs new file mode 100644 index 0000000000..218c0a012c --- /dev/null +++ b/crates/rustc_codegen_spirv/src/attr.rs @@ -0,0 +1,312 @@ +//! `#[spirv(...)]` attribute support. +//! +//! The attribute-checking parts of this try to follow `rustc_passes::check_attr`. + +use crate::symbols::{SpirvAttribute, Symbols}; +use rustc_ast::Attribute; +use rustc_hir as hir; +use rustc_hir::def_id::LocalDefId; +use rustc_hir::intravisit::{self, NestedVisitorMap, Visitor}; +use rustc_hir::{HirId, MethodKind, Target, CRATE_HIR_ID}; +use rustc_middle::hir::map::Map; +use rustc_middle::ty::query::Providers; +use rustc_middle::ty::TyCtxt; +use std::fmt; +use std::rc::Rc; + +// FIXME(eddyb) make this reusable from somewhere in `rustc`. +pub(crate) fn target_from_impl_item<'tcx>( + tcx: TyCtxt<'tcx>, + impl_item: &hir::ImplItem<'_>, +) -> Target { + match impl_item.kind { + hir::ImplItemKind::Const(..) => Target::AssocConst, + hir::ImplItemKind::Fn(..) => { + let parent_hir_id = tcx.hir().get_parent_item(impl_item.hir_id); + let containing_item = tcx.hir().expect_item(parent_hir_id); + let containing_impl_is_for_trait = match &containing_item.kind { + hir::ItemKind::Impl { of_trait, .. } => of_trait.is_some(), + _ => unreachable!("parent of an ImplItem must be an Impl"), + }; + if containing_impl_is_for_trait { + Target::Method(MethodKind::Trait { body: true }) + } else { + Target::Method(MethodKind::Inherent) + } + } + hir::ImplItemKind::TyAlias(..) => Target::AssocTy, + } +} + +// HACK(eddyb) current `Target` (after rust-lang/rust#80641 + rust-lang/rust#80920), +// emulated before we can rustup to that point and use the new variants directly. +enum TargetNew { + Old(Target), + + // Added by rust-lang/rust#80641. + Field, + Arm, + MacroDef, + + // Added by rust-lang/rust#80920. + Param, +} + +impl From for TargetNew { + fn from(target: Target) -> Self { + Self::Old(target) + } +} + +impl fmt::Display for TargetNew { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let description = match self { + Self::Old(target) => return write!(f, "{}", target), + + Self::Field => "struct field", + Self::Arm => "match arm", + Self::MacroDef => "macro def", + + Self::Param => "function param", + }; + f.write_str(description) + } +} + +struct CheckSpirvAttrVisitor<'tcx> { + tcx: TyCtxt<'tcx>, + sym: Rc, +} + +impl CheckSpirvAttrVisitor<'_> { + fn check_spirv_attributes( + &self, + hir_id: HirId, + attrs: &[Attribute], + target: impl Into, + ) { + let parse_attrs = |attrs| crate::symbols::parse_attrs_for_checking(&self.sym, attrs); + + let target = target.into(); + for parse_attr_result in parse_attrs(attrs) { + let (span, attr) = match parse_attr_result { + Ok(span_and_attr) => span_and_attr, + Err((span, msg)) => { + self.tcx.sess.span_err(span, &msg); + continue; + } + }; + + /// Error newtype marker used below for readability. + struct Expected(T); + + let valid_target = match attr { + SpirvAttribute::Builtin(_) + | SpirvAttribute::DescriptorSet(_) + | SpirvAttribute::Binding(_) + | SpirvAttribute::Flat => match target { + TargetNew::Param => { + let parent_hir_id = self.tcx.hir().get_parent_node(hir_id); + let parent_is_entry_point = + parse_attrs(self.tcx.hir().attrs(parent_hir_id)) + .filter_map(|r| r.ok()) + .any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_))); + if !parent_is_entry_point { + self.tcx.sess.span_err( + span, + "attribute is only valid on a parameter of an entry-point function", + ); + } + Ok(()) + } + + _ => Err(Expected("function parameter")), + }, + + SpirvAttribute::Entry(_) => match target { + TargetNew::Old(Target::Fn) + | TargetNew::Old(Target::Method(MethodKind::Trait { body: true })) + | TargetNew::Old(Target::Method(MethodKind::Inherent)) => { + // FIXME(eddyb) further check entry-point attribute validity, + // e.g. signature, shouldn't have `#[inline]` or generics, etc. + Ok(()) + } + + _ => Err(Expected("function")), + }, + + SpirvAttribute::UnrollLoops => match target { + TargetNew::Old(Target::Fn) + | TargetNew::Old(Target::Closure) + | TargetNew::Old(Target::Method(MethodKind::Trait { body: true })) + | TargetNew::Old(Target::Method(MethodKind::Inherent)) => Ok(()), + + _ => Err(Expected("function or closure")), + }, + + SpirvAttribute::StorageClass(_) + | SpirvAttribute::ImageType { .. } + | SpirvAttribute::Sampler + | SpirvAttribute::SampledImage + | SpirvAttribute::Block => match target { + TargetNew::Old(Target::Struct) => { + // FIXME(eddyb) further check type attribute validity, + // e.g. layout, generics, other attributes, etc. + Ok(()) + } + + _ => Err(Expected("struct")), + }, + }; + match valid_target { + Ok(()) => {} + Err(Expected(expected_target)) => self.tcx.sess.span_err( + span, + &format!( + "attribute is only valid on a {}, not on a {}", + expected_target, target + ), + ), + } + } + } +} + +// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`. +impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> { + type Map = Map<'tcx>; + + fn nested_visit_map(&mut self) -> NestedVisitorMap { + NestedVisitorMap::OnlyBodies(self.tcx.hir()) + } + + fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) { + let target = Target::from_item(item); + self.check_spirv_attributes(item.hir_id, item.attrs, target); + intravisit::walk_item(self, item) + } + + fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) { + let target = Target::from_generic_param(generic_param); + self.check_spirv_attributes(generic_param.hir_id, generic_param.attrs, target); + intravisit::walk_generic_param(self, generic_param) + } + + fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) { + let target = Target::from_trait_item(trait_item); + self.check_spirv_attributes(trait_item.hir_id, trait_item.attrs, target); + intravisit::walk_trait_item(self, trait_item) + } + + fn visit_struct_field(&mut self, struct_field: &'tcx hir::StructField<'tcx>) { + self.check_spirv_attributes(struct_field.hir_id, struct_field.attrs, TargetNew::Field); + intravisit::walk_struct_field(self, struct_field); + } + + fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) { + self.check_spirv_attributes(arm.hir_id, arm.attrs, TargetNew::Arm); + intravisit::walk_arm(self, arm); + } + + fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) { + let target = Target::from_foreign_item(f_item); + self.check_spirv_attributes(f_item.hir_id, f_item.attrs, target); + intravisit::walk_foreign_item(self, f_item) + } + + fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) { + let target = target_from_impl_item(self.tcx, impl_item); + self.check_spirv_attributes(impl_item.hir_id, impl_item.attrs, target); + intravisit::walk_impl_item(self, impl_item) + } + + fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) { + // When checking statements ignore expressions, they will be checked later. + if let hir::StmtKind::Local(l) = stmt.kind { + self.check_spirv_attributes(l.hir_id, &l.attrs, Target::Statement); + } + intravisit::walk_stmt(self, stmt) + } + + fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) { + let target = match expr.kind { + hir::ExprKind::Closure(..) => Target::Closure, + _ => Target::Expression, + }; + + self.check_spirv_attributes(expr.hir_id, &expr.attrs, target); + intravisit::walk_expr(self, expr) + } + + fn visit_variant( + &mut self, + variant: &'tcx hir::Variant<'tcx>, + generics: &'tcx hir::Generics<'tcx>, + item_id: HirId, + ) { + self.check_spirv_attributes(variant.id, variant.attrs, Target::Variant); + intravisit::walk_variant(self, variant, generics, item_id) + } + + fn visit_macro_def(&mut self, macro_def: &'tcx hir::MacroDef<'tcx>) { + self.check_spirv_attributes(macro_def.hir_id, macro_def.attrs, TargetNew::MacroDef); + intravisit::walk_macro_def(self, macro_def); + } + + fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) { + self.check_spirv_attributes(param.hir_id, param.attrs, TargetNew::Param); + + intravisit::walk_param(self, param); + } +} + +fn check_invalid_macro_level_spirv_attr(tcx: TyCtxt<'_>, sym: &Symbols, attrs: &[Attribute]) { + for attr in attrs { + if tcx.sess.check_name(attr, sym.spirv) { + tcx.sess + .span_err(attr.span, "#[spirv(..)] cannot be applied to a macro"); + } + } +} + +// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`. +fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalDefId) { + let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor { + tcx, + sym: Symbols::get(), + }; + tcx.hir().visit_item_likes_in_module( + module_def_id, + &mut check_spirv_attr_visitor.as_deep_visitor(), + ); + // FIXME(eddyb) use `tcx.hir().visit_exported_macros_in_krate(...)` after rustup. + for id in tcx.hir().krate().exported_macros { + check_spirv_attr_visitor.visit_macro_def(match tcx.hir().find(id.hir_id) { + Some(hir::Node::MacroDef(macro_def)) => macro_def, + _ => unreachable!(), + }); + } + check_invalid_macro_level_spirv_attr( + tcx, + &check_spirv_attr_visitor.sym, + tcx.hir().krate().non_exported_macro_attrs, + ); + if module_def_id.is_top_level_module() { + check_spirv_attr_visitor.check_spirv_attributes( + CRATE_HIR_ID, + tcx.hir().krate_attrs(), + Target::Mod, + ); + } +} + +pub(crate) fn provide(providers: &mut Providers) { + *providers = Providers { + check_mod_attrs: |tcx, def_id| { + // Run both the default checks, and our `#[spirv(...)]` ones. + (rustc_interface::DEFAULT_QUERY_PROVIDERS.check_mod_attrs)(tcx, def_id); + check_mod_attrs(tcx, def_id) + }, + ..*providers + }; +} diff --git a/crates/rustc_codegen_spirv/src/lib.rs b/crates/rustc_codegen_spirv/src/lib.rs index 99f0f10eaf..31d63a08f3 100644 --- a/crates/rustc_codegen_spirv/src/lib.rs +++ b/crates/rustc_codegen_spirv/src/lib.rs @@ -81,6 +81,7 @@ macro_rules! assert_ty_eq { } mod abi; +mod attr; mod builder; mod builder_spirv; mod codegen_cx; @@ -299,6 +300,9 @@ impl CodegenBackend for SpirvCodegenBackend { inner }) }; + + // Extra hooks provided by other parts of `rustc_codegen_spirv`. + crate::attr::provide(providers); } fn provide_extern(&self, providers: &mut query::Providers) { diff --git a/crates/rustc_codegen_spirv/src/symbols.rs b/crates/rustc_codegen_spirv/src/symbols.rs index b251db157a..a875673fa7 100644 --- a/crates/rustc_codegen_spirv/src/symbols.rs +++ b/crates/rustc_codegen_spirv/src/symbols.rs @@ -448,6 +448,7 @@ impl From for Entry { } } +// FIXME(eddyb) maybe move this to `attr`? #[derive(Debug, Clone)] pub enum SpirvAttribute { Builtin(BuiltIn), @@ -471,28 +472,32 @@ pub enum SpirvAttribute { UnrollLoops, } -// Note that we could mark the attr as used via cx.tcx.sess.mark_attr_used(attr), but unused -// reporting already happens even before we get here :( +// FIXME(eddyb) maybe move this to `attr`? /// Returns only the spirv attributes that could successfully parsed. -/// For any malformed ones, an error is reported. +/// For any malformed ones, an error is reported prior to codegen, by a check pass. pub fn parse_attrs<'a, 'tcx>( cx: &'a CodegenCx<'tcx>, attrs: &'tcx [Attribute], ) -> impl Iterator + Captures<'tcx> + 'a { - parse_attrs_with_errors(&cx.sym, attrs).filter_map(move |parse_attr_result| { - parse_attr_result - .map_err(|(span, msg)| cx.tcx.sess.span_err(span, &msg)) - .ok() - }) + parse_attrs_for_checking(&cx.sym, attrs) + .filter_map(move |parse_attr_result| { + // NOTE(eddyb) `delay_span_bug` ensures that if attribute checking fails + // to see an attribute error, it will cause an ICE instead. + parse_attr_result + .map_err(|(span, msg)| cx.tcx.sess.delay_span_bug(span, &msg)) + .ok() + }) + .map(|(_span, parsed_attr)| parsed_attr) } // FIXME(eddyb) find something nicer for the error type. type ParseAttrError = (Span, String); -fn parse_attrs_with_errors<'a>( +// FIXME(eddyb) maybe move this to `attr`? +pub(crate) fn parse_attrs_for_checking<'a>( sym: &'a Symbols, attrs: &'a [Attribute], -) -> impl Iterator> + 'a { +) -> impl Iterator> + 'a { attrs.iter().flat_map(move |attr| { let is_spirv = match attr.kind { AttrKind::Normal(ref item, _) => { @@ -519,18 +524,19 @@ fn parse_attrs_with_errors<'a>( whole_attr_error .into_iter() .chain(args.into_iter().map(move |ref arg| { - if arg.has_name(sym.image_type) { - parse_image_type(sym, arg) + let span = arg.span(); + let parsed_attr = if arg.has_name(sym.image_type) { + parse_image_type(sym, arg)? } else if arg.has_name(sym.descriptor_set) { - Ok(SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)) + SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?) } else if arg.has_name(sym.binding) { - Ok(SpirvAttribute::Binding(parse_attr_int_value(arg)?)) + SpirvAttribute::Binding(parse_attr_int_value(arg)?) } else { let name = match arg.ident() { Some(i) => i, None => { return Err(( - arg.span(), + span, "#[spirv(..)] attribute argument must be single identifier" .to_string(), )); @@ -548,8 +554,9 @@ fn parse_attrs_with_errors<'a>( }) .unwrap_or_else(|| { Err((name.span, "unknown argument to spirv attribute".to_string())) - }) - } + })? + }; + Ok((span, parsed_attr)) })) }) }