Require local size x dimension and remove gl_ (#495)

* Require local size and remove gl_

Removes the gl_ prefix from the compute shader attribute, shortens the thread dimension declaration to threads(x, y, z), requires the x size dimensions be specified, trailing ones may be elided for the y or z dimensions.

* Implement review suggestions
This commit is contained in:
Henno 2021-03-16 04:12:21 -04:00 committed by GitHub
parent a173208d80
commit eebb2d3b32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 25 deletions

View File

@ -198,7 +198,7 @@ const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
("tessellation_evaluation", TessellationEvaluation),
("geometry", Geometry),
("fragment", Fragment),
("gl_compute", GLCompute),
("compute", GLCompute),
("kernel", Kernel),
("task_nv", TaskNV),
("mesh_nv", MeshNV),
@ -218,6 +218,7 @@ enum ExecutionModeExtraDim {
X,
Y,
Z,
Tuple,
}
const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
@ -240,9 +241,7 @@ const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
("depth_greater", DepthGreater, None),
("depth_less", DepthLess, None),
("depth_unchanged", DepthUnchanged, None),
("local_size_x", LocalSize, X),
("local_size_y", LocalSize, Y),
("local_size_z", LocalSize, Z),
("threads", LocalSize, Tuple),
("local_size_hint_x", LocalSizeHint, X),
("local_size_hint_y", LocalSizeHint, Y),
("local_size_hint_z", LocalSizeHint, Z),
@ -690,6 +689,40 @@ fn parse_attr_int_value(arg: &NestedMetaItem) -> Result<u32, ParseAttrError> {
}
}
fn parse_local_size_attr(arg: &NestedMetaItem) -> Result<[u32; 3], ParseAttrError> {
let arg = match arg.meta_item() {
Some(arg) => arg,
None => return Err((arg.span(), "attribute must have value".to_string())),
};
match arg.meta_item_list() {
Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
let mut local_size = [1; 3];
for (idx, lit) in tuple.iter().enumerate() {
match lit.literal() {
Some(&Lit {
kind: LitKind::Int(x, LitIntType::Unsuffixed),
..
}) if x <= u32::MAX as u128 => local_size[idx] = x as u32,
_ => return Err((lit.span(), "must be a u32 literal".to_string())),
}
}
Ok(local_size)
}
Some(tuple) if tuple.is_empty() => Err((
arg.span,
"#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
)),
Some(tuple) if tuple.len() > 3 => Err((
arg.span,
"#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
)),
_ => Err((
arg.span,
"#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
)),
}
}
// for a given entry, gather up the additional attributes
// in this case ExecutionMode's, some have extra arguments
// others are specified with x, y, or z components
@ -715,7 +748,7 @@ fn parse_entry_attrs(
{
use ExecutionModeExtraDim::*;
let val = match extra_dim {
None => Option::None,
None | Tuple => Option::None,
_ => Some(parse_attr_int_value(attr)?),
};
match execution_mode {
@ -723,22 +756,15 @@ fn parse_entry_attrs(
origin_mode.replace(*execution_mode);
}
LocalSize => {
let val = val.unwrap();
if local_size.is_none() {
local_size.replace([1, 1, 1]);
}
let local_size = local_size.as_mut().unwrap();
match extra_dim {
X => {
local_size[0] = val;
}
Y => {
local_size[1] = val;
}
Z => {
local_size[2] = val;
}
_ => unreachable!(),
local_size.replace(parse_local_size_attr(attr)?);
} else {
return Err((
attr_name.span,
String::from(
"`#[spirv(compute(threads))]` may only be specified once",
),
));
}
}
LocalSizeHint => {
@ -838,10 +864,18 @@ fn parse_entry_attrs(
.push((origin_mode, ExecutionModeExtra::new([])));
}
GLCompute => {
let local_size = local_size.unwrap_or([1, 1, 1]);
entry
.execution_modes
.push((LocalSize, ExecutionModeExtra::new(local_size)));
if let Some(local_size) = local_size {
entry
.execution_modes
.push((LocalSize, ExecutionModeExtra::new(local_size)));
} else {
return Err((
arg.span(),
String::from(
"The `threads` argument must be specified when using `#[spirv(compute)]`",
),
));
}
}
Kernel => {
if let Some(local_size) = local_size {

View File

@ -13,5 +13,6 @@ extern crate spirv_std;
#[macro_use]
pub extern crate spirv_std_macros;
#[spirv(gl_compute)]
// LocalSize/numthreads of (x = 32, y = 1, z = 1)
#[spirv(compute(threads(32)))]
pub fn main_cs() {}