mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-22 06:45:13 +00:00
Refactor attributes and add descriptor_set/binding (#145)
* Move entry declarations to their own file Also clean up attribute parsing (and make it allow multiple arguments in the process) * Add descriptor_set and binding attributes * clippy fix * Fix test * Reserve descriptor_set 0 for future use * Add book page on attributes
This commit is contained in:
parent
fe18434bff
commit
b32c04b3fd
@ -1,3 +1,4 @@
|
||||
# Summary
|
||||
|
||||
- [Introduction](./introduction.md)
|
||||
- [Attribute syntax](./attributes.md)
|
||||
|
48
docs/src/attributes.md
Normal file
48
docs/src/attributes.md
Normal file
@ -0,0 +1,48 @@
|
||||
# Attribute syntax
|
||||
|
||||
rust-gpu introduces a number of SPIR-V related attributes to express behavior specific to SPIR-V not exposed in the base rust language.
|
||||
|
||||
There are a few different categories of attributes:
|
||||
|
||||
## Entry points
|
||||
|
||||
When declaring an entry point to your shader, SPIR-V needs to know what type of function it is. For example, it could be a fragment shader, or vertex shader. Specifying this attribute is also the way rust-gpu knows that you would like to export a function as an entry point, no other functions are exported.
|
||||
|
||||
Example:
|
||||
|
||||
```rust
|
||||
#[spirv(fragment)]
|
||||
fn main() { }
|
||||
```
|
||||
|
||||
Common values are `#[spirv(fragment)]` and `#[spirv(vertex)]`. A list of all supported names can be found in [spirv_headers](https://docs.rs/spirv_headers/1.5.0/spirv_headers/enum.ExecutionModel.html) - convert the enum name to snake_case for the rust-gpu attribute name.
|
||||
|
||||
## Builtins
|
||||
|
||||
When declaring inputs and outputs, sometimes you want to declare it as a "builtin". This means many things, but one example is `gl_Position` from glsl - the GPU assigns inherent meaning to the variable and uses it for placing the vertex in clip space. The equivalent in rust-gpu is called `position`.
|
||||
|
||||
Example:
|
||||
|
||||
```rust:
|
||||
#[spirv(fragment)]
|
||||
fn main(
|
||||
#[spirv(position)] mut out_pos: Output<Vec4>,
|
||||
) { }
|
||||
```
|
||||
|
||||
Common values are `#[spirv(position)]`, `#[spirv(vertex_id)]`, and many more. A list of all supported names can be found in [spirv_headers](https://docs.rs/spirv_headers/1.5.0/spirv_headers/enum.BuiltIn.html) - convert the enum name to snake_case for the rust-gpu attribute name.
|
||||
|
||||
## Descriptor set and binding
|
||||
|
||||
A SPIR-V shader must declare where uniform variables are located with explicit indices that match up with CPU-side code. This can be done with the `descriptor_set` and `binding` attributes. Note that `descriptor_set = 0` is reserved for future use, and cannot be used.
|
||||
|
||||
Example:
|
||||
|
||||
```rust:
|
||||
#[spirv(fragment)]
|
||||
fn main(
|
||||
#[spirv(descriptor_set = 2, binding = 5)] mut var: Uniform<Vec4>,
|
||||
) { }
|
||||
```
|
||||
|
||||
Both descriptor_set and binding take an integer argument that specifies the uniform's index.
|
@ -194,7 +194,7 @@ pub fn fs(screen_pos: Vec2) -> Vec4 {
|
||||
}
|
||||
|
||||
#[allow(unused_attributes)]
|
||||
#[spirv(entry = "fragment")]
|
||||
#[spirv(fragment)]
|
||||
pub fn main_fs(input: Input<Vec4>, mut output: Output<Vec4>) {
|
||||
let v = input.load();
|
||||
let color = fs(Vec2::new(v.0, v.1));
|
||||
@ -202,11 +202,11 @@ pub fn main_fs(input: Input<Vec4>, mut output: Output<Vec4>) {
|
||||
}
|
||||
|
||||
#[allow(unused_attributes)]
|
||||
#[spirv(entry = "vertex")]
|
||||
#[spirv(vertex)]
|
||||
pub fn main_vs(
|
||||
in_pos: Input<Vec4>,
|
||||
_in_color: Input<Vec4>,
|
||||
#[spirv(builtin = "position")] mut out_pos: Output<Vec4>,
|
||||
#[spirv(position)] mut out_pos: Output<Vec4>,
|
||||
mut out_color: Output<Vec4>,
|
||||
) {
|
||||
out_pos.store(in_pos.load());
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
use crate::codegen_cx::CodegenCx;
|
||||
use crate::spirv_type::SpirvType;
|
||||
use crate::symbols::{parse_attr, SpirvAttribute};
|
||||
use crate::symbols::{parse_attrs, SpirvAttribute};
|
||||
use rspirv::spirv::{StorageClass, Word};
|
||||
use rustc_middle::bug;
|
||||
use rustc_middle::ty::layout::{FnAbiExt, TyAndLayout};
|
||||
@ -538,8 +538,8 @@ fn dig_scalar_pointee_adt<'tcx>(
|
||||
// TODO: Enforce this is only used in spirv-std.
|
||||
fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Option<StorageClass> {
|
||||
if let TyKind::Adt(adt, _substs) = ty.ty.kind() {
|
||||
for attr in cx.tcx.get_attrs(adt.did) {
|
||||
if let Some(SpirvAttribute::StorageClass(storage_class)) = parse_attr(cx, attr) {
|
||||
for attr in parse_attrs(cx, cx.tcx.get_attrs(adt.did)) {
|
||||
if let SpirvAttribute::StorageClass(storage_class) = attr {
|
||||
return Some(storage_class);
|
||||
}
|
||||
}
|
||||
|
@ -2,14 +2,10 @@ use super::CodegenCx;
|
||||
use crate::abi::ConvSpirvType;
|
||||
use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt};
|
||||
use crate::spirv_type::SpirvType;
|
||||
use crate::symbols::{parse_attr, SpirvAttribute};
|
||||
use rspirv::dr::Operand;
|
||||
use rspirv::spirv::{
|
||||
Decoration, ExecutionMode, ExecutionModel, FunctionControl, LinkageType, StorageClass, Word,
|
||||
};
|
||||
use crate::symbols::{parse_attrs, SpirvAttribute};
|
||||
use rspirv::spirv::{FunctionControl, LinkageType, StorageClass, Word};
|
||||
use rustc_attr::InlineAttr;
|
||||
use rustc_codegen_ssa::traits::{PreDefineMethods, StaticMethods};
|
||||
use rustc_hir::Param;
|
||||
use rustc_middle::bug;
|
||||
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs};
|
||||
use rustc_middle::mir::mono::{Linkage, MonoItem, Visibility};
|
||||
@ -18,9 +14,7 @@ use rustc_middle::ty::{Instance, ParamEnv, Ty, TypeFoldable};
|
||||
use rustc_span::def_id::DefId;
|
||||
use rustc_span::Span;
|
||||
use rustc_target::abi::call::FnAbi;
|
||||
use rustc_target::abi::call::PassMode;
|
||||
use rustc_target::abi::{Align, LayoutOf};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn attrs_to_spirv(attrs: &CodegenFnAttrs) -> FunctionControl {
|
||||
let mut control = FunctionControl::NONE;
|
||||
@ -156,194 +150,6 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
self.zombie_with_span(result.def, span, "Globals are not supported yet");
|
||||
result
|
||||
}
|
||||
|
||||
// Entry points declare their "interface" (all uniforms, inputs, outputs, etc.) as parameters. spir-v uses globals
|
||||
// to declare the interface. So, we need to generate a lil stub for the "real" main that collects all those global
|
||||
// variables and calls the user-defined main function.
|
||||
fn shader_entry_stub(
|
||||
&self,
|
||||
entry_func: SpirvValue,
|
||||
hir_params: &[Param<'_>],
|
||||
name: String,
|
||||
execution_model: ExecutionModel,
|
||||
) {
|
||||
let void = SpirvType::Void.def(self);
|
||||
let fn_void_void = SpirvType::Function {
|
||||
return_type: void,
|
||||
arguments: vec![],
|
||||
}
|
||||
.def(self);
|
||||
let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) {
|
||||
SpirvType::Function {
|
||||
return_type,
|
||||
arguments,
|
||||
} => (return_type, arguments),
|
||||
other => self.tcx.sess.fatal(&format!(
|
||||
"Invalid entry_stub type: {}",
|
||||
other.debug(entry_func.ty, self)
|
||||
)),
|
||||
};
|
||||
let mut emit = self.emit_global();
|
||||
let mut decoration_locations = HashMap::new();
|
||||
// Create OpVariables before OpFunction so they're global instead of local vars.
|
||||
let arguments = entry_func_args
|
||||
.iter()
|
||||
.zip(hir_params)
|
||||
.map(|(&arg, hir_param)| {
|
||||
let storage_class = match self.lookup_type(arg) {
|
||||
SpirvType::Pointer { storage_class, .. } => storage_class,
|
||||
other => self.tcx.sess.fatal(&format!(
|
||||
"Invalid entry arg type {}",
|
||||
other.debug(arg, self)
|
||||
)),
|
||||
};
|
||||
let mut has_location = match storage_class {
|
||||
StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant => {
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
// Note: this *declares* the variable too.
|
||||
let variable = emit.variable(arg, None, storage_class, None);
|
||||
for attr in hir_param.attrs {
|
||||
if let Some(SpirvAttribute::Builtin(builtin)) = parse_attr(self, attr) {
|
||||
emit.decorate(
|
||||
variable,
|
||||
Decoration::BuiltIn,
|
||||
std::iter::once(Operand::BuiltIn(builtin)),
|
||||
);
|
||||
has_location = false;
|
||||
}
|
||||
}
|
||||
// Assign locations from left to right, incrementing each storage class
|
||||
// individually.
|
||||
// TODO: Is this right for UniformConstant? Do they share locations with
|
||||
// input/outpus?
|
||||
if has_location {
|
||||
let location = decoration_locations
|
||||
.entry(storage_class)
|
||||
.or_insert_with(|| 0);
|
||||
emit.decorate(
|
||||
variable,
|
||||
Decoration::Location,
|
||||
std::iter::once(Operand::LiteralInt32(*location)),
|
||||
);
|
||||
*location += 1;
|
||||
}
|
||||
variable
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let fn_id = emit
|
||||
.begin_function(void, None, FunctionControl::NONE, fn_void_void)
|
||||
.unwrap();
|
||||
emit.begin_block(None).unwrap();
|
||||
emit.function_call(
|
||||
entry_func_return,
|
||||
None,
|
||||
entry_func.def,
|
||||
arguments.iter().copied(),
|
||||
)
|
||||
.unwrap();
|
||||
emit.ret().unwrap();
|
||||
emit.end_function().unwrap();
|
||||
|
||||
let interface = arguments;
|
||||
emit.entry_point(execution_model, fn_id, name, interface);
|
||||
if execution_model == ExecutionModel::Fragment {
|
||||
// TODO: Make this configurable.
|
||||
emit.execution_mode(fn_id, ExecutionMode::OriginUpperLeft, &[]);
|
||||
}
|
||||
}
|
||||
|
||||
// Kernel mode takes its interface as function parameters(??)
|
||||
// OpEntryPoints cannot be OpLinkage, so write out a stub to call through.
|
||||
fn kernel_entry_stub(
|
||||
&self,
|
||||
entry_func: SpirvValue,
|
||||
name: String,
|
||||
execution_model: ExecutionModel,
|
||||
) {
|
||||
let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) {
|
||||
SpirvType::Function {
|
||||
return_type,
|
||||
arguments,
|
||||
} => (return_type, arguments),
|
||||
other => self.tcx.sess.fatal(&format!(
|
||||
"Invalid kernel_entry_stub type: {}",
|
||||
other.debug(entry_func.ty, self)
|
||||
)),
|
||||
};
|
||||
let mut emit = self.emit_global();
|
||||
let fn_id = emit
|
||||
.begin_function(
|
||||
entry_func_return,
|
||||
None,
|
||||
FunctionControl::NONE,
|
||||
entry_func.ty,
|
||||
)
|
||||
.unwrap();
|
||||
let arguments = entry_func_args
|
||||
.iter()
|
||||
.map(|&ty| emit.function_parameter(ty).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
emit.begin_block(None).unwrap();
|
||||
let call_result = emit
|
||||
.function_call(entry_func_return, None, entry_func.def, arguments)
|
||||
.unwrap();
|
||||
if self.lookup_type(entry_func_return) == SpirvType::Void {
|
||||
emit.ret().unwrap();
|
||||
} else {
|
||||
emit.ret_value(call_result).unwrap();
|
||||
}
|
||||
emit.end_function().unwrap();
|
||||
|
||||
emit.entry_point(execution_model, fn_id, name, &[]);
|
||||
}
|
||||
|
||||
fn entry_stub(
|
||||
&self,
|
||||
instance: &Instance<'_>,
|
||||
fn_abi: &FnAbi<'_, Ty<'_>>,
|
||||
entry_func: SpirvValue,
|
||||
name: String,
|
||||
execution_model: ExecutionModel,
|
||||
) {
|
||||
let local_id = match instance.def_id().as_local() {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
self.tcx
|
||||
.sess
|
||||
.err(&format!("Cannot declare {} as an entry point", name));
|
||||
return;
|
||||
}
|
||||
};
|
||||
let fn_hir_id = self.tcx.hir().local_def_id_to_hir_id(local_id);
|
||||
let body = self.tcx.hir().body(self.tcx.hir().body_owned_by(fn_hir_id));
|
||||
for (abi, arg) in fn_abi.args.iter().zip(body.params) {
|
||||
if let PassMode::Direct(_) = abi.mode {
|
||||
} else {
|
||||
self.tcx.sess.span_err(
|
||||
arg.span,
|
||||
&format!("PassMode {:?} invalid for entry point parameter", abi.mode),
|
||||
)
|
||||
}
|
||||
}
|
||||
if let PassMode::Ignore = fn_abi.ret.mode {
|
||||
} else {
|
||||
self.tcx.sess.span_err(
|
||||
self.tcx.hir().span(fn_hir_id),
|
||||
&format!(
|
||||
"PassMode {:?} invalid for entry point return type",
|
||||
fn_abi.ret.mode
|
||||
),
|
||||
)
|
||||
}
|
||||
if execution_model == ExecutionModel::Kernel {
|
||||
self.kernel_entry_stub(entry_func, name, execution_model);
|
||||
} else {
|
||||
self.shader_entry_stub(entry_func, body.params, name, execution_model);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> {
|
||||
@ -403,16 +209,16 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> {
|
||||
let declared =
|
||||
self.declare_fn_ext(symbol_name, Some(&human_name), linkage2, spv_attrs, &fn_abi);
|
||||
|
||||
for attr in self.tcx.get_attrs(instance.def_id()) {
|
||||
match parse_attr(self, attr) {
|
||||
Some(SpirvAttribute::Entry(execution_model)) => self.entry_stub(
|
||||
for attr in parse_attrs(self, self.tcx.get_attrs(instance.def_id())) {
|
||||
match attr {
|
||||
SpirvAttribute::Entry(execution_model) => self.entry_stub(
|
||||
&instance,
|
||||
&fn_abi,
|
||||
declared,
|
||||
human_name.clone(),
|
||||
execution_model,
|
||||
),
|
||||
Some(SpirvAttribute::ReallyUnsafeIgnoreBitcasts) => {
|
||||
SpirvAttribute::ReallyUnsafeIgnoreBitcasts => {
|
||||
self.really_unsafe_ignore_bitcasts
|
||||
.borrow_mut()
|
||||
.insert(declared);
|
||||
|
245
rustc_codegen_spirv/src/codegen_cx/entry.rs
Normal file
245
rustc_codegen_spirv/src/codegen_cx/entry.rs
Normal file
@ -0,0 +1,245 @@
|
||||
use super::CodegenCx;
|
||||
use crate::builder_spirv::SpirvValue;
|
||||
use crate::spirv_type::SpirvType;
|
||||
use crate::symbols::{parse_attrs, SpirvAttribute};
|
||||
use rspirv::dr::Operand;
|
||||
use rspirv::spirv::{
|
||||
Decoration, ExecutionMode, ExecutionModel, FunctionControl, StorageClass, Word,
|
||||
};
|
||||
use rustc_hir::Param;
|
||||
use rustc_middle::ty::{Instance, Ty};
|
||||
use rustc_target::abi::call::{FnAbi, PassMode};
|
||||
use std::collections::HashMap;
|
||||
|
||||
impl<'tcx> CodegenCx<'tcx> {
|
||||
// Entry points declare their "interface" (all uniforms, inputs, outputs, etc.) as parameters.
|
||||
// spir-v uses globals to declare the interface. So, we need to generate a lil stub for the
|
||||
// "real" main that collects all those global variables and calls the user-defined main
|
||||
// function.
|
||||
pub fn entry_stub(
|
||||
&self,
|
||||
instance: &Instance<'_>,
|
||||
fn_abi: &FnAbi<'_, Ty<'_>>,
|
||||
entry_func: SpirvValue,
|
||||
name: String,
|
||||
execution_model: ExecutionModel,
|
||||
) {
|
||||
let local_id = match instance.def_id().as_local() {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
self.tcx
|
||||
.sess
|
||||
.err(&format!("Cannot declare {} as an entry point", name));
|
||||
return;
|
||||
}
|
||||
};
|
||||
let fn_hir_id = self.tcx.hir().local_def_id_to_hir_id(local_id);
|
||||
let body = self.tcx.hir().body(self.tcx.hir().body_owned_by(fn_hir_id));
|
||||
for (abi, arg) in fn_abi.args.iter().zip(body.params) {
|
||||
if let PassMode::Direct(_) = abi.mode {
|
||||
} else {
|
||||
self.tcx.sess.span_err(
|
||||
arg.span,
|
||||
&format!("PassMode {:?} invalid for entry point parameter", abi.mode),
|
||||
)
|
||||
}
|
||||
}
|
||||
if let PassMode::Ignore = fn_abi.ret.mode {
|
||||
} else {
|
||||
self.tcx.sess.span_err(
|
||||
self.tcx.hir().span(fn_hir_id),
|
||||
&format!(
|
||||
"PassMode {:?} invalid for entry point return type",
|
||||
fn_abi.ret.mode
|
||||
),
|
||||
)
|
||||
}
|
||||
if execution_model == ExecutionModel::Kernel {
|
||||
self.kernel_entry_stub(entry_func, name, execution_model);
|
||||
} else {
|
||||
self.shader_entry_stub(entry_func, body.params, name, execution_model);
|
||||
}
|
||||
}
|
||||
|
||||
fn shader_entry_stub(
|
||||
&self,
|
||||
entry_func: SpirvValue,
|
||||
hir_params: &[Param<'tcx>],
|
||||
name: String,
|
||||
execution_model: ExecutionModel,
|
||||
) {
|
||||
let void = SpirvType::Void.def(self);
|
||||
let fn_void_void = SpirvType::Function {
|
||||
return_type: void,
|
||||
arguments: vec![],
|
||||
}
|
||||
.def(self);
|
||||
let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) {
|
||||
SpirvType::Function {
|
||||
return_type,
|
||||
arguments,
|
||||
} => (return_type, arguments),
|
||||
other => self.tcx.sess.fatal(&format!(
|
||||
"Invalid entry_stub type: {}",
|
||||
other.debug(entry_func.ty, self)
|
||||
)),
|
||||
};
|
||||
let mut decoration_locations = HashMap::new();
|
||||
// Create OpVariables before OpFunction so they're global instead of local vars.
|
||||
let arguments = entry_func_args
|
||||
.iter()
|
||||
.zip(hir_params)
|
||||
.map(|(&arg, hir_param)| {
|
||||
self.declare_parameter(arg, hir_param, &mut decoration_locations)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut emit = self.emit_global();
|
||||
let fn_id = emit
|
||||
.begin_function(void, None, FunctionControl::NONE, fn_void_void)
|
||||
.unwrap();
|
||||
emit.begin_block(None).unwrap();
|
||||
emit.function_call(
|
||||
entry_func_return,
|
||||
None,
|
||||
entry_func.def,
|
||||
arguments.iter().map(|&(a, _)| a),
|
||||
)
|
||||
.unwrap();
|
||||
emit.ret().unwrap();
|
||||
emit.end_function().unwrap();
|
||||
|
||||
let interface: Vec<_> = if emit.version().unwrap() > (1, 3) {
|
||||
// SPIR-V >= v1.4 includes all OpVariables in the interface.
|
||||
arguments.into_iter().map(|(a, _)| a).collect()
|
||||
} else {
|
||||
// SPIR-V <= v1.3 only includes Input and Output in the interface.
|
||||
arguments
|
||||
.into_iter()
|
||||
.filter(|&(_, s)| s == StorageClass::Input || s == StorageClass::Output)
|
||||
.map(|(a, _)| a)
|
||||
.collect()
|
||||
};
|
||||
emit.entry_point(execution_model, fn_id, name, interface);
|
||||
if execution_model == ExecutionModel::Fragment {
|
||||
// TODO: Make this configurable.
|
||||
emit.execution_mode(fn_id, ExecutionMode::OriginUpperLeft, &[]);
|
||||
}
|
||||
}
|
||||
|
||||
fn declare_parameter(
|
||||
&self,
|
||||
arg: Word,
|
||||
hir_param: &Param<'tcx>,
|
||||
decoration_locations: &mut HashMap<StorageClass, u32>,
|
||||
) -> (Word, StorageClass) {
|
||||
let storage_class = match self.lookup_type(arg) {
|
||||
SpirvType::Pointer { storage_class, .. } => storage_class,
|
||||
other => self.tcx.sess.fatal(&format!(
|
||||
"Invalid entry arg type {}",
|
||||
other.debug(arg, self)
|
||||
)),
|
||||
};
|
||||
let mut has_location = matches!(
|
||||
storage_class,
|
||||
StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant
|
||||
);
|
||||
// Note: this *declares* the variable too.
|
||||
let variable = self.emit_global().variable(arg, None, storage_class, None);
|
||||
for attr in parse_attrs(self, hir_param.attrs) {
|
||||
match attr {
|
||||
SpirvAttribute::Builtin(builtin) => {
|
||||
self.emit_global().decorate(
|
||||
variable,
|
||||
Decoration::BuiltIn,
|
||||
std::iter::once(Operand::BuiltIn(builtin)),
|
||||
);
|
||||
has_location = false;
|
||||
}
|
||||
SpirvAttribute::DescriptorSet(index) => {
|
||||
self.emit_global().decorate(
|
||||
variable,
|
||||
Decoration::DescriptorSet,
|
||||
std::iter::once(Operand::LiteralInt32(index)),
|
||||
);
|
||||
has_location = false;
|
||||
}
|
||||
SpirvAttribute::Binding(index) => {
|
||||
if index == 0 {
|
||||
self.tcx.sess.span_err(
|
||||
hir_param.span,
|
||||
"descriptor_set 0 is reserved for internal / future use",
|
||||
);
|
||||
}
|
||||
self.emit_global().decorate(
|
||||
variable,
|
||||
Decoration::Binding,
|
||||
std::iter::once(Operand::LiteralInt32(index)),
|
||||
);
|
||||
has_location = false;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
// Assign locations from left to right, incrementing each storage class
|
||||
// individually.
|
||||
// TODO: Is this right for UniformConstant? Do they share locations with
|
||||
// input/outpus?
|
||||
if has_location {
|
||||
let location = decoration_locations
|
||||
.entry(storage_class)
|
||||
.or_insert_with(|| 0);
|
||||
self.emit_global().decorate(
|
||||
variable,
|
||||
Decoration::Location,
|
||||
std::iter::once(Operand::LiteralInt32(*location)),
|
||||
);
|
||||
*location += 1;
|
||||
}
|
||||
(variable, storage_class)
|
||||
}
|
||||
|
||||
// Kernel mode takes its interface as function parameters(??)
|
||||
// OpEntryPoints cannot be OpLinkage, so write out a stub to call through.
|
||||
fn kernel_entry_stub(
|
||||
&self,
|
||||
entry_func: SpirvValue,
|
||||
name: String,
|
||||
execution_model: ExecutionModel,
|
||||
) {
|
||||
let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) {
|
||||
SpirvType::Function {
|
||||
return_type,
|
||||
arguments,
|
||||
} => (return_type, arguments),
|
||||
other => self.tcx.sess.fatal(&format!(
|
||||
"Invalid kernel_entry_stub type: {}",
|
||||
other.debug(entry_func.ty, self)
|
||||
)),
|
||||
};
|
||||
let mut emit = self.emit_global();
|
||||
let fn_id = emit
|
||||
.begin_function(
|
||||
entry_func_return,
|
||||
None,
|
||||
FunctionControl::NONE,
|
||||
entry_func.ty,
|
||||
)
|
||||
.unwrap();
|
||||
let arguments = entry_func_args
|
||||
.iter()
|
||||
.map(|&ty| emit.function_parameter(ty).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
emit.begin_block(None).unwrap();
|
||||
let call_result = emit
|
||||
.function_call(entry_func_return, None, entry_func.def, arguments)
|
||||
.unwrap();
|
||||
if self.lookup_type(entry_func_return) == SpirvType::Void {
|
||||
emit.ret().unwrap();
|
||||
} else {
|
||||
emit.ret_value(call_result).unwrap();
|
||||
}
|
||||
emit.end_function().unwrap();
|
||||
|
||||
emit.entry_point(execution_model, fn_id, name, &[]);
|
||||
}
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
mod constant;
|
||||
mod declare;
|
||||
mod entry;
|
||||
mod type_;
|
||||
|
||||
use crate::builder::ExtInst;
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::codegen_cx::CodegenCx;
|
||||
use rspirv::spirv::{BuiltIn, ExecutionModel, StorageClass};
|
||||
use rustc_ast::ast::{AttrKind, Attribute};
|
||||
use rustc_ast::ast::{AttrKind, Attribute, Lit, LitIntType, LitKind, NestedMetaItem};
|
||||
use rustc_span::symbol::Symbol;
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -13,19 +13,14 @@ pub struct Symbols {
|
||||
pub spirv: Symbol,
|
||||
pub spirv_std: Symbol,
|
||||
pub kernel: Symbol,
|
||||
pub builtin: Symbol,
|
||||
pub storage_class: Symbol,
|
||||
pub entry: Symbol,
|
||||
pub really_unsafe_ignore_bitcasts: Symbol,
|
||||
|
||||
builtins: HashMap<Symbol, BuiltIn>,
|
||||
storage_classes: HashMap<Symbol, StorageClass>,
|
||||
execution_models: HashMap<Symbol, ExecutionModel>,
|
||||
descriptor_set: Symbol,
|
||||
binding: Symbol,
|
||||
attributes: HashMap<Symbol, SpirvAttribute>,
|
||||
}
|
||||
|
||||
fn make_builtins() -> HashMap<Symbol, BuiltIn> {
|
||||
const BUILTINS: &[(&str, BuiltIn)] = {
|
||||
use BuiltIn::*;
|
||||
[
|
||||
&[
|
||||
("position", Position),
|
||||
("point_size", PointSize),
|
||||
("clip_distance", ClipDistance),
|
||||
@ -126,15 +121,12 @@ fn make_builtins() -> HashMap<Symbol, BuiltIn> {
|
||||
("warp_id_nv", WarpIDNV),
|
||||
("SMIDNV", SMIDNV),
|
||||
]
|
||||
.iter()
|
||||
.map(|&(a, b)| (Symbol::intern(a), b))
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
|
||||
fn make_storage_classes() -> HashMap<Symbol, StorageClass> {
|
||||
const STORAGE_CLASSES: &[(&str, StorageClass)] = {
|
||||
use StorageClass::*;
|
||||
// make sure these strings stay synced with spirv-std's pointer types
|
||||
[
|
||||
&[
|
||||
("uniform_constant", UniformConstant),
|
||||
("input", Input),
|
||||
("uniform", Uniform),
|
||||
@ -165,14 +157,11 @@ fn make_storage_classes() -> HashMap<Symbol, StorageClass> {
|
||||
),
|
||||
("physical_storage_buffer", PhysicalStorageBuffer),
|
||||
]
|
||||
.iter()
|
||||
.map(|&(a, b)| (Symbol::intern(a), b))
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
|
||||
fn make_execution_models() -> HashMap<Symbol, ExecutionModel> {
|
||||
const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
|
||||
use ExecutionModel::*;
|
||||
[
|
||||
&[
|
||||
("vertex", Vertex),
|
||||
("tessellation_control", TessellationControl),
|
||||
("tessellation_evaluation", TessellationEvaluation),
|
||||
@ -189,54 +178,65 @@ fn make_execution_models() -> HashMap<Symbol, ExecutionModel> {
|
||||
("miss_nv", MissNV),
|
||||
("callable_nv", CallableNV),
|
||||
]
|
||||
.iter()
|
||||
.map(|&(a, b)| (Symbol::intern(a), b))
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
|
||||
impl Symbols {
|
||||
pub fn new() -> Self {
|
||||
let builtins = BUILTINS
|
||||
.iter()
|
||||
.map(|&(a, b)| (a, SpirvAttribute::Builtin(b)));
|
||||
let storage_classes = STORAGE_CLASSES
|
||||
.iter()
|
||||
.map(|&(a, b)| (a, SpirvAttribute::StorageClass(b)));
|
||||
let execution_models = EXECUTION_MODELS
|
||||
.iter()
|
||||
.map(|&(a, b)| (a, SpirvAttribute::Entry(b)));
|
||||
let custom = std::iter::once((
|
||||
"really_unsafe_ignore_bitcasts",
|
||||
SpirvAttribute::ReallyUnsafeIgnoreBitcasts,
|
||||
));
|
||||
let attributes_iter = builtins
|
||||
.chain(storage_classes)
|
||||
.chain(execution_models)
|
||||
.chain(custom)
|
||||
.map(|(a, b)| (Symbol::intern(a), b));
|
||||
let mut attributes = HashMap::new();
|
||||
for attr in attributes_iter {
|
||||
let old = attributes.insert(attr.0, attr.1);
|
||||
// `.collect()` into a HashMap does not error on duplicates, so manually write out the
|
||||
// loop here to error on duplicates.
|
||||
assert!(old.is_none());
|
||||
}
|
||||
Self {
|
||||
spirv: Symbol::intern("spirv"),
|
||||
spirv_std: Symbol::intern("spirv_std"),
|
||||
kernel: Symbol::intern("kernel"),
|
||||
builtin: Symbol::intern("builtin"),
|
||||
storage_class: Symbol::intern("storage_class"),
|
||||
entry: Symbol::intern("entry"),
|
||||
really_unsafe_ignore_bitcasts: Symbol::intern("really_unsafe_ignore_bitcasts"),
|
||||
builtins: make_builtins(),
|
||||
storage_classes: make_storage_classes(),
|
||||
execution_models: make_execution_models(),
|
||||
descriptor_set: Symbol::intern("descriptor_set"),
|
||||
binding: Symbol::intern("binding"),
|
||||
attributes,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn symbol_to_builtin(&self, sym: Symbol) -> Option<BuiltIn> {
|
||||
self.builtins.get(&sym).copied()
|
||||
}
|
||||
|
||||
pub fn symbol_to_storageclass(&self, sym: Symbol) -> Option<StorageClass> {
|
||||
self.storage_classes.get(&sym).copied()
|
||||
}
|
||||
|
||||
pub fn symbol_to_execution_model(&self, sym: Symbol) -> Option<ExecutionModel> {
|
||||
self.execution_models.get(&sym).copied()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SpirvAttribute {
|
||||
Builtin(BuiltIn),
|
||||
StorageClass(StorageClass),
|
||||
Entry(ExecutionModel),
|
||||
DescriptorSet(u32),
|
||||
Binding(u32),
|
||||
ReallyUnsafeIgnoreBitcasts,
|
||||
}
|
||||
|
||||
// Note that we could mark the attr as used via cx.tcx.sess.mark_attr_used(attr), but unused reporting already happens
|
||||
// even before we get here :(
|
||||
/// Returns None if this attribute is not a spirv attribute, or if it's malformed (and an error is reported).
|
||||
pub fn parse_attr<'tcx>(cx: &CodegenCx<'tcx>, attr: &Attribute) -> Option<SpirvAttribute> {
|
||||
// Example attributes that we parse here:
|
||||
// #[spirv(storage_class = "uniform")]
|
||||
// #[spirv(entry = "kernel")]
|
||||
// Note that we could mark the attr as used via cx.tcx.sess.mark_attr_used(attr), but unused
|
||||
// reporting already happens even before we get here :(
|
||||
/// Returns empty if this attribute is not a spirv attribute, or if it's malformed (and an error is
|
||||
/// reported).
|
||||
pub fn parse_attrs(
|
||||
cx: &CodegenCx<'_>,
|
||||
attrs: &[Attribute],
|
||||
) -> impl Iterator<Item = SpirvAttribute> {
|
||||
let result = attrs.iter().flat_map(|attr| {
|
||||
let is_spirv = match attr.kind {
|
||||
AttrKind::Normal(ref item) => {
|
||||
// TODO: We ignore the rest of the path. Is this right?
|
||||
@ -245,81 +245,76 @@ pub fn parse_attr<'tcx>(cx: &CodegenCx<'tcx>, attr: &Attribute) -> Option<SpirvA
|
||||
}
|
||||
AttrKind::DocComment(..) => false,
|
||||
};
|
||||
if !is_spirv {
|
||||
return None;
|
||||
}
|
||||
let args = if let Some(args) = attr.meta_item_list() {
|
||||
let args = if !is_spirv {
|
||||
// Use an empty vec here to return empty
|
||||
Vec::new()
|
||||
} else if let Some(args) = attr.meta_item_list() {
|
||||
args
|
||||
} else {
|
||||
cx.tcx
|
||||
.sess
|
||||
.span_err(attr.span, "#[spirv(..)] attribute must have one argument");
|
||||
return None;
|
||||
cx.tcx.sess.span_err(
|
||||
attr.span,
|
||||
"#[spirv(..)] attribute must have at least one argument",
|
||||
);
|
||||
Vec::new()
|
||||
};
|
||||
if args.len() != 1 {
|
||||
cx.tcx
|
||||
.sess
|
||||
.span_err(attr.span, "#[spirv(..)] attribute must have one argument");
|
||||
args.into_iter().filter_map(move |ref arg| {
|
||||
if arg.has_name(cx.sym.descriptor_set) {
|
||||
match parse_attr_int_value(cx, arg) {
|
||||
Some(x) => Some(SpirvAttribute::DescriptorSet(x)),
|
||||
None => None,
|
||||
}
|
||||
} else if arg.has_name(cx.sym.binding) {
|
||||
match parse_attr_int_value(cx, arg) {
|
||||
Some(x) => Some(SpirvAttribute::Binding(x)),
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
let name = match arg.ident() {
|
||||
Some(i) => i,
|
||||
None => {
|
||||
cx.tcx.sess.span_err(
|
||||
arg.span(),
|
||||
"#[spirv(..)] attribute argument must be single identifier",
|
||||
);
|
||||
return None;
|
||||
}
|
||||
let arg = &args[0];
|
||||
if arg.has_name(cx.sym.builtin) {
|
||||
if let Some(builtin_arg) = arg.value_str() {
|
||||
match cx.sym.symbol_to_builtin(builtin_arg) {
|
||||
Some(builtin) => Some(SpirvAttribute::Builtin(builtin)),
|
||||
None => {
|
||||
cx.tcx.sess.span_err(attr.span, "unknown spir-v builtin");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cx.tcx.sess.span_err(
|
||||
attr.span,
|
||||
"builtin must have value: #[spirv(builtin = \"..\")]",
|
||||
);
|
||||
None
|
||||
}
|
||||
} else if arg.has_name(cx.sym.storage_class) {
|
||||
if let Some(storage_arg) = arg.value_str() {
|
||||
match cx.sym.symbol_to_storageclass(storage_arg) {
|
||||
Some(storage_class) => Some(SpirvAttribute::StorageClass(storage_class)),
|
||||
};
|
||||
match cx.sym.attributes.get(&name.name) {
|
||||
Some(a) => Some(a.clone()),
|
||||
None => {
|
||||
cx.tcx
|
||||
.sess
|
||||
.span_err(attr.span, "unknown spir-v storage class");
|
||||
.span_err(name.span, "unknown argument to spirv attribute");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cx.tcx.sess.span_err(
|
||||
attr.span,
|
||||
"storage_class must have value: #[spirv(storage_class = \"..\")]",
|
||||
);
|
||||
None
|
||||
}
|
||||
} else if arg.has_name(cx.sym.entry) {
|
||||
if let Some(storage_arg) = arg.value_str() {
|
||||
match cx.sym.symbol_to_execution_model(storage_arg) {
|
||||
Some(execution_model) => Some(SpirvAttribute::Entry(execution_model)),
|
||||
})
|
||||
});
|
||||
// lifetimes are hard :(
|
||||
result.collect::<Vec<_>>().into_iter()
|
||||
}
|
||||
|
||||
fn parse_attr_int_value(cx: &CodegenCx<'_>, arg: &NestedMetaItem) -> Option<u32> {
|
||||
let arg = match arg.meta_item() {
|
||||
Some(arg) => arg,
|
||||
None => {
|
||||
cx.tcx
|
||||
.sess
|
||||
.span_err(attr.span, "unknown spir-v execution model");
|
||||
None
|
||||
.span_err(arg.span(), "attribute must have value");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
};
|
||||
match arg.name_value_literal() {
|
||||
Some(&Lit {
|
||||
kind: LitKind::Int(x, LitIntType::Unsuffixed),
|
||||
..
|
||||
}) if x <= u32::MAX as u128 => Some(x as u32),
|
||||
_ => {
|
||||
cx.tcx
|
||||
.sess
|
||||
.span_err(attr.span, "entry must have value: #[spirv(entry = \"..\")]");
|
||||
.span_err(arg.span, "attribute value must be integer");
|
||||
None
|
||||
}
|
||||
} else if arg.has_name(cx.sym.really_unsafe_ignore_bitcasts) {
|
||||
Some(SpirvAttribute::ReallyUnsafeIgnoreBitcasts)
|
||||
} else {
|
||||
cx.tcx
|
||||
.sess
|
||||
.span_err(attr.span, "unknown argument to spirv attribute");
|
||||
None
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ use super::val;
|
||||
fn hello_world() {
|
||||
val(r#"
|
||||
#[allow(unused_attributes)]
|
||||
#[spirv(entry = "fragment")]
|
||||
#[spirv(fragment)]
|
||||
pub fn main() {
|
||||
}
|
||||
"#);
|
||||
|
@ -54,9 +54,9 @@ macro_rules! pointer_addrspace_write {
|
||||
}
|
||||
|
||||
macro_rules! pointer_addrspace {
|
||||
($storage_class:literal, $type_name:ident, $writeable:tt) => {
|
||||
($storage_class:ident, $type_name:ident, $writeable:tt) => {
|
||||
#[allow(unused_attributes)]
|
||||
#[spirv(storage_class = $storage_class)]
|
||||
#[spirv($storage_class)]
|
||||
pub struct $type_name<'a, T> {
|
||||
x: &'a mut T,
|
||||
}
|
||||
@ -76,26 +76,26 @@ macro_rules! pointer_addrspace {
|
||||
|
||||
// Make sure these strings stay synced with symbols.rs
|
||||
// Note the type names don't have to match anything, they can be renamed (only the string must match)
|
||||
pointer_addrspace!("uniform_constant", UniformConstant, false);
|
||||
pointer_addrspace!("input", Input, false);
|
||||
pointer_addrspace!("uniform", Uniform, true);
|
||||
pointer_addrspace!("output", Output, true);
|
||||
pointer_addrspace!("workgroup", Workgroup, true);
|
||||
pointer_addrspace!("cross_workgroup", CrossWorkgroup, true);
|
||||
pointer_addrspace!("private", Private, true);
|
||||
pointer_addrspace!("function", Function, true);
|
||||
pointer_addrspace!("generic", Generic, true);
|
||||
pointer_addrspace!("push_constant", PushConstant, false);
|
||||
pointer_addrspace!("atomic_counter", AtomicCounter, true);
|
||||
pointer_addrspace!("image", Image, true);
|
||||
pointer_addrspace!("storage_buffer", StorageBuffer, true);
|
||||
pointer_addrspace!("callable_data_khr", CallableDataKHR, true);
|
||||
pointer_addrspace!("incoming_callable_data_khr", IncomingCallableDataKHR, true);
|
||||
pointer_addrspace!("ray_payload_khr", RayPayloadKHR, true);
|
||||
pointer_addrspace!("hit_attribute_khr", HitAttributeKHR, true);
|
||||
pointer_addrspace!("incoming_ray_payload_khr", IncomingRayPayloadKHR, true);
|
||||
pointer_addrspace!("shader_record_buffer_khr", ShaderRecordBufferKHR, true);
|
||||
pointer_addrspace!("physical_storage_buffer", PhysicalStorageBuffer, true);
|
||||
pointer_addrspace!(uniform_constant, UniformConstant, false);
|
||||
pointer_addrspace!(input, Input, false);
|
||||
pointer_addrspace!(uniform, Uniform, true);
|
||||
pointer_addrspace!(output, Output, true);
|
||||
pointer_addrspace!(workgroup, Workgroup, true);
|
||||
pointer_addrspace!(cross_workgroup, CrossWorkgroup, true);
|
||||
pointer_addrspace!(private, Private, true);
|
||||
pointer_addrspace!(function, Function, true);
|
||||
pointer_addrspace!(generic, Generic, true);
|
||||
pointer_addrspace!(push_constant, PushConstant, false);
|
||||
pointer_addrspace!(atomic_counter, AtomicCounter, true);
|
||||
pointer_addrspace!(image, Image, true);
|
||||
pointer_addrspace!(storage_buffer, StorageBuffer, true);
|
||||
pointer_addrspace!(callable_data_khr, CallableDataKHR, true);
|
||||
pointer_addrspace!(incoming_callable_data_khr, IncomingCallableDataKHR, true);
|
||||
pointer_addrspace!(ray_payload_khr, RayPayloadKHR, true);
|
||||
pointer_addrspace!(hit_attribute_khr, HitAttributeKHR, true);
|
||||
pointer_addrspace!(incoming_ray_payload_khr, IncomingRayPayloadKHR, true);
|
||||
pointer_addrspace!(shader_record_buffer_khr, ShaderRecordBufferKHR, true);
|
||||
pointer_addrspace!(physical_storage_buffer, PhysicalStorageBuffer, true);
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
Loading…
Reference in New Issue
Block a user