diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index e0ff9f9022..91987d3dc6 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -1,3 +1,4 @@ # Summary - [Introduction](./introduction.md) +- [Attribute syntax](./attributes.md) diff --git a/docs/src/attributes.md b/docs/src/attributes.md new file mode 100644 index 0000000000..a1ca8f70c0 --- /dev/null +++ b/docs/src/attributes.md @@ -0,0 +1,48 @@ +# Attribute syntax + +rust-gpu introduces a number of SPIR-V related attributes to express behavior specific to SPIR-V not exposed in the base rust language. + +There are a few different categories of attributes: + +## Entry points + +When declaring an entry point to your shader, SPIR-V needs to know what type of function it is. For example, it could be a fragment shader, or vertex shader. Specifying this attribute is also the way rust-gpu knows that you would like to export a function as an entry point, no other functions are exported. + +Example: + +```rust +#[spirv(fragment)] +fn main() { } +``` + +Common values are `#[spirv(fragment)]` and `#[spirv(vertex)]`. A list of all supported names can be found in [spirv_headers](https://docs.rs/spirv_headers/1.5.0/spirv_headers/enum.ExecutionModel.html) - convert the enum name to snake_case for the rust-gpu attribute name. + +## Builtins + +When declaring inputs and outputs, sometimes you want to declare it as a "builtin". This means many things, but one example is `gl_Position` from glsl - the GPU assigns inherent meaning to the variable and uses it for placing the vertex in clip space. The equivalent in rust-gpu is called `position`. + +Example: + +```rust: +#[spirv(fragment)] +fn main( + #[spirv(position)] mut out_pos: Output, +) { } +``` + +Common values are `#[spirv(position)]`, `#[spirv(vertex_id)]`, and many more. A list of all supported names can be found in [spirv_headers](https://docs.rs/spirv_headers/1.5.0/spirv_headers/enum.BuiltIn.html) - convert the enum name to snake_case for the rust-gpu attribute name. + +## Descriptor set and binding + +A SPIR-V shader must declare where uniform variables are located with explicit indices that match up with CPU-side code. This can be done with the `descriptor_set` and `binding` attributes. Note that `descriptor_set = 0` is reserved for future use, and cannot be used. + +Example: + +```rust: +#[spirv(fragment)] +fn main( + #[spirv(descriptor_set = 2, binding = 5)] mut var: Uniform, +) { } +``` + +Both descriptor_set and binding take an integer argument that specifies the uniform's index. diff --git a/examples/example-shader/src/lib.rs b/examples/example-shader/src/lib.rs index 44cb4efcb7..a74b0e8102 100644 --- a/examples/example-shader/src/lib.rs +++ b/examples/example-shader/src/lib.rs @@ -194,7 +194,7 @@ pub fn fs(screen_pos: Vec2) -> Vec4 { } #[allow(unused_attributes)] -#[spirv(entry = "fragment")] +#[spirv(fragment)] pub fn main_fs(input: Input, mut output: Output) { let v = input.load(); let color = fs(Vec2::new(v.0, v.1)); @@ -202,11 +202,11 @@ pub fn main_fs(input: Input, mut output: Output) { } #[allow(unused_attributes)] -#[spirv(entry = "vertex")] +#[spirv(vertex)] pub fn main_vs( in_pos: Input, _in_color: Input, - #[spirv(builtin = "position")] mut out_pos: Output, + #[spirv(position)] mut out_pos: Output, mut out_color: Output, ) { out_pos.store(in_pos.load()); diff --git a/rustc_codegen_spirv/src/abi.rs b/rustc_codegen_spirv/src/abi.rs index 26cf376585..3c8ef3b34e 100644 --- a/rustc_codegen_spirv/src/abi.rs +++ b/rustc_codegen_spirv/src/abi.rs @@ -3,7 +3,7 @@ use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; -use crate::symbols::{parse_attr, SpirvAttribute}; +use crate::symbols::{parse_attrs, SpirvAttribute}; use rspirv::spirv::{StorageClass, Word}; use rustc_middle::bug; use rustc_middle::ty::layout::{FnAbiExt, TyAndLayout}; @@ -538,8 +538,8 @@ fn dig_scalar_pointee_adt<'tcx>( // TODO: Enforce this is only used in spirv-std. fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Option { if let TyKind::Adt(adt, _substs) = ty.ty.kind() { - for attr in cx.tcx.get_attrs(adt.did) { - if let Some(SpirvAttribute::StorageClass(storage_class)) = parse_attr(cx, attr) { + for attr in parse_attrs(cx, cx.tcx.get_attrs(adt.did)) { + if let SpirvAttribute::StorageClass(storage_class) = attr { return Some(storage_class); } } diff --git a/rustc_codegen_spirv/src/codegen_cx/declare.rs b/rustc_codegen_spirv/src/codegen_cx/declare.rs index 102bc7dde0..1bf8f976c4 100644 --- a/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -2,14 +2,10 @@ use super::CodegenCx; use crate::abi::ConvSpirvType; use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt}; use crate::spirv_type::SpirvType; -use crate::symbols::{parse_attr, SpirvAttribute}; -use rspirv::dr::Operand; -use rspirv::spirv::{ - Decoration, ExecutionMode, ExecutionModel, FunctionControl, LinkageType, StorageClass, Word, -}; +use crate::symbols::{parse_attrs, SpirvAttribute}; +use rspirv::spirv::{FunctionControl, LinkageType, StorageClass, Word}; use rustc_attr::InlineAttr; use rustc_codegen_ssa::traits::{PreDefineMethods, StaticMethods}; -use rustc_hir::Param; use rustc_middle::bug; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::mir::mono::{Linkage, MonoItem, Visibility}; @@ -18,9 +14,7 @@ use rustc_middle::ty::{Instance, ParamEnv, Ty, TypeFoldable}; use rustc_span::def_id::DefId; use rustc_span::Span; use rustc_target::abi::call::FnAbi; -use rustc_target::abi::call::PassMode; use rustc_target::abi::{Align, LayoutOf}; -use std::collections::HashMap; fn attrs_to_spirv(attrs: &CodegenFnAttrs) -> FunctionControl { let mut control = FunctionControl::NONE; @@ -156,194 +150,6 @@ impl<'tcx> CodegenCx<'tcx> { self.zombie_with_span(result.def, span, "Globals are not supported yet"); result } - - // Entry points declare their "interface" (all uniforms, inputs, outputs, etc.) as parameters. spir-v uses globals - // to declare the interface. So, we need to generate a lil stub for the "real" main that collects all those global - // variables and calls the user-defined main function. - fn shader_entry_stub( - &self, - entry_func: SpirvValue, - hir_params: &[Param<'_>], - name: String, - execution_model: ExecutionModel, - ) { - let void = SpirvType::Void.def(self); - let fn_void_void = SpirvType::Function { - return_type: void, - arguments: vec![], - } - .def(self); - let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) { - SpirvType::Function { - return_type, - arguments, - } => (return_type, arguments), - other => self.tcx.sess.fatal(&format!( - "Invalid entry_stub type: {}", - other.debug(entry_func.ty, self) - )), - }; - let mut emit = self.emit_global(); - let mut decoration_locations = HashMap::new(); - // Create OpVariables before OpFunction so they're global instead of local vars. - let arguments = entry_func_args - .iter() - .zip(hir_params) - .map(|(&arg, hir_param)| { - let storage_class = match self.lookup_type(arg) { - SpirvType::Pointer { storage_class, .. } => storage_class, - other => self.tcx.sess.fatal(&format!( - "Invalid entry arg type {}", - other.debug(arg, self) - )), - }; - let mut has_location = match storage_class { - StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant => { - true - } - _ => false, - }; - // Note: this *declares* the variable too. - let variable = emit.variable(arg, None, storage_class, None); - for attr in hir_param.attrs { - if let Some(SpirvAttribute::Builtin(builtin)) = parse_attr(self, attr) { - emit.decorate( - variable, - Decoration::BuiltIn, - std::iter::once(Operand::BuiltIn(builtin)), - ); - has_location = false; - } - } - // Assign locations from left to right, incrementing each storage class - // individually. - // TODO: Is this right for UniformConstant? Do they share locations with - // input/outpus? - if has_location { - let location = decoration_locations - .entry(storage_class) - .or_insert_with(|| 0); - emit.decorate( - variable, - Decoration::Location, - std::iter::once(Operand::LiteralInt32(*location)), - ); - *location += 1; - } - variable - }) - .collect::>(); - let fn_id = emit - .begin_function(void, None, FunctionControl::NONE, fn_void_void) - .unwrap(); - emit.begin_block(None).unwrap(); - emit.function_call( - entry_func_return, - None, - entry_func.def, - arguments.iter().copied(), - ) - .unwrap(); - emit.ret().unwrap(); - emit.end_function().unwrap(); - - let interface = arguments; - emit.entry_point(execution_model, fn_id, name, interface); - if execution_model == ExecutionModel::Fragment { - // TODO: Make this configurable. - emit.execution_mode(fn_id, ExecutionMode::OriginUpperLeft, &[]); - } - } - - // Kernel mode takes its interface as function parameters(??) - // OpEntryPoints cannot be OpLinkage, so write out a stub to call through. - fn kernel_entry_stub( - &self, - entry_func: SpirvValue, - name: String, - execution_model: ExecutionModel, - ) { - let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) { - SpirvType::Function { - return_type, - arguments, - } => (return_type, arguments), - other => self.tcx.sess.fatal(&format!( - "Invalid kernel_entry_stub type: {}", - other.debug(entry_func.ty, self) - )), - }; - let mut emit = self.emit_global(); - let fn_id = emit - .begin_function( - entry_func_return, - None, - FunctionControl::NONE, - entry_func.ty, - ) - .unwrap(); - let arguments = entry_func_args - .iter() - .map(|&ty| emit.function_parameter(ty).unwrap()) - .collect::>(); - emit.begin_block(None).unwrap(); - let call_result = emit - .function_call(entry_func_return, None, entry_func.def, arguments) - .unwrap(); - if self.lookup_type(entry_func_return) == SpirvType::Void { - emit.ret().unwrap(); - } else { - emit.ret_value(call_result).unwrap(); - } - emit.end_function().unwrap(); - - emit.entry_point(execution_model, fn_id, name, &[]); - } - - fn entry_stub( - &self, - instance: &Instance<'_>, - fn_abi: &FnAbi<'_, Ty<'_>>, - entry_func: SpirvValue, - name: String, - execution_model: ExecutionModel, - ) { - let local_id = match instance.def_id().as_local() { - Some(id) => id, - None => { - self.tcx - .sess - .err(&format!("Cannot declare {} as an entry point", name)); - return; - } - }; - let fn_hir_id = self.tcx.hir().local_def_id_to_hir_id(local_id); - let body = self.tcx.hir().body(self.tcx.hir().body_owned_by(fn_hir_id)); - for (abi, arg) in fn_abi.args.iter().zip(body.params) { - if let PassMode::Direct(_) = abi.mode { - } else { - self.tcx.sess.span_err( - arg.span, - &format!("PassMode {:?} invalid for entry point parameter", abi.mode), - ) - } - } - if let PassMode::Ignore = fn_abi.ret.mode { - } else { - self.tcx.sess.span_err( - self.tcx.hir().span(fn_hir_id), - &format!( - "PassMode {:?} invalid for entry point return type", - fn_abi.ret.mode - ), - ) - } - if execution_model == ExecutionModel::Kernel { - self.kernel_entry_stub(entry_func, name, execution_model); - } else { - self.shader_entry_stub(entry_func, body.params, name, execution_model); - } - } } impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> { @@ -403,16 +209,16 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> { let declared = self.declare_fn_ext(symbol_name, Some(&human_name), linkage2, spv_attrs, &fn_abi); - for attr in self.tcx.get_attrs(instance.def_id()) { - match parse_attr(self, attr) { - Some(SpirvAttribute::Entry(execution_model)) => self.entry_stub( + for attr in parse_attrs(self, self.tcx.get_attrs(instance.def_id())) { + match attr { + SpirvAttribute::Entry(execution_model) => self.entry_stub( &instance, &fn_abi, declared, human_name.clone(), execution_model, ), - Some(SpirvAttribute::ReallyUnsafeIgnoreBitcasts) => { + SpirvAttribute::ReallyUnsafeIgnoreBitcasts => { self.really_unsafe_ignore_bitcasts .borrow_mut() .insert(declared); diff --git a/rustc_codegen_spirv/src/codegen_cx/entry.rs b/rustc_codegen_spirv/src/codegen_cx/entry.rs new file mode 100644 index 0000000000..39910dae3f --- /dev/null +++ b/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -0,0 +1,245 @@ +use super::CodegenCx; +use crate::builder_spirv::SpirvValue; +use crate::spirv_type::SpirvType; +use crate::symbols::{parse_attrs, SpirvAttribute}; +use rspirv::dr::Operand; +use rspirv::spirv::{ + Decoration, ExecutionMode, ExecutionModel, FunctionControl, StorageClass, Word, +}; +use rustc_hir::Param; +use rustc_middle::ty::{Instance, Ty}; +use rustc_target::abi::call::{FnAbi, PassMode}; +use std::collections::HashMap; + +impl<'tcx> CodegenCx<'tcx> { + // Entry points declare their "interface" (all uniforms, inputs, outputs, etc.) as parameters. + // spir-v uses globals to declare the interface. So, we need to generate a lil stub for the + // "real" main that collects all those global variables and calls the user-defined main + // function. + pub fn entry_stub( + &self, + instance: &Instance<'_>, + fn_abi: &FnAbi<'_, Ty<'_>>, + entry_func: SpirvValue, + name: String, + execution_model: ExecutionModel, + ) { + let local_id = match instance.def_id().as_local() { + Some(id) => id, + None => { + self.tcx + .sess + .err(&format!("Cannot declare {} as an entry point", name)); + return; + } + }; + let fn_hir_id = self.tcx.hir().local_def_id_to_hir_id(local_id); + let body = self.tcx.hir().body(self.tcx.hir().body_owned_by(fn_hir_id)); + for (abi, arg) in fn_abi.args.iter().zip(body.params) { + if let PassMode::Direct(_) = abi.mode { + } else { + self.tcx.sess.span_err( + arg.span, + &format!("PassMode {:?} invalid for entry point parameter", abi.mode), + ) + } + } + if let PassMode::Ignore = fn_abi.ret.mode { + } else { + self.tcx.sess.span_err( + self.tcx.hir().span(fn_hir_id), + &format!( + "PassMode {:?} invalid for entry point return type", + fn_abi.ret.mode + ), + ) + } + if execution_model == ExecutionModel::Kernel { + self.kernel_entry_stub(entry_func, name, execution_model); + } else { + self.shader_entry_stub(entry_func, body.params, name, execution_model); + } + } + + fn shader_entry_stub( + &self, + entry_func: SpirvValue, + hir_params: &[Param<'tcx>], + name: String, + execution_model: ExecutionModel, + ) { + let void = SpirvType::Void.def(self); + let fn_void_void = SpirvType::Function { + return_type: void, + arguments: vec![], + } + .def(self); + let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) { + SpirvType::Function { + return_type, + arguments, + } => (return_type, arguments), + other => self.tcx.sess.fatal(&format!( + "Invalid entry_stub type: {}", + other.debug(entry_func.ty, self) + )), + }; + let mut decoration_locations = HashMap::new(); + // Create OpVariables before OpFunction so they're global instead of local vars. + let arguments = entry_func_args + .iter() + .zip(hir_params) + .map(|(&arg, hir_param)| { + self.declare_parameter(arg, hir_param, &mut decoration_locations) + }) + .collect::>(); + let mut emit = self.emit_global(); + let fn_id = emit + .begin_function(void, None, FunctionControl::NONE, fn_void_void) + .unwrap(); + emit.begin_block(None).unwrap(); + emit.function_call( + entry_func_return, + None, + entry_func.def, + arguments.iter().map(|&(a, _)| a), + ) + .unwrap(); + emit.ret().unwrap(); + emit.end_function().unwrap(); + + let interface: Vec<_> = if emit.version().unwrap() > (1, 3) { + // SPIR-V >= v1.4 includes all OpVariables in the interface. + arguments.into_iter().map(|(a, _)| a).collect() + } else { + // SPIR-V <= v1.3 only includes Input and Output in the interface. + arguments + .into_iter() + .filter(|&(_, s)| s == StorageClass::Input || s == StorageClass::Output) + .map(|(a, _)| a) + .collect() + }; + emit.entry_point(execution_model, fn_id, name, interface); + if execution_model == ExecutionModel::Fragment { + // TODO: Make this configurable. + emit.execution_mode(fn_id, ExecutionMode::OriginUpperLeft, &[]); + } + } + + fn declare_parameter( + &self, + arg: Word, + hir_param: &Param<'tcx>, + decoration_locations: &mut HashMap, + ) -> (Word, StorageClass) { + let storage_class = match self.lookup_type(arg) { + SpirvType::Pointer { storage_class, .. } => storage_class, + other => self.tcx.sess.fatal(&format!( + "Invalid entry arg type {}", + other.debug(arg, self) + )), + }; + let mut has_location = matches!( + storage_class, + StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant + ); + // Note: this *declares* the variable too. + let variable = self.emit_global().variable(arg, None, storage_class, None); + for attr in parse_attrs(self, hir_param.attrs) { + match attr { + SpirvAttribute::Builtin(builtin) => { + self.emit_global().decorate( + variable, + Decoration::BuiltIn, + std::iter::once(Operand::BuiltIn(builtin)), + ); + has_location = false; + } + SpirvAttribute::DescriptorSet(index) => { + self.emit_global().decorate( + variable, + Decoration::DescriptorSet, + std::iter::once(Operand::LiteralInt32(index)), + ); + has_location = false; + } + SpirvAttribute::Binding(index) => { + if index == 0 { + self.tcx.sess.span_err( + hir_param.span, + "descriptor_set 0 is reserved for internal / future use", + ); + } + self.emit_global().decorate( + variable, + Decoration::Binding, + std::iter::once(Operand::LiteralInt32(index)), + ); + has_location = false; + } + _ => {} + } + } + // Assign locations from left to right, incrementing each storage class + // individually. + // TODO: Is this right for UniformConstant? Do they share locations with + // input/outpus? + if has_location { + let location = decoration_locations + .entry(storage_class) + .or_insert_with(|| 0); + self.emit_global().decorate( + variable, + Decoration::Location, + std::iter::once(Operand::LiteralInt32(*location)), + ); + *location += 1; + } + (variable, storage_class) + } + + // Kernel mode takes its interface as function parameters(??) + // OpEntryPoints cannot be OpLinkage, so write out a stub to call through. + fn kernel_entry_stub( + &self, + entry_func: SpirvValue, + name: String, + execution_model: ExecutionModel, + ) { + let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) { + SpirvType::Function { + return_type, + arguments, + } => (return_type, arguments), + other => self.tcx.sess.fatal(&format!( + "Invalid kernel_entry_stub type: {}", + other.debug(entry_func.ty, self) + )), + }; + let mut emit = self.emit_global(); + let fn_id = emit + .begin_function( + entry_func_return, + None, + FunctionControl::NONE, + entry_func.ty, + ) + .unwrap(); + let arguments = entry_func_args + .iter() + .map(|&ty| emit.function_parameter(ty).unwrap()) + .collect::>(); + emit.begin_block(None).unwrap(); + let call_result = emit + .function_call(entry_func_return, None, entry_func.def, arguments) + .unwrap(); + if self.lookup_type(entry_func_return) == SpirvType::Void { + emit.ret().unwrap(); + } else { + emit.ret_value(call_result).unwrap(); + } + emit.end_function().unwrap(); + + emit.entry_point(execution_model, fn_id, name, &[]); + } +} diff --git a/rustc_codegen_spirv/src/codegen_cx/mod.rs b/rustc_codegen_spirv/src/codegen_cx/mod.rs index 5974b05563..15203596e5 100644 --- a/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -1,5 +1,6 @@ mod constant; mod declare; +mod entry; mod type_; use crate::builder::ExtInst; diff --git a/rustc_codegen_spirv/src/symbols.rs b/rustc_codegen_spirv/src/symbols.rs index 1f44cd1b58..beb8494418 100644 --- a/rustc_codegen_spirv/src/symbols.rs +++ b/rustc_codegen_spirv/src/symbols.rs @@ -1,6 +1,6 @@ use crate::codegen_cx::CodegenCx; use rspirv::spirv::{BuiltIn, ExecutionModel, StorageClass}; -use rustc_ast::ast::{AttrKind, Attribute}; +use rustc_ast::ast::{AttrKind, Attribute, Lit, LitIntType, LitKind, NestedMetaItem}; use rustc_span::symbol::Symbol; use std::collections::HashMap; @@ -13,19 +13,14 @@ pub struct Symbols { pub spirv: Symbol, pub spirv_std: Symbol, pub kernel: Symbol, - pub builtin: Symbol, - pub storage_class: Symbol, - pub entry: Symbol, - pub really_unsafe_ignore_bitcasts: Symbol, - - builtins: HashMap, - storage_classes: HashMap, - execution_models: HashMap, + descriptor_set: Symbol, + binding: Symbol, + attributes: HashMap, } -fn make_builtins() -> HashMap { +const BUILTINS: &[(&str, BuiltIn)] = { use BuiltIn::*; - [ + &[ ("position", Position), ("point_size", PointSize), ("clip_distance", ClipDistance), @@ -126,15 +121,12 @@ fn make_builtins() -> HashMap { ("warp_id_nv", WarpIDNV), ("SMIDNV", SMIDNV), ] - .iter() - .map(|&(a, b)| (Symbol::intern(a), b)) - .collect() -} +}; -fn make_storage_classes() -> HashMap { +const STORAGE_CLASSES: &[(&str, StorageClass)] = { use StorageClass::*; // make sure these strings stay synced with spirv-std's pointer types - [ + &[ ("uniform_constant", UniformConstant), ("input", Input), ("uniform", Uniform), @@ -165,14 +157,11 @@ fn make_storage_classes() -> HashMap { ), ("physical_storage_buffer", PhysicalStorageBuffer), ] - .iter() - .map(|&(a, b)| (Symbol::intern(a), b)) - .collect() -} +}; -fn make_execution_models() -> HashMap { +const EXECUTION_MODELS: &[(&str, ExecutionModel)] = { use ExecutionModel::*; - [ + &[ ("vertex", Vertex), ("tessellation_control", TessellationControl), ("tessellation_evaluation", TessellationEvaluation), @@ -189,137 +178,143 @@ fn make_execution_models() -> HashMap { ("miss_nv", MissNV), ("callable_nv", CallableNV), ] - .iter() - .map(|&(a, b)| (Symbol::intern(a), b)) - .collect() -} +}; impl Symbols { pub fn new() -> Self { + let builtins = BUILTINS + .iter() + .map(|&(a, b)| (a, SpirvAttribute::Builtin(b))); + let storage_classes = STORAGE_CLASSES + .iter() + .map(|&(a, b)| (a, SpirvAttribute::StorageClass(b))); + let execution_models = EXECUTION_MODELS + .iter() + .map(|&(a, b)| (a, SpirvAttribute::Entry(b))); + let custom = std::iter::once(( + "really_unsafe_ignore_bitcasts", + SpirvAttribute::ReallyUnsafeIgnoreBitcasts, + )); + let attributes_iter = builtins + .chain(storage_classes) + .chain(execution_models) + .chain(custom) + .map(|(a, b)| (Symbol::intern(a), b)); + let mut attributes = HashMap::new(); + for attr in attributes_iter { + let old = attributes.insert(attr.0, attr.1); + // `.collect()` into a HashMap does not error on duplicates, so manually write out the + // loop here to error on duplicates. + assert!(old.is_none()); + } Self { spirv: Symbol::intern("spirv"), spirv_std: Symbol::intern("spirv_std"), kernel: Symbol::intern("kernel"), - builtin: Symbol::intern("builtin"), - storage_class: Symbol::intern("storage_class"), - entry: Symbol::intern("entry"), - really_unsafe_ignore_bitcasts: Symbol::intern("really_unsafe_ignore_bitcasts"), - builtins: make_builtins(), - storage_classes: make_storage_classes(), - execution_models: make_execution_models(), + descriptor_set: Symbol::intern("descriptor_set"), + binding: Symbol::intern("binding"), + attributes, } } - - pub fn symbol_to_builtin(&self, sym: Symbol) -> Option { - self.builtins.get(&sym).copied() - } - - pub fn symbol_to_storageclass(&self, sym: Symbol) -> Option { - self.storage_classes.get(&sym).copied() - } - - pub fn symbol_to_execution_model(&self, sym: Symbol) -> Option { - self.execution_models.get(&sym).copied() - } } +#[derive(Debug, Clone)] pub enum SpirvAttribute { Builtin(BuiltIn), StorageClass(StorageClass), Entry(ExecutionModel), + DescriptorSet(u32), + Binding(u32), ReallyUnsafeIgnoreBitcasts, } -// Note that we could mark the attr as used via cx.tcx.sess.mark_attr_used(attr), but unused reporting already happens -// even before we get here :( -/// Returns None if this attribute is not a spirv attribute, or if it's malformed (and an error is reported). -pub fn parse_attr<'tcx>(cx: &CodegenCx<'tcx>, attr: &Attribute) -> Option { - // Example attributes that we parse here: - // #[spirv(storage_class = "uniform")] - // #[spirv(entry = "kernel")] - let is_spirv = match attr.kind { - AttrKind::Normal(ref item) => { - // TODO: We ignore the rest of the path. Is this right? - let last = item.path.segments.last(); - last.map_or(false, |seg| seg.ident.name == cx.sym.spirv) - } - AttrKind::DocComment(..) => false, - }; - if !is_spirv { - return None; - } - let args = if let Some(args) = attr.meta_item_list() { - args - } else { - cx.tcx - .sess - .span_err(attr.span, "#[spirv(..)] attribute must have one argument"); - return None; - }; - if args.len() != 1 { - cx.tcx - .sess - .span_err(attr.span, "#[spirv(..)] attribute must have one argument"); - return None; - } - let arg = &args[0]; - if arg.has_name(cx.sym.builtin) { - if let Some(builtin_arg) = arg.value_str() { - match cx.sym.symbol_to_builtin(builtin_arg) { - Some(builtin) => Some(SpirvAttribute::Builtin(builtin)), - None => { - cx.tcx.sess.span_err(attr.span, "unknown spir-v builtin"); - None - } +// Note that we could mark the attr as used via cx.tcx.sess.mark_attr_used(attr), but unused +// reporting already happens even before we get here :( +/// Returns empty if this attribute is not a spirv attribute, or if it's malformed (and an error is +/// reported). +pub fn parse_attrs( + cx: &CodegenCx<'_>, + attrs: &[Attribute], +) -> impl Iterator { + let result = attrs.iter().flat_map(|attr| { + let is_spirv = match attr.kind { + AttrKind::Normal(ref item) => { + // TODO: We ignore the rest of the path. Is this right? + let last = item.path.segments.last(); + last.map_or(false, |seg| seg.ident.name == cx.sym.spirv) } + AttrKind::DocComment(..) => false, + }; + let args = if !is_spirv { + // Use an empty vec here to return empty + Vec::new() + } else if let Some(args) = attr.meta_item_list() { + args } else { cx.tcx.sess.span_err( attr.span, - "builtin must have value: #[spirv(builtin = \"..\")]", + "#[spirv(..)] attribute must have at least one argument", ); - None - } - } else if arg.has_name(cx.sym.storage_class) { - if let Some(storage_arg) = arg.value_str() { - match cx.sym.symbol_to_storageclass(storage_arg) { - Some(storage_class) => Some(SpirvAttribute::StorageClass(storage_class)), - None => { - cx.tcx - .sess - .span_err(attr.span, "unknown spir-v storage class"); - None + Vec::new() + }; + args.into_iter().filter_map(move |ref arg| { + if arg.has_name(cx.sym.descriptor_set) { + match parse_attr_int_value(cx, arg) { + Some(x) => Some(SpirvAttribute::DescriptorSet(x)), + None => None, + } + } else if arg.has_name(cx.sym.binding) { + match parse_attr_int_value(cx, arg) { + Some(x) => Some(SpirvAttribute::Binding(x)), + None => None, + } + } else { + let name = match arg.ident() { + Some(i) => i, + None => { + cx.tcx.sess.span_err( + arg.span(), + "#[spirv(..)] attribute argument must be single identifier", + ); + return None; + } + }; + match cx.sym.attributes.get(&name.name) { + Some(a) => Some(a.clone()), + None => { + cx.tcx + .sess + .span_err(name.span, "unknown argument to spirv attribute"); + None + } } } - } else { - cx.tcx.sess.span_err( - attr.span, - "storage_class must have value: #[spirv(storage_class = \"..\")]", - ); - None - } - } else if arg.has_name(cx.sym.entry) { - if let Some(storage_arg) = arg.value_str() { - match cx.sym.symbol_to_execution_model(storage_arg) { - Some(execution_model) => Some(SpirvAttribute::Entry(execution_model)), - None => { - cx.tcx - .sess - .span_err(attr.span, "unknown spir-v execution model"); - None - } - } - } else { + }) + }); + // lifetimes are hard :( + result.collect::>().into_iter() +} + +fn parse_attr_int_value(cx: &CodegenCx<'_>, arg: &NestedMetaItem) -> Option { + let arg = match arg.meta_item() { + Some(arg) => arg, + None => { cx.tcx .sess - .span_err(attr.span, "entry must have value: #[spirv(entry = \"..\")]"); + .span_err(arg.span(), "attribute must have value"); + return None; + } + }; + match arg.name_value_literal() { + Some(&Lit { + kind: LitKind::Int(x, LitIntType::Unsuffixed), + .. + }) if x <= u32::MAX as u128 => Some(x as u32), + _ => { + cx.tcx + .sess + .span_err(arg.span, "attribute value must be integer"); None } - } else if arg.has_name(cx.sym.really_unsafe_ignore_bitcasts) { - Some(SpirvAttribute::ReallyUnsafeIgnoreBitcasts) - } else { - cx.tcx - .sess - .span_err(attr.span, "unknown argument to spirv attribute"); - None } } diff --git a/spirv-builder/src/test/basic.rs b/spirv-builder/src/test/basic.rs index 6d28e20687..e4db88969a 100644 --- a/spirv-builder/src/test/basic.rs +++ b/spirv-builder/src/test/basic.rs @@ -4,7 +4,7 @@ use super::val; fn hello_world() { val(r#" #[allow(unused_attributes)] -#[spirv(entry = "fragment")] +#[spirv(fragment)] pub fn main() { } "#); diff --git a/spirv-std/src/lib.rs b/spirv-std/src/lib.rs index ffb6654242..9ed436639f 100644 --- a/spirv-std/src/lib.rs +++ b/spirv-std/src/lib.rs @@ -54,9 +54,9 @@ macro_rules! pointer_addrspace_write { } macro_rules! pointer_addrspace { - ($storage_class:literal, $type_name:ident, $writeable:tt) => { + ($storage_class:ident, $type_name:ident, $writeable:tt) => { #[allow(unused_attributes)] - #[spirv(storage_class = $storage_class)] + #[spirv($storage_class)] pub struct $type_name<'a, T> { x: &'a mut T, } @@ -76,26 +76,26 @@ macro_rules! pointer_addrspace { // Make sure these strings stay synced with symbols.rs // Note the type names don't have to match anything, they can be renamed (only the string must match) -pointer_addrspace!("uniform_constant", UniformConstant, false); -pointer_addrspace!("input", Input, false); -pointer_addrspace!("uniform", Uniform, true); -pointer_addrspace!("output", Output, true); -pointer_addrspace!("workgroup", Workgroup, true); -pointer_addrspace!("cross_workgroup", CrossWorkgroup, true); -pointer_addrspace!("private", Private, true); -pointer_addrspace!("function", Function, true); -pointer_addrspace!("generic", Generic, true); -pointer_addrspace!("push_constant", PushConstant, false); -pointer_addrspace!("atomic_counter", AtomicCounter, true); -pointer_addrspace!("image", Image, true); -pointer_addrspace!("storage_buffer", StorageBuffer, true); -pointer_addrspace!("callable_data_khr", CallableDataKHR, true); -pointer_addrspace!("incoming_callable_data_khr", IncomingCallableDataKHR, true); -pointer_addrspace!("ray_payload_khr", RayPayloadKHR, true); -pointer_addrspace!("hit_attribute_khr", HitAttributeKHR, true); -pointer_addrspace!("incoming_ray_payload_khr", IncomingRayPayloadKHR, true); -pointer_addrspace!("shader_record_buffer_khr", ShaderRecordBufferKHR, true); -pointer_addrspace!("physical_storage_buffer", PhysicalStorageBuffer, true); +pointer_addrspace!(uniform_constant, UniformConstant, false); +pointer_addrspace!(input, Input, false); +pointer_addrspace!(uniform, Uniform, true); +pointer_addrspace!(output, Output, true); +pointer_addrspace!(workgroup, Workgroup, true); +pointer_addrspace!(cross_workgroup, CrossWorkgroup, true); +pointer_addrspace!(private, Private, true); +pointer_addrspace!(function, Function, true); +pointer_addrspace!(generic, Generic, true); +pointer_addrspace!(push_constant, PushConstant, false); +pointer_addrspace!(atomic_counter, AtomicCounter, true); +pointer_addrspace!(image, Image, true); +pointer_addrspace!(storage_buffer, StorageBuffer, true); +pointer_addrspace!(callable_data_khr, CallableDataKHR, true); +pointer_addrspace!(incoming_callable_data_khr, IncomingCallableDataKHR, true); +pointer_addrspace!(ray_payload_khr, RayPayloadKHR, true); +pointer_addrspace!(hit_attribute_khr, HitAttributeKHR, true); +pointer_addrspace!(incoming_ray_payload_khr, IncomingRayPayloadKHR, true); +pointer_addrspace!(shader_record_buffer_khr, ShaderRecordBufferKHR, true); +pointer_addrspace!(physical_storage_buffer, PhysicalStorageBuffer, true); #[allow(non_camel_case_types)] #[derive(Debug, Clone, Copy)]