Refactor attributes and add descriptor_set/binding (#145)

* Move entry declarations to their own file

Also clean up attribute parsing (and make it allow multiple arguments in
the process)

* Add descriptor_set and binding attributes

* clippy fix

* Fix test

* Reserve descriptor_set 0 for future use

* Add book page on attributes
This commit is contained in:
Ashley Hauck 2020-10-26 16:23:21 +01:00 committed by GitHub
parent fe18434bff
commit b32c04b3fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 451 additions and 355 deletions

View File

@ -1,3 +1,4 @@
# Summary # Summary
- [Introduction](./introduction.md) - [Introduction](./introduction.md)
- [Attribute syntax](./attributes.md)

48
docs/src/attributes.md Normal file
View File

@ -0,0 +1,48 @@
# Attribute syntax
rust-gpu introduces a number of SPIR-V related attributes to express behavior specific to SPIR-V not exposed in the base rust language.
There are a few different categories of attributes:
## Entry points
When declaring an entry point to your shader, SPIR-V needs to know what type of function it is. For example, it could be a fragment shader, or vertex shader. Specifying this attribute is also the way rust-gpu knows that you would like to export a function as an entry point, no other functions are exported.
Example:
```rust
#[spirv(fragment)]
fn main() { }
```
Common values are `#[spirv(fragment)]` and `#[spirv(vertex)]`. A list of all supported names can be found in [spirv_headers](https://docs.rs/spirv_headers/1.5.0/spirv_headers/enum.ExecutionModel.html) - convert the enum name to snake_case for the rust-gpu attribute name.
## Builtins
When declaring inputs and outputs, sometimes you want to declare it as a "builtin". This means many things, but one example is `gl_Position` from glsl - the GPU assigns inherent meaning to the variable and uses it for placing the vertex in clip space. The equivalent in rust-gpu is called `position`.
Example:
```rust:
#[spirv(fragment)]
fn main(
#[spirv(position)] mut out_pos: Output<Vec4>,
) { }
```
Common values are `#[spirv(position)]`, `#[spirv(vertex_id)]`, and many more. A list of all supported names can be found in [spirv_headers](https://docs.rs/spirv_headers/1.5.0/spirv_headers/enum.BuiltIn.html) - convert the enum name to snake_case for the rust-gpu attribute name.
## Descriptor set and binding
A SPIR-V shader must declare where uniform variables are located with explicit indices that match up with CPU-side code. This can be done with the `descriptor_set` and `binding` attributes. Note that `descriptor_set = 0` is reserved for future use, and cannot be used.
Example:
```rust:
#[spirv(fragment)]
fn main(
#[spirv(descriptor_set = 2, binding = 5)] mut var: Uniform<Vec4>,
) { }
```
Both descriptor_set and binding take an integer argument that specifies the uniform's index.

View File

@ -194,7 +194,7 @@ pub fn fs(screen_pos: Vec2) -> Vec4 {
} }
#[allow(unused_attributes)] #[allow(unused_attributes)]
#[spirv(entry = "fragment")] #[spirv(fragment)]
pub fn main_fs(input: Input<Vec4>, mut output: Output<Vec4>) { pub fn main_fs(input: Input<Vec4>, mut output: Output<Vec4>) {
let v = input.load(); let v = input.load();
let color = fs(Vec2::new(v.0, v.1)); let color = fs(Vec2::new(v.0, v.1));
@ -202,11 +202,11 @@ pub fn main_fs(input: Input<Vec4>, mut output: Output<Vec4>) {
} }
#[allow(unused_attributes)] #[allow(unused_attributes)]
#[spirv(entry = "vertex")] #[spirv(vertex)]
pub fn main_vs( pub fn main_vs(
in_pos: Input<Vec4>, in_pos: Input<Vec4>,
_in_color: Input<Vec4>, _in_color: Input<Vec4>,
#[spirv(builtin = "position")] mut out_pos: Output<Vec4>, #[spirv(position)] mut out_pos: Output<Vec4>,
mut out_color: Output<Vec4>, mut out_color: Output<Vec4>,
) { ) {
out_pos.store(in_pos.load()); out_pos.store(in_pos.load());

View File

@ -3,7 +3,7 @@
use crate::codegen_cx::CodegenCx; use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType; use crate::spirv_type::SpirvType;
use crate::symbols::{parse_attr, SpirvAttribute}; use crate::symbols::{parse_attrs, SpirvAttribute};
use rspirv::spirv::{StorageClass, Word}; use rspirv::spirv::{StorageClass, Word};
use rustc_middle::bug; use rustc_middle::bug;
use rustc_middle::ty::layout::{FnAbiExt, TyAndLayout}; use rustc_middle::ty::layout::{FnAbiExt, TyAndLayout};
@ -538,8 +538,8 @@ fn dig_scalar_pointee_adt<'tcx>(
// TODO: Enforce this is only used in spirv-std. // TODO: Enforce this is only used in spirv-std.
fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Option<StorageClass> { fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Option<StorageClass> {
if let TyKind::Adt(adt, _substs) = ty.ty.kind() { if let TyKind::Adt(adt, _substs) = ty.ty.kind() {
for attr in cx.tcx.get_attrs(adt.did) { for attr in parse_attrs(cx, cx.tcx.get_attrs(adt.did)) {
if let Some(SpirvAttribute::StorageClass(storage_class)) = parse_attr(cx, attr) { if let SpirvAttribute::StorageClass(storage_class) = attr {
return Some(storage_class); return Some(storage_class);
} }
} }

View File

@ -2,14 +2,10 @@ use super::CodegenCx;
use crate::abi::ConvSpirvType; use crate::abi::ConvSpirvType;
use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt}; use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt};
use crate::spirv_type::SpirvType; use crate::spirv_type::SpirvType;
use crate::symbols::{parse_attr, SpirvAttribute}; use crate::symbols::{parse_attrs, SpirvAttribute};
use rspirv::dr::Operand; use rspirv::spirv::{FunctionControl, LinkageType, StorageClass, Word};
use rspirv::spirv::{
Decoration, ExecutionMode, ExecutionModel, FunctionControl, LinkageType, StorageClass, Word,
};
use rustc_attr::InlineAttr; use rustc_attr::InlineAttr;
use rustc_codegen_ssa::traits::{PreDefineMethods, StaticMethods}; use rustc_codegen_ssa::traits::{PreDefineMethods, StaticMethods};
use rustc_hir::Param;
use rustc_middle::bug; use rustc_middle::bug;
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs};
use rustc_middle::mir::mono::{Linkage, MonoItem, Visibility}; use rustc_middle::mir::mono::{Linkage, MonoItem, Visibility};
@ -18,9 +14,7 @@ use rustc_middle::ty::{Instance, ParamEnv, Ty, TypeFoldable};
use rustc_span::def_id::DefId; use rustc_span::def_id::DefId;
use rustc_span::Span; use rustc_span::Span;
use rustc_target::abi::call::FnAbi; use rustc_target::abi::call::FnAbi;
use rustc_target::abi::call::PassMode;
use rustc_target::abi::{Align, LayoutOf}; use rustc_target::abi::{Align, LayoutOf};
use std::collections::HashMap;
fn attrs_to_spirv(attrs: &CodegenFnAttrs) -> FunctionControl { fn attrs_to_spirv(attrs: &CodegenFnAttrs) -> FunctionControl {
let mut control = FunctionControl::NONE; let mut control = FunctionControl::NONE;
@ -156,194 +150,6 @@ impl<'tcx> CodegenCx<'tcx> {
self.zombie_with_span(result.def, span, "Globals are not supported yet"); self.zombie_with_span(result.def, span, "Globals are not supported yet");
result result
} }
// Entry points declare their "interface" (all uniforms, inputs, outputs, etc.) as parameters. spir-v uses globals
// to declare the interface. So, we need to generate a lil stub for the "real" main that collects all those global
// variables and calls the user-defined main function.
fn shader_entry_stub(
&self,
entry_func: SpirvValue,
hir_params: &[Param<'_>],
name: String,
execution_model: ExecutionModel,
) {
let void = SpirvType::Void.def(self);
let fn_void_void = SpirvType::Function {
return_type: void,
arguments: vec![],
}
.def(self);
let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) {
SpirvType::Function {
return_type,
arguments,
} => (return_type, arguments),
other => self.tcx.sess.fatal(&format!(
"Invalid entry_stub type: {}",
other.debug(entry_func.ty, self)
)),
};
let mut emit = self.emit_global();
let mut decoration_locations = HashMap::new();
// Create OpVariables before OpFunction so they're global instead of local vars.
let arguments = entry_func_args
.iter()
.zip(hir_params)
.map(|(&arg, hir_param)| {
let storage_class = match self.lookup_type(arg) {
SpirvType::Pointer { storage_class, .. } => storage_class,
other => self.tcx.sess.fatal(&format!(
"Invalid entry arg type {}",
other.debug(arg, self)
)),
};
let mut has_location = match storage_class {
StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant => {
true
}
_ => false,
};
// Note: this *declares* the variable too.
let variable = emit.variable(arg, None, storage_class, None);
for attr in hir_param.attrs {
if let Some(SpirvAttribute::Builtin(builtin)) = parse_attr(self, attr) {
emit.decorate(
variable,
Decoration::BuiltIn,
std::iter::once(Operand::BuiltIn(builtin)),
);
has_location = false;
}
}
// Assign locations from left to right, incrementing each storage class
// individually.
// TODO: Is this right for UniformConstant? Do they share locations with
// input/outpus?
if has_location {
let location = decoration_locations
.entry(storage_class)
.or_insert_with(|| 0);
emit.decorate(
variable,
Decoration::Location,
std::iter::once(Operand::LiteralInt32(*location)),
);
*location += 1;
}
variable
})
.collect::<Vec<_>>();
let fn_id = emit
.begin_function(void, None, FunctionControl::NONE, fn_void_void)
.unwrap();
emit.begin_block(None).unwrap();
emit.function_call(
entry_func_return,
None,
entry_func.def,
arguments.iter().copied(),
)
.unwrap();
emit.ret().unwrap();
emit.end_function().unwrap();
let interface = arguments;
emit.entry_point(execution_model, fn_id, name, interface);
if execution_model == ExecutionModel::Fragment {
// TODO: Make this configurable.
emit.execution_mode(fn_id, ExecutionMode::OriginUpperLeft, &[]);
}
}
// Kernel mode takes its interface as function parameters(??)
// OpEntryPoints cannot be OpLinkage, so write out a stub to call through.
fn kernel_entry_stub(
&self,
entry_func: SpirvValue,
name: String,
execution_model: ExecutionModel,
) {
let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) {
SpirvType::Function {
return_type,
arguments,
} => (return_type, arguments),
other => self.tcx.sess.fatal(&format!(
"Invalid kernel_entry_stub type: {}",
other.debug(entry_func.ty, self)
)),
};
let mut emit = self.emit_global();
let fn_id = emit
.begin_function(
entry_func_return,
None,
FunctionControl::NONE,
entry_func.ty,
)
.unwrap();
let arguments = entry_func_args
.iter()
.map(|&ty| emit.function_parameter(ty).unwrap())
.collect::<Vec<_>>();
emit.begin_block(None).unwrap();
let call_result = emit
.function_call(entry_func_return, None, entry_func.def, arguments)
.unwrap();
if self.lookup_type(entry_func_return) == SpirvType::Void {
emit.ret().unwrap();
} else {
emit.ret_value(call_result).unwrap();
}
emit.end_function().unwrap();
emit.entry_point(execution_model, fn_id, name, &[]);
}
fn entry_stub(
&self,
instance: &Instance<'_>,
fn_abi: &FnAbi<'_, Ty<'_>>,
entry_func: SpirvValue,
name: String,
execution_model: ExecutionModel,
) {
let local_id = match instance.def_id().as_local() {
Some(id) => id,
None => {
self.tcx
.sess
.err(&format!("Cannot declare {} as an entry point", name));
return;
}
};
let fn_hir_id = self.tcx.hir().local_def_id_to_hir_id(local_id);
let body = self.tcx.hir().body(self.tcx.hir().body_owned_by(fn_hir_id));
for (abi, arg) in fn_abi.args.iter().zip(body.params) {
if let PassMode::Direct(_) = abi.mode {
} else {
self.tcx.sess.span_err(
arg.span,
&format!("PassMode {:?} invalid for entry point parameter", abi.mode),
)
}
}
if let PassMode::Ignore = fn_abi.ret.mode {
} else {
self.tcx.sess.span_err(
self.tcx.hir().span(fn_hir_id),
&format!(
"PassMode {:?} invalid for entry point return type",
fn_abi.ret.mode
),
)
}
if execution_model == ExecutionModel::Kernel {
self.kernel_entry_stub(entry_func, name, execution_model);
} else {
self.shader_entry_stub(entry_func, body.params, name, execution_model);
}
}
} }
impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> { impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> {
@ -403,16 +209,16 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> {
let declared = let declared =
self.declare_fn_ext(symbol_name, Some(&human_name), linkage2, spv_attrs, &fn_abi); self.declare_fn_ext(symbol_name, Some(&human_name), linkage2, spv_attrs, &fn_abi);
for attr in self.tcx.get_attrs(instance.def_id()) { for attr in parse_attrs(self, self.tcx.get_attrs(instance.def_id())) {
match parse_attr(self, attr) { match attr {
Some(SpirvAttribute::Entry(execution_model)) => self.entry_stub( SpirvAttribute::Entry(execution_model) => self.entry_stub(
&instance, &instance,
&fn_abi, &fn_abi,
declared, declared,
human_name.clone(), human_name.clone(),
execution_model, execution_model,
), ),
Some(SpirvAttribute::ReallyUnsafeIgnoreBitcasts) => { SpirvAttribute::ReallyUnsafeIgnoreBitcasts => {
self.really_unsafe_ignore_bitcasts self.really_unsafe_ignore_bitcasts
.borrow_mut() .borrow_mut()
.insert(declared); .insert(declared);

View File

@ -0,0 +1,245 @@
use super::CodegenCx;
use crate::builder_spirv::SpirvValue;
use crate::spirv_type::SpirvType;
use crate::symbols::{parse_attrs, SpirvAttribute};
use rspirv::dr::Operand;
use rspirv::spirv::{
Decoration, ExecutionMode, ExecutionModel, FunctionControl, StorageClass, Word,
};
use rustc_hir::Param;
use rustc_middle::ty::{Instance, Ty};
use rustc_target::abi::call::{FnAbi, PassMode};
use std::collections::HashMap;
impl<'tcx> CodegenCx<'tcx> {
// Entry points declare their "interface" (all uniforms, inputs, outputs, etc.) as parameters.
// spir-v uses globals to declare the interface. So, we need to generate a lil stub for the
// "real" main that collects all those global variables and calls the user-defined main
// function.
pub fn entry_stub(
&self,
instance: &Instance<'_>,
fn_abi: &FnAbi<'_, Ty<'_>>,
entry_func: SpirvValue,
name: String,
execution_model: ExecutionModel,
) {
let local_id = match instance.def_id().as_local() {
Some(id) => id,
None => {
self.tcx
.sess
.err(&format!("Cannot declare {} as an entry point", name));
return;
}
};
let fn_hir_id = self.tcx.hir().local_def_id_to_hir_id(local_id);
let body = self.tcx.hir().body(self.tcx.hir().body_owned_by(fn_hir_id));
for (abi, arg) in fn_abi.args.iter().zip(body.params) {
if let PassMode::Direct(_) = abi.mode {
} else {
self.tcx.sess.span_err(
arg.span,
&format!("PassMode {:?} invalid for entry point parameter", abi.mode),
)
}
}
if let PassMode::Ignore = fn_abi.ret.mode {
} else {
self.tcx.sess.span_err(
self.tcx.hir().span(fn_hir_id),
&format!(
"PassMode {:?} invalid for entry point return type",
fn_abi.ret.mode
),
)
}
if execution_model == ExecutionModel::Kernel {
self.kernel_entry_stub(entry_func, name, execution_model);
} else {
self.shader_entry_stub(entry_func, body.params, name, execution_model);
}
}
fn shader_entry_stub(
&self,
entry_func: SpirvValue,
hir_params: &[Param<'tcx>],
name: String,
execution_model: ExecutionModel,
) {
let void = SpirvType::Void.def(self);
let fn_void_void = SpirvType::Function {
return_type: void,
arguments: vec![],
}
.def(self);
let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) {
SpirvType::Function {
return_type,
arguments,
} => (return_type, arguments),
other => self.tcx.sess.fatal(&format!(
"Invalid entry_stub type: {}",
other.debug(entry_func.ty, self)
)),
};
let mut decoration_locations = HashMap::new();
// Create OpVariables before OpFunction so they're global instead of local vars.
let arguments = entry_func_args
.iter()
.zip(hir_params)
.map(|(&arg, hir_param)| {
self.declare_parameter(arg, hir_param, &mut decoration_locations)
})
.collect::<Vec<_>>();
let mut emit = self.emit_global();
let fn_id = emit
.begin_function(void, None, FunctionControl::NONE, fn_void_void)
.unwrap();
emit.begin_block(None).unwrap();
emit.function_call(
entry_func_return,
None,
entry_func.def,
arguments.iter().map(|&(a, _)| a),
)
.unwrap();
emit.ret().unwrap();
emit.end_function().unwrap();
let interface: Vec<_> = if emit.version().unwrap() > (1, 3) {
// SPIR-V >= v1.4 includes all OpVariables in the interface.
arguments.into_iter().map(|(a, _)| a).collect()
} else {
// SPIR-V <= v1.3 only includes Input and Output in the interface.
arguments
.into_iter()
.filter(|&(_, s)| s == StorageClass::Input || s == StorageClass::Output)
.map(|(a, _)| a)
.collect()
};
emit.entry_point(execution_model, fn_id, name, interface);
if execution_model == ExecutionModel::Fragment {
// TODO: Make this configurable.
emit.execution_mode(fn_id, ExecutionMode::OriginUpperLeft, &[]);
}
}
fn declare_parameter(
&self,
arg: Word,
hir_param: &Param<'tcx>,
decoration_locations: &mut HashMap<StorageClass, u32>,
) -> (Word, StorageClass) {
let storage_class = match self.lookup_type(arg) {
SpirvType::Pointer { storage_class, .. } => storage_class,
other => self.tcx.sess.fatal(&format!(
"Invalid entry arg type {}",
other.debug(arg, self)
)),
};
let mut has_location = matches!(
storage_class,
StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant
);
// Note: this *declares* the variable too.
let variable = self.emit_global().variable(arg, None, storage_class, None);
for attr in parse_attrs(self, hir_param.attrs) {
match attr {
SpirvAttribute::Builtin(builtin) => {
self.emit_global().decorate(
variable,
Decoration::BuiltIn,
std::iter::once(Operand::BuiltIn(builtin)),
);
has_location = false;
}
SpirvAttribute::DescriptorSet(index) => {
self.emit_global().decorate(
variable,
Decoration::DescriptorSet,
std::iter::once(Operand::LiteralInt32(index)),
);
has_location = false;
}
SpirvAttribute::Binding(index) => {
if index == 0 {
self.tcx.sess.span_err(
hir_param.span,
"descriptor_set 0 is reserved for internal / future use",
);
}
self.emit_global().decorate(
variable,
Decoration::Binding,
std::iter::once(Operand::LiteralInt32(index)),
);
has_location = false;
}
_ => {}
}
}
// Assign locations from left to right, incrementing each storage class
// individually.
// TODO: Is this right for UniformConstant? Do they share locations with
// input/outpus?
if has_location {
let location = decoration_locations
.entry(storage_class)
.or_insert_with(|| 0);
self.emit_global().decorate(
variable,
Decoration::Location,
std::iter::once(Operand::LiteralInt32(*location)),
);
*location += 1;
}
(variable, storage_class)
}
// Kernel mode takes its interface as function parameters(??)
// OpEntryPoints cannot be OpLinkage, so write out a stub to call through.
fn kernel_entry_stub(
&self,
entry_func: SpirvValue,
name: String,
execution_model: ExecutionModel,
) {
let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) {
SpirvType::Function {
return_type,
arguments,
} => (return_type, arguments),
other => self.tcx.sess.fatal(&format!(
"Invalid kernel_entry_stub type: {}",
other.debug(entry_func.ty, self)
)),
};
let mut emit = self.emit_global();
let fn_id = emit
.begin_function(
entry_func_return,
None,
FunctionControl::NONE,
entry_func.ty,
)
.unwrap();
let arguments = entry_func_args
.iter()
.map(|&ty| emit.function_parameter(ty).unwrap())
.collect::<Vec<_>>();
emit.begin_block(None).unwrap();
let call_result = emit
.function_call(entry_func_return, None, entry_func.def, arguments)
.unwrap();
if self.lookup_type(entry_func_return) == SpirvType::Void {
emit.ret().unwrap();
} else {
emit.ret_value(call_result).unwrap();
}
emit.end_function().unwrap();
emit.entry_point(execution_model, fn_id, name, &[]);
}
}

View File

@ -1,5 +1,6 @@
mod constant; mod constant;
mod declare; mod declare;
mod entry;
mod type_; mod type_;
use crate::builder::ExtInst; use crate::builder::ExtInst;

View File

@ -1,6 +1,6 @@
use crate::codegen_cx::CodegenCx; use crate::codegen_cx::CodegenCx;
use rspirv::spirv::{BuiltIn, ExecutionModel, StorageClass}; use rspirv::spirv::{BuiltIn, ExecutionModel, StorageClass};
use rustc_ast::ast::{AttrKind, Attribute}; use rustc_ast::ast::{AttrKind, Attribute, Lit, LitIntType, LitKind, NestedMetaItem};
use rustc_span::symbol::Symbol; use rustc_span::symbol::Symbol;
use std::collections::HashMap; use std::collections::HashMap;
@ -13,19 +13,14 @@ pub struct Symbols {
pub spirv: Symbol, pub spirv: Symbol,
pub spirv_std: Symbol, pub spirv_std: Symbol,
pub kernel: Symbol, pub kernel: Symbol,
pub builtin: Symbol, descriptor_set: Symbol,
pub storage_class: Symbol, binding: Symbol,
pub entry: Symbol, attributes: HashMap<Symbol, SpirvAttribute>,
pub really_unsafe_ignore_bitcasts: Symbol,
builtins: HashMap<Symbol, BuiltIn>,
storage_classes: HashMap<Symbol, StorageClass>,
execution_models: HashMap<Symbol, ExecutionModel>,
} }
fn make_builtins() -> HashMap<Symbol, BuiltIn> { const BUILTINS: &[(&str, BuiltIn)] = {
use BuiltIn::*; use BuiltIn::*;
[ &[
("position", Position), ("position", Position),
("point_size", PointSize), ("point_size", PointSize),
("clip_distance", ClipDistance), ("clip_distance", ClipDistance),
@ -126,15 +121,12 @@ fn make_builtins() -> HashMap<Symbol, BuiltIn> {
("warp_id_nv", WarpIDNV), ("warp_id_nv", WarpIDNV),
("SMIDNV", SMIDNV), ("SMIDNV", SMIDNV),
] ]
.iter() };
.map(|&(a, b)| (Symbol::intern(a), b))
.collect()
}
fn make_storage_classes() -> HashMap<Symbol, StorageClass> { const STORAGE_CLASSES: &[(&str, StorageClass)] = {
use StorageClass::*; use StorageClass::*;
// make sure these strings stay synced with spirv-std's pointer types // make sure these strings stay synced with spirv-std's pointer types
[ &[
("uniform_constant", UniformConstant), ("uniform_constant", UniformConstant),
("input", Input), ("input", Input),
("uniform", Uniform), ("uniform", Uniform),
@ -165,14 +157,11 @@ fn make_storage_classes() -> HashMap<Symbol, StorageClass> {
), ),
("physical_storage_buffer", PhysicalStorageBuffer), ("physical_storage_buffer", PhysicalStorageBuffer),
] ]
.iter() };
.map(|&(a, b)| (Symbol::intern(a), b))
.collect()
}
fn make_execution_models() -> HashMap<Symbol, ExecutionModel> { const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
use ExecutionModel::*; use ExecutionModel::*;
[ &[
("vertex", Vertex), ("vertex", Vertex),
("tessellation_control", TessellationControl), ("tessellation_control", TessellationControl),
("tessellation_evaluation", TessellationEvaluation), ("tessellation_evaluation", TessellationEvaluation),
@ -189,54 +178,65 @@ fn make_execution_models() -> HashMap<Symbol, ExecutionModel> {
("miss_nv", MissNV), ("miss_nv", MissNV),
("callable_nv", CallableNV), ("callable_nv", CallableNV),
] ]
.iter() };
.map(|&(a, b)| (Symbol::intern(a), b))
.collect()
}
impl Symbols { impl Symbols {
pub fn new() -> Self { pub fn new() -> Self {
let builtins = BUILTINS
.iter()
.map(|&(a, b)| (a, SpirvAttribute::Builtin(b)));
let storage_classes = STORAGE_CLASSES
.iter()
.map(|&(a, b)| (a, SpirvAttribute::StorageClass(b)));
let execution_models = EXECUTION_MODELS
.iter()
.map(|&(a, b)| (a, SpirvAttribute::Entry(b)));
let custom = std::iter::once((
"really_unsafe_ignore_bitcasts",
SpirvAttribute::ReallyUnsafeIgnoreBitcasts,
));
let attributes_iter = builtins
.chain(storage_classes)
.chain(execution_models)
.chain(custom)
.map(|(a, b)| (Symbol::intern(a), b));
let mut attributes = HashMap::new();
for attr in attributes_iter {
let old = attributes.insert(attr.0, attr.1);
// `.collect()` into a HashMap does not error on duplicates, so manually write out the
// loop here to error on duplicates.
assert!(old.is_none());
}
Self { Self {
spirv: Symbol::intern("spirv"), spirv: Symbol::intern("spirv"),
spirv_std: Symbol::intern("spirv_std"), spirv_std: Symbol::intern("spirv_std"),
kernel: Symbol::intern("kernel"), kernel: Symbol::intern("kernel"),
builtin: Symbol::intern("builtin"), descriptor_set: Symbol::intern("descriptor_set"),
storage_class: Symbol::intern("storage_class"), binding: Symbol::intern("binding"),
entry: Symbol::intern("entry"), attributes,
really_unsafe_ignore_bitcasts: Symbol::intern("really_unsafe_ignore_bitcasts"),
builtins: make_builtins(),
storage_classes: make_storage_classes(),
execution_models: make_execution_models(),
} }
} }
pub fn symbol_to_builtin(&self, sym: Symbol) -> Option<BuiltIn> {
self.builtins.get(&sym).copied()
}
pub fn symbol_to_storageclass(&self, sym: Symbol) -> Option<StorageClass> {
self.storage_classes.get(&sym).copied()
}
pub fn symbol_to_execution_model(&self, sym: Symbol) -> Option<ExecutionModel> {
self.execution_models.get(&sym).copied()
}
} }
#[derive(Debug, Clone)]
pub enum SpirvAttribute { pub enum SpirvAttribute {
Builtin(BuiltIn), Builtin(BuiltIn),
StorageClass(StorageClass), StorageClass(StorageClass),
Entry(ExecutionModel), Entry(ExecutionModel),
DescriptorSet(u32),
Binding(u32),
ReallyUnsafeIgnoreBitcasts, ReallyUnsafeIgnoreBitcasts,
} }
// Note that we could mark the attr as used via cx.tcx.sess.mark_attr_used(attr), but unused reporting already happens // Note that we could mark the attr as used via cx.tcx.sess.mark_attr_used(attr), but unused
// even before we get here :( // reporting already happens even before we get here :(
/// Returns None if this attribute is not a spirv attribute, or if it's malformed (and an error is reported). /// Returns empty if this attribute is not a spirv attribute, or if it's malformed (and an error is
pub fn parse_attr<'tcx>(cx: &CodegenCx<'tcx>, attr: &Attribute) -> Option<SpirvAttribute> { /// reported).
// Example attributes that we parse here: pub fn parse_attrs(
// #[spirv(storage_class = "uniform")] cx: &CodegenCx<'_>,
// #[spirv(entry = "kernel")] attrs: &[Attribute],
) -> impl Iterator<Item = SpirvAttribute> {
let result = attrs.iter().flat_map(|attr| {
let is_spirv = match attr.kind { let is_spirv = match attr.kind {
AttrKind::Normal(ref item) => { AttrKind::Normal(ref item) => {
// TODO: We ignore the rest of the path. Is this right? // TODO: We ignore the rest of the path. Is this right?
@ -245,81 +245,76 @@ pub fn parse_attr<'tcx>(cx: &CodegenCx<'tcx>, attr: &Attribute) -> Option<SpirvA
} }
AttrKind::DocComment(..) => false, AttrKind::DocComment(..) => false,
}; };
if !is_spirv { let args = if !is_spirv {
return None; // Use an empty vec here to return empty
} Vec::new()
let args = if let Some(args) = attr.meta_item_list() { } else if let Some(args) = attr.meta_item_list() {
args args
} else { } else {
cx.tcx cx.tcx.sess.span_err(
.sess attr.span,
.span_err(attr.span, "#[spirv(..)] attribute must have one argument"); "#[spirv(..)] attribute must have at least one argument",
return None; );
Vec::new()
}; };
if args.len() != 1 { args.into_iter().filter_map(move |ref arg| {
cx.tcx if arg.has_name(cx.sym.descriptor_set) {
.sess match parse_attr_int_value(cx, arg) {
.span_err(attr.span, "#[spirv(..)] attribute must have one argument"); Some(x) => Some(SpirvAttribute::DescriptorSet(x)),
None => None,
}
} else if arg.has_name(cx.sym.binding) {
match parse_attr_int_value(cx, arg) {
Some(x) => Some(SpirvAttribute::Binding(x)),
None => None,
}
} else {
let name = match arg.ident() {
Some(i) => i,
None => {
cx.tcx.sess.span_err(
arg.span(),
"#[spirv(..)] attribute argument must be single identifier",
);
return None; return None;
} }
let arg = &args[0]; };
if arg.has_name(cx.sym.builtin) { match cx.sym.attributes.get(&name.name) {
if let Some(builtin_arg) = arg.value_str() { Some(a) => Some(a.clone()),
match cx.sym.symbol_to_builtin(builtin_arg) {
Some(builtin) => Some(SpirvAttribute::Builtin(builtin)),
None => {
cx.tcx.sess.span_err(attr.span, "unknown spir-v builtin");
None
}
}
} else {
cx.tcx.sess.span_err(
attr.span,
"builtin must have value: #[spirv(builtin = \"..\")]",
);
None
}
} else if arg.has_name(cx.sym.storage_class) {
if let Some(storage_arg) = arg.value_str() {
match cx.sym.symbol_to_storageclass(storage_arg) {
Some(storage_class) => Some(SpirvAttribute::StorageClass(storage_class)),
None => { None => {
cx.tcx cx.tcx
.sess .sess
.span_err(attr.span, "unknown spir-v storage class"); .span_err(name.span, "unknown argument to spirv attribute");
None None
} }
} }
} else {
cx.tcx.sess.span_err(
attr.span,
"storage_class must have value: #[spirv(storage_class = \"..\")]",
);
None
} }
} else if arg.has_name(cx.sym.entry) { })
if let Some(storage_arg) = arg.value_str() { });
match cx.sym.symbol_to_execution_model(storage_arg) { // lifetimes are hard :(
Some(execution_model) => Some(SpirvAttribute::Entry(execution_model)), result.collect::<Vec<_>>().into_iter()
}
fn parse_attr_int_value(cx: &CodegenCx<'_>, arg: &NestedMetaItem) -> Option<u32> {
let arg = match arg.meta_item() {
Some(arg) => arg,
None => { None => {
cx.tcx cx.tcx
.sess .sess
.span_err(attr.span, "unknown spir-v execution model"); .span_err(arg.span(), "attribute must have value");
None return None;
} }
} };
} else { match arg.name_value_literal() {
Some(&Lit {
kind: LitKind::Int(x, LitIntType::Unsuffixed),
..
}) if x <= u32::MAX as u128 => Some(x as u32),
_ => {
cx.tcx cx.tcx
.sess .sess
.span_err(attr.span, "entry must have value: #[spirv(entry = \"..\")]"); .span_err(arg.span, "attribute value must be integer");
None None
} }
} else if arg.has_name(cx.sym.really_unsafe_ignore_bitcasts) {
Some(SpirvAttribute::ReallyUnsafeIgnoreBitcasts)
} else {
cx.tcx
.sess
.span_err(attr.span, "unknown argument to spirv attribute");
None
} }
} }

View File

@ -4,7 +4,7 @@ use super::val;
fn hello_world() { fn hello_world() {
val(r#" val(r#"
#[allow(unused_attributes)] #[allow(unused_attributes)]
#[spirv(entry = "fragment")] #[spirv(fragment)]
pub fn main() { pub fn main() {
} }
"#); "#);

View File

@ -54,9 +54,9 @@ macro_rules! pointer_addrspace_write {
} }
macro_rules! pointer_addrspace { macro_rules! pointer_addrspace {
($storage_class:literal, $type_name:ident, $writeable:tt) => { ($storage_class:ident, $type_name:ident, $writeable:tt) => {
#[allow(unused_attributes)] #[allow(unused_attributes)]
#[spirv(storage_class = $storage_class)] #[spirv($storage_class)]
pub struct $type_name<'a, T> { pub struct $type_name<'a, T> {
x: &'a mut T, x: &'a mut T,
} }
@ -76,26 +76,26 @@ macro_rules! pointer_addrspace {
// Make sure these strings stay synced with symbols.rs // Make sure these strings stay synced with symbols.rs
// Note the type names don't have to match anything, they can be renamed (only the string must match) // Note the type names don't have to match anything, they can be renamed (only the string must match)
pointer_addrspace!("uniform_constant", UniformConstant, false); pointer_addrspace!(uniform_constant, UniformConstant, false);
pointer_addrspace!("input", Input, false); pointer_addrspace!(input, Input, false);
pointer_addrspace!("uniform", Uniform, true); pointer_addrspace!(uniform, Uniform, true);
pointer_addrspace!("output", Output, true); pointer_addrspace!(output, Output, true);
pointer_addrspace!("workgroup", Workgroup, true); pointer_addrspace!(workgroup, Workgroup, true);
pointer_addrspace!("cross_workgroup", CrossWorkgroup, true); pointer_addrspace!(cross_workgroup, CrossWorkgroup, true);
pointer_addrspace!("private", Private, true); pointer_addrspace!(private, Private, true);
pointer_addrspace!("function", Function, true); pointer_addrspace!(function, Function, true);
pointer_addrspace!("generic", Generic, true); pointer_addrspace!(generic, Generic, true);
pointer_addrspace!("push_constant", PushConstant, false); pointer_addrspace!(push_constant, PushConstant, false);
pointer_addrspace!("atomic_counter", AtomicCounter, true); pointer_addrspace!(atomic_counter, AtomicCounter, true);
pointer_addrspace!("image", Image, true); pointer_addrspace!(image, Image, true);
pointer_addrspace!("storage_buffer", StorageBuffer, true); pointer_addrspace!(storage_buffer, StorageBuffer, true);
pointer_addrspace!("callable_data_khr", CallableDataKHR, true); pointer_addrspace!(callable_data_khr, CallableDataKHR, true);
pointer_addrspace!("incoming_callable_data_khr", IncomingCallableDataKHR, true); pointer_addrspace!(incoming_callable_data_khr, IncomingCallableDataKHR, true);
pointer_addrspace!("ray_payload_khr", RayPayloadKHR, true); pointer_addrspace!(ray_payload_khr, RayPayloadKHR, true);
pointer_addrspace!("hit_attribute_khr", HitAttributeKHR, true); pointer_addrspace!(hit_attribute_khr, HitAttributeKHR, true);
pointer_addrspace!("incoming_ray_payload_khr", IncomingRayPayloadKHR, true); pointer_addrspace!(incoming_ray_payload_khr, IncomingRayPayloadKHR, true);
pointer_addrspace!("shader_record_buffer_khr", ShaderRecordBufferKHR, true); pointer_addrspace!(shader_record_buffer_khr, ShaderRecordBufferKHR, true);
pointer_addrspace!("physical_storage_buffer", PhysicalStorageBuffer, true); pointer_addrspace!(physical_storage_buffer, PhysicalStorageBuffer, true);
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]