mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-22 06:45:13 +00:00
Require local size x dimension and remove gl_ (#495)
* Require local size and remove gl_ Removes the gl_ prefix from the compute shader attribute, shortens the thread dimension declaration to threads(x, y, z), requires the x size dimensions be specified, trailing ones may be elided for the y or z dimensions. * Implement review suggestions
This commit is contained in:
parent
a173208d80
commit
eebb2d3b32
@ -198,7 +198,7 @@ const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
|
||||
("tessellation_evaluation", TessellationEvaluation),
|
||||
("geometry", Geometry),
|
||||
("fragment", Fragment),
|
||||
("gl_compute", GLCompute),
|
||||
("compute", GLCompute),
|
||||
("kernel", Kernel),
|
||||
("task_nv", TaskNV),
|
||||
("mesh_nv", MeshNV),
|
||||
@ -218,6 +218,7 @@ enum ExecutionModeExtraDim {
|
||||
X,
|
||||
Y,
|
||||
Z,
|
||||
Tuple,
|
||||
}
|
||||
|
||||
const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
|
||||
@ -240,9 +241,7 @@ const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
|
||||
("depth_greater", DepthGreater, None),
|
||||
("depth_less", DepthLess, None),
|
||||
("depth_unchanged", DepthUnchanged, None),
|
||||
("local_size_x", LocalSize, X),
|
||||
("local_size_y", LocalSize, Y),
|
||||
("local_size_z", LocalSize, Z),
|
||||
("threads", LocalSize, Tuple),
|
||||
("local_size_hint_x", LocalSizeHint, X),
|
||||
("local_size_hint_y", LocalSizeHint, Y),
|
||||
("local_size_hint_z", LocalSizeHint, Z),
|
||||
@ -690,6 +689,40 @@ fn parse_attr_int_value(arg: &NestedMetaItem) -> Result<u32, ParseAttrError> {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_local_size_attr(arg: &NestedMetaItem) -> Result<[u32; 3], ParseAttrError> {
|
||||
let arg = match arg.meta_item() {
|
||||
Some(arg) => arg,
|
||||
None => return Err((arg.span(), "attribute must have value".to_string())),
|
||||
};
|
||||
match arg.meta_item_list() {
|
||||
Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
|
||||
let mut local_size = [1; 3];
|
||||
for (idx, lit) in tuple.iter().enumerate() {
|
||||
match lit.literal() {
|
||||
Some(&Lit {
|
||||
kind: LitKind::Int(x, LitIntType::Unsuffixed),
|
||||
..
|
||||
}) if x <= u32::MAX as u128 => local_size[idx] = x as u32,
|
||||
_ => return Err((lit.span(), "must be a u32 literal".to_string())),
|
||||
}
|
||||
}
|
||||
Ok(local_size)
|
||||
}
|
||||
Some(tuple) if tuple.is_empty() => Err((
|
||||
arg.span,
|
||||
"#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
|
||||
)),
|
||||
Some(tuple) if tuple.len() > 3 => Err((
|
||||
arg.span,
|
||||
"#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
|
||||
)),
|
||||
_ => Err((
|
||||
arg.span,
|
||||
"#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// for a given entry, gather up the additional attributes
|
||||
// in this case ExecutionMode's, some have extra arguments
|
||||
// others are specified with x, y, or z components
|
||||
@ -715,7 +748,7 @@ fn parse_entry_attrs(
|
||||
{
|
||||
use ExecutionModeExtraDim::*;
|
||||
let val = match extra_dim {
|
||||
None => Option::None,
|
||||
None | Tuple => Option::None,
|
||||
_ => Some(parse_attr_int_value(attr)?),
|
||||
};
|
||||
match execution_mode {
|
||||
@ -723,22 +756,15 @@ fn parse_entry_attrs(
|
||||
origin_mode.replace(*execution_mode);
|
||||
}
|
||||
LocalSize => {
|
||||
let val = val.unwrap();
|
||||
if local_size.is_none() {
|
||||
local_size.replace([1, 1, 1]);
|
||||
}
|
||||
let local_size = local_size.as_mut().unwrap();
|
||||
match extra_dim {
|
||||
X => {
|
||||
local_size[0] = val;
|
||||
}
|
||||
Y => {
|
||||
local_size[1] = val;
|
||||
}
|
||||
Z => {
|
||||
local_size[2] = val;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
local_size.replace(parse_local_size_attr(attr)?);
|
||||
} else {
|
||||
return Err((
|
||||
attr_name.span,
|
||||
String::from(
|
||||
"`#[spirv(compute(threads))]` may only be specified once",
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
LocalSizeHint => {
|
||||
@ -838,10 +864,18 @@ fn parse_entry_attrs(
|
||||
.push((origin_mode, ExecutionModeExtra::new([])));
|
||||
}
|
||||
GLCompute => {
|
||||
let local_size = local_size.unwrap_or([1, 1, 1]);
|
||||
if let Some(local_size) = local_size {
|
||||
entry
|
||||
.execution_modes
|
||||
.push((LocalSize, ExecutionModeExtra::new(local_size)));
|
||||
} else {
|
||||
return Err((
|
||||
arg.span(),
|
||||
String::from(
|
||||
"The `threads` argument must be specified when using `#[spirv(compute)]`",
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
Kernel => {
|
||||
if let Some(local_size) = local_size {
|
||||
|
@ -13,5 +13,6 @@ extern crate spirv_std;
|
||||
#[macro_use]
|
||||
pub extern crate spirv_std_macros;
|
||||
|
||||
#[spirv(gl_compute)]
|
||||
// LocalSize/numthreads of (x = 32, y = 1, z = 1)
|
||||
#[spirv(compute(threads(32)))]
|
||||
pub fn main_cs() {}
|
||||
|
Loading…
Reference in New Issue
Block a user