mirror of
https://github.com/vulkano-rs/vulkano.git
synced 2024-11-21 22:34:43 +00:00
Added required_subgroup_size to PipelineShaderStageCreateInfo (#2235)
* Added required_subgroup_size to PipelineShaderStageCreateInfo * Added validation errors. * Fixed error msgs / vuids. * ComputeShaderExecution for validating local_size. * WorkgroupSizeId reflection. * contains_enum * Reworked ComputeShaderExecution. * panic msgs. * workgroup size validation * unused import * fixed test deprecated fn * catch workgroup size overflow * EntryPointInfo::local_size docs * comments * typo + error msg
This commit is contained in:
parent
4133a3bf63
commit
ee4e3089d3
@ -73,7 +73,49 @@ fn write_shader_execution(execution: &ShaderExecution) -> TokenStream {
|
||||
)
|
||||
}
|
||||
}
|
||||
ShaderExecution::Compute => quote! { ::vulkano::shader::ShaderExecution::Compute },
|
||||
ShaderExecution::Compute(execution) => {
|
||||
use ::quote::ToTokens;
|
||||
use ::vulkano::shader::{ComputeShaderExecution, LocalSize};
|
||||
|
||||
struct LocalSizeToTokens(LocalSize);
|
||||
|
||||
impl ToTokens for LocalSizeToTokens {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
match self.0 {
|
||||
LocalSize::Literal(literal) => quote! {
|
||||
::vulkano::shader::LocalSize::Literal(#literal)
|
||||
},
|
||||
LocalSize::SpecId(id) => quote! {
|
||||
::vulkano::shader::LocalSize::SpecId(#id)
|
||||
},
|
||||
}
|
||||
.to_tokens(tokens);
|
||||
}
|
||||
}
|
||||
|
||||
match execution {
|
||||
ComputeShaderExecution::LocalSize([x, y, z]) => {
|
||||
let [x, y, z] = [
|
||||
LocalSizeToTokens(*x),
|
||||
LocalSizeToTokens(*y),
|
||||
LocalSizeToTokens(*z),
|
||||
];
|
||||
quote! { ::vulkano::shader::ShaderExecution::Compute(
|
||||
::vulkano::shader::ComputeShaderExecution::LocalSize([#x, #y, #z])
|
||||
) }
|
||||
}
|
||||
ComputeShaderExecution::LocalSizeId([x, y, z]) => {
|
||||
let [x, y, z] = [
|
||||
LocalSizeToTokens(*x),
|
||||
LocalSizeToTokens(*y),
|
||||
LocalSizeToTokens(*z),
|
||||
];
|
||||
quote! { ::vulkano::shader::ShaderExecution::Compute(
|
||||
::vulkano::shader::ComputeShaderExecution::LocalSizeId([#x, #y, #z])
|
||||
) }
|
||||
}
|
||||
}
|
||||
}
|
||||
ShaderExecution::RayGeneration => {
|
||||
quote! { ::vulkano::shader::ShaderExecution::RayGeneration }
|
||||
}
|
||||
|
@ -74,11 +74,9 @@ impl ComputePipeline {
|
||||
if let Some(cache) = &cache {
|
||||
assert_eq!(device, cache.device().as_ref());
|
||||
}
|
||||
|
||||
create_info
|
||||
.validate(device)
|
||||
.map_err(|err| err.add_context("create_info"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -100,12 +98,14 @@ impl ComputePipeline {
|
||||
let specialization_info_vk;
|
||||
let specialization_map_entries_vk: Vec<_>;
|
||||
let mut specialization_data_vk: Vec<u8>;
|
||||
let required_subgroup_size_create_info;
|
||||
|
||||
{
|
||||
let &PipelineShaderStageCreateInfo {
|
||||
flags,
|
||||
ref entry_point,
|
||||
ref specialization_info,
|
||||
ref required_subgroup_size,
|
||||
_ne: _,
|
||||
} = stage;
|
||||
|
||||
@ -135,7 +135,20 @@ impl ComputePipeline {
|
||||
data_size: specialization_data_vk.len(),
|
||||
p_data: specialization_data_vk.as_ptr() as *const _,
|
||||
};
|
||||
required_subgroup_size_create_info =
|
||||
required_subgroup_size.map(|required_subgroup_size| {
|
||||
ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo {
|
||||
required_subgroup_size,
|
||||
..Default::default()
|
||||
}
|
||||
});
|
||||
stage_vk = ash::vk::PipelineShaderStageCreateInfo {
|
||||
p_next: required_subgroup_size_create_info.as_ref().map_or(
|
||||
ptr::null(),
|
||||
|required_subgroup_size_create_info| {
|
||||
required_subgroup_size_create_info as *const _ as _
|
||||
},
|
||||
),
|
||||
flags: flags.into(),
|
||||
stage: ShaderStage::from(&entry_point_info.execution).into(),
|
||||
module: entry_point.module().handle(),
|
||||
@ -333,12 +346,13 @@ impl ComputePipelineCreateInfo {
|
||||
flags: _,
|
||||
ref entry_point,
|
||||
specialization_info: _,
|
||||
required_subgroup_size: _vk,
|
||||
_ne: _,
|
||||
} = &stage;
|
||||
|
||||
let entry_point_info = entry_point.info();
|
||||
|
||||
if !matches!(entry_point_info.execution, ShaderExecution::Compute) {
|
||||
if !matches!(entry_point_info.execution, ShaderExecution::Compute(_)) {
|
||||
return Err(Box::new(ValidationError {
|
||||
context: "stage.entry_point".into(),
|
||||
problem: "is not a `ShaderStage::Compute` entry point".into(),
|
||||
@ -514,4 +528,135 @@ mod tests {
|
||||
let data_buffer_content = data_buffer.read().unwrap();
|
||||
assert_eq!(*data_buffer_content, 0x12345678);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn required_subgroup_size() {
|
||||
// This test checks whether required_subgroup_size works.
|
||||
// It executes a single compute shader (one invocation) that writes the subgroup size
|
||||
// to a buffer. The buffer content is then checked for the right value.
|
||||
|
||||
let (device, queue) = gfx_dev_and_queue!(subgroup_size_control);
|
||||
|
||||
let cs = unsafe {
|
||||
/*
|
||||
#version 450
|
||||
|
||||
#extension GL_KHR_shader_subgroup_basic: enable
|
||||
|
||||
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(set = 0, binding = 0) buffer Output {
|
||||
uint write;
|
||||
} write;
|
||||
|
||||
void main() {
|
||||
if (gl_GlobalInvocationID.x == 0) {
|
||||
write.write = gl_SubgroupSize;
|
||||
}
|
||||
}
|
||||
*/
|
||||
const MODULE: [u32; 246] = [
|
||||
119734787, 65536, 851978, 30, 0, 131089, 1, 131089, 61, 393227, 1, 1280527431,
|
||||
1685353262, 808793134, 0, 196622, 0, 1, 458767, 5, 4, 1852399981, 0, 9, 23, 393232,
|
||||
4, 17, 128, 1, 1, 196611, 2, 450, 655364, 1197427783, 1279741775, 1885560645,
|
||||
1953718128, 1600482425, 1701734764, 1919509599, 1769235301, 25974, 524292,
|
||||
1197427783, 1279741775, 1852399429, 1685417059, 1768185701, 1952671090, 6649449,
|
||||
589828, 1264536647, 1935626824, 1701077352, 1970495346, 1869768546, 1650421877,
|
||||
1667855201, 0, 262149, 4, 1852399981, 0, 524293, 9, 1197436007, 1633841004,
|
||||
1986939244, 1952539503, 1231974249, 68, 262149, 18, 1886680399, 29813, 327686, 18,
|
||||
0, 1953067639, 101, 262149, 20, 1953067639, 101, 393221, 23, 1398762599,
|
||||
1919378037, 1399879023, 6650473, 262215, 9, 11, 28, 327752, 18, 0, 35, 0, 196679,
|
||||
18, 3, 262215, 20, 34, 0, 262215, 20, 33, 0, 196679, 23, 0, 262215, 23, 11, 36,
|
||||
196679, 24, 0, 262215, 29, 11, 25, 131091, 2, 196641, 3, 2, 262165, 6, 32, 0,
|
||||
262167, 7, 6, 3, 262176, 8, 1, 7, 262203, 8, 9, 1, 262187, 6, 10, 0, 262176, 11, 1,
|
||||
6, 131092, 14, 196638, 18, 6, 262176, 19, 2, 18, 262203, 19, 20, 2, 262165, 21, 32,
|
||||
1, 262187, 21, 22, 0, 262203, 11, 23, 1, 262176, 25, 2, 6, 262187, 6, 27, 128,
|
||||
262187, 6, 28, 1, 393260, 7, 29, 27, 28, 28, 327734, 2, 4, 0, 3, 131320, 5, 327745,
|
||||
11, 12, 9, 10, 262205, 6, 13, 12, 327850, 14, 15, 13, 10, 196855, 17, 0, 262394,
|
||||
15, 16, 17, 131320, 16, 262205, 6, 24, 23, 327745, 25, 26, 20, 22, 196670, 26, 24,
|
||||
131321, 17, 131320, 17, 65789, 65592,
|
||||
];
|
||||
let module =
|
||||
ShaderModule::new(device.clone(), ShaderModuleCreateInfo::new(&MODULE)).unwrap();
|
||||
module.entry_point("main").unwrap()
|
||||
};
|
||||
|
||||
let properties = device.physical_device().properties();
|
||||
let subgroup_size = properties.min_subgroup_size.unwrap_or(1);
|
||||
|
||||
let pipeline = {
|
||||
let stage = PipelineShaderStageCreateInfo {
|
||||
required_subgroup_size: Some(subgroup_size),
|
||||
..PipelineShaderStageCreateInfo::new(cs)
|
||||
};
|
||||
let layout = PipelineLayout::new(
|
||||
device.clone(),
|
||||
PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage])
|
||||
.into_pipeline_layout_create_info(device.clone())
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
ComputePipeline::new(
|
||||
device.clone(),
|
||||
None,
|
||||
ComputePipelineCreateInfo::stage_layout(stage, layout),
|
||||
)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let memory_allocator = StandardMemoryAllocator::new_default(device.clone());
|
||||
let data_buffer = Buffer::from_data(
|
||||
&memory_allocator,
|
||||
BufferCreateInfo {
|
||||
usage: BufferUsage::STORAGE_BUFFER,
|
||||
..Default::default()
|
||||
},
|
||||
AllocationCreateInfo {
|
||||
memory_type_filter: MemoryTypeFilter::PREFER_DEVICE
|
||||
| MemoryTypeFilter::HOST_RANDOM_ACCESS,
|
||||
..Default::default()
|
||||
},
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let ds_allocator = StandardDescriptorSetAllocator::new(device.clone());
|
||||
let set = PersistentDescriptorSet::new(
|
||||
&ds_allocator,
|
||||
pipeline.layout().set_layouts().get(0).unwrap().clone(),
|
||||
[WriteDescriptorSet::buffer(0, data_buffer.clone())],
|
||||
[],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let cb_allocator = StandardCommandBufferAllocator::new(device.clone(), Default::default());
|
||||
let mut cbb = AutoCommandBufferBuilder::primary(
|
||||
&cb_allocator,
|
||||
queue.queue_family_index(),
|
||||
CommandBufferUsage::OneTimeSubmit,
|
||||
)
|
||||
.unwrap();
|
||||
cbb.bind_pipeline_compute(pipeline.clone())
|
||||
.unwrap()
|
||||
.bind_descriptor_sets(
|
||||
PipelineBindPoint::Compute,
|
||||
pipeline.layout().clone(),
|
||||
0,
|
||||
set,
|
||||
)
|
||||
.unwrap()
|
||||
.dispatch([128, 1, 1])
|
||||
.unwrap();
|
||||
let cb = cbb.build().unwrap();
|
||||
|
||||
let future = now(device)
|
||||
.then_execute(queue, cb)
|
||||
.unwrap()
|
||||
.then_signal_fence_and_flush()
|
||||
.unwrap();
|
||||
future.wait(None).unwrap();
|
||||
|
||||
let data_buffer_content = data_buffer.read().unwrap();
|
||||
assert_eq!(*data_buffer_content, subgroup_size);
|
||||
}
|
||||
}
|
||||
|
@ -167,7 +167,6 @@ impl GraphicsPipeline {
|
||||
create_info
|
||||
.validate(device)
|
||||
.map_err(|err| err.add_context("create_info"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -204,6 +203,8 @@ impl GraphicsPipeline {
|
||||
specialization_info_vk: ash::vk::SpecializationInfo,
|
||||
specialization_map_entries_vk: Vec<ash::vk::SpecializationMapEntry>,
|
||||
specialization_data_vk: Vec<u8>,
|
||||
required_subgroup_size_create_info:
|
||||
Option<ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo>,
|
||||
}
|
||||
|
||||
let (mut stages_vk, mut per_stage_vk): (SmallVec<[_; 5]>, SmallVec<[_; 5]>) = stages
|
||||
@ -213,6 +214,7 @@ impl GraphicsPipeline {
|
||||
flags,
|
||||
ref entry_point,
|
||||
ref specialization_info,
|
||||
ref required_subgroup_size,
|
||||
_ne: _,
|
||||
} = stage;
|
||||
|
||||
@ -235,7 +237,13 @@ impl GraphicsPipeline {
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let required_subgroup_size_create_info =
|
||||
required_subgroup_size.map(|required_subgroup_size| {
|
||||
ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo {
|
||||
required_subgroup_size,
|
||||
..Default::default()
|
||||
}
|
||||
});
|
||||
(
|
||||
ash::vk::PipelineShaderStageCreateInfo {
|
||||
flags: flags.into(),
|
||||
@ -255,6 +263,7 @@ impl GraphicsPipeline {
|
||||
},
|
||||
specialization_map_entries_vk,
|
||||
specialization_data_vk,
|
||||
required_subgroup_size_create_info,
|
||||
},
|
||||
)
|
||||
})
|
||||
@ -267,10 +276,17 @@ impl GraphicsPipeline {
|
||||
specialization_info_vk,
|
||||
specialization_map_entries_vk,
|
||||
specialization_data_vk,
|
||||
required_subgroup_size_create_info,
|
||||
},
|
||||
) in (stages_vk.iter_mut()).zip(per_stage_vk.iter_mut())
|
||||
{
|
||||
*stage_vk = ash::vk::PipelineShaderStageCreateInfo {
|
||||
p_next: required_subgroup_size_create_info.as_ref().map_or(
|
||||
ptr::null(),
|
||||
|required_subgroup_size_create_info| {
|
||||
required_subgroup_size_create_info as *const _ as _
|
||||
},
|
||||
),
|
||||
p_name: name_vk.as_ptr(),
|
||||
p_specialization_info: specialization_info_vk,
|
||||
..*stage_vk
|
||||
@ -2420,6 +2436,7 @@ impl GraphicsPipelineCreateInfo {
|
||||
flags: _,
|
||||
ref entry_point,
|
||||
specialization_info: _,
|
||||
required_subgroup_size: _vk,
|
||||
_ne: _,
|
||||
} = stage;
|
||||
|
||||
|
@ -321,6 +321,21 @@ pub struct PipelineShaderStageCreateInfo {
|
||||
/// The default value is empty.
|
||||
pub specialization_info: HashMap<u32, SpecializationConstant>,
|
||||
|
||||
/// The required subgroup size.
|
||||
///
|
||||
/// Requires [`subgroup_size_control`](crate::device::Features::subgroup_size_control). The
|
||||
/// shader stage must be included in
|
||||
/// [`required_subgroup_size_stages`](crate::device::Properties::required_subgroup_size_stages).
|
||||
/// Subgroup size must be power of 2 and within
|
||||
/// [`min_subgroup_size`](crate::device::Properties::min_subgroup_size)
|
||||
/// and [`max_subgroup_size`](crate::device::Properties::max_subgroup_size).
|
||||
///
|
||||
/// For compute shaders, `max_compute_workgroup_subgroups * required_subgroup_size` must be
|
||||
/// greater than or equal to `workgroup_size.x * workgroup_size.y * workgroup_size.z`.
|
||||
///
|
||||
/// The default value is None.
|
||||
pub required_subgroup_size: Option<u32>,
|
||||
|
||||
pub _ne: crate::NonExhaustive,
|
||||
}
|
||||
|
||||
@ -332,6 +347,7 @@ impl PipelineShaderStageCreateInfo {
|
||||
flags: PipelineShaderStageCreateFlags::empty(),
|
||||
entry_point,
|
||||
specialization_info: HashMap::default(),
|
||||
required_subgroup_size: None,
|
||||
_ne: crate::NonExhaustive(()),
|
||||
}
|
||||
}
|
||||
@ -341,6 +357,7 @@ impl PipelineShaderStageCreateInfo {
|
||||
flags,
|
||||
ref entry_point,
|
||||
ref specialization_info,
|
||||
ref required_subgroup_size,
|
||||
_ne: _,
|
||||
} = self;
|
||||
|
||||
@ -469,10 +486,145 @@ impl PipelineShaderStageCreateInfo {
|
||||
}
|
||||
}
|
||||
|
||||
let workgroup_size = if let Some(local_size) =
|
||||
entry_point_info.local_size(specialization_info)?
|
||||
{
|
||||
let [x, y, z] = local_size;
|
||||
if x == 0 || y == 0 || z == 0 {
|
||||
return Err(Box::new(ValidationError {
|
||||
problem: format!("`workgroup size` {local_size:?} cannot be 0").into(),
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
let properties = device.physical_device().properties();
|
||||
if stage_enum == ShaderStage::Compute {
|
||||
let max_compute_work_group_size = properties.max_compute_work_group_size;
|
||||
let [max_x, max_y, max_z] = max_compute_work_group_size;
|
||||
if x > max_x || y > max_y || z > max_z {
|
||||
return Err(Box::new(ValidationError {
|
||||
problem: format!("`workgroup size` {local_size:?} is greater than `max_compute_work_group_size` {max_compute_work_group_size:?}").into(),
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
let max_invocations = properties.max_compute_work_group_invocations;
|
||||
if let Some(workgroup_size) = (|| x.checked_mul(y)?.checked_mul(z))() {
|
||||
if workgroup_size > max_invocations {
|
||||
return Err(Box::new(ValidationError {
|
||||
problem: format!("the product of `workgroup size` {local_size:?} = {workgroup_size} is greater than `max_compute_work_group_invocations` {max_invocations}").into(),
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
Some(workgroup_size)
|
||||
} else {
|
||||
return Err(Box::new(ValidationError {
|
||||
problem: format!("the product of `workgroup size` {local_size:?} = (overflow) is greater than `max_compute_work_group_invocations` {max_invocations}").into(),
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
} else {
|
||||
// TODO: Additional stages when `.local_size()` supports them.
|
||||
unreachable!()
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(required_subgroup_size) = required_subgroup_size {
|
||||
validate_required_subgroup_size(
|
||||
device,
|
||||
stage_enum,
|
||||
workgroup_size,
|
||||
*required_subgroup_size,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn validate_required_subgroup_size(
|
||||
device: &Device,
|
||||
stage: ShaderStage,
|
||||
workgroup_size: Option<u32>,
|
||||
subgroup_size: u32,
|
||||
) -> Result<(), Box<ValidationError>> {
|
||||
if !device.enabled_features().subgroup_size_control {
|
||||
return Err(Box::new(ValidationError {
|
||||
context: "required_subgroup_size".into(),
|
||||
requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::Feature(
|
||||
"subgroup_size_control",
|
||||
)])]),
|
||||
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"],
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
let properties = device.physical_device().properties();
|
||||
if !properties
|
||||
.required_subgroup_size_stages
|
||||
.unwrap_or_default()
|
||||
.contains_enum(stage)
|
||||
{
|
||||
return Err(Box::new(ValidationError {
|
||||
context: "required_subgroup_size".into(),
|
||||
problem: format!("`shader stage` {stage:?} is not in `required_subgroup_size_stages`")
|
||||
.into(),
|
||||
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"],
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
if !subgroup_size.is_power_of_two() {
|
||||
return Err(Box::new(ValidationError {
|
||||
context: "required_subgroup_size".into(),
|
||||
problem: format!("`subgroup_size` {subgroup_size} is not a power of 2").into(),
|
||||
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02760"],
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
let min_subgroup_size = properties.min_subgroup_size.unwrap_or(1);
|
||||
if subgroup_size < min_subgroup_size {
|
||||
return Err(Box::new(ValidationError {
|
||||
context: "required_subgroup_size".into(),
|
||||
problem: format!(
|
||||
"`subgroup_size` {subgroup_size} is less than `min_subgroup_size` {min_subgroup_size}"
|
||||
)
|
||||
.into(),
|
||||
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02761"],
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
let max_subgroup_size = properties.max_subgroup_size.unwrap_or(128);
|
||||
if subgroup_size > max_subgroup_size {
|
||||
return Err(Box::new(ValidationError {
|
||||
context: "required_subgroup_size".into(),
|
||||
problem:
|
||||
format!("`subgroup_size` {subgroup_size} is greater than `max_subgroup_size` {max_subgroup_size}")
|
||||
.into(),
|
||||
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02762"],
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
if let Some(workgroup_size) = workgroup_size {
|
||||
if stage == ShaderStage::Compute {
|
||||
let max_compute_workgroup_subgroups = properties
|
||||
.max_compute_workgroup_subgroups
|
||||
.unwrap_or_default();
|
||||
if max_compute_workgroup_subgroups
|
||||
.checked_mul(subgroup_size)
|
||||
.unwrap_or(u32::MAX)
|
||||
< workgroup_size
|
||||
{
|
||||
return Err(Box::new(ValidationError {
|
||||
context: "required_subgroup_size".into(),
|
||||
problem: format!("`subgroup_size` {subgroup_size} creates more than {max_compute_workgroup_subgroups} subgroups").into(),
|
||||
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02756"],
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
vulkan_bitflags! {
|
||||
#[non_exhaustive]
|
||||
|
||||
|
@ -549,6 +549,78 @@ pub struct EntryPointInfo {
|
||||
pub output_interface: ShaderInterface,
|
||||
}
|
||||
|
||||
impl EntryPointInfo {
|
||||
/// The local size in Compute shaders, None otherwise.
|
||||
///
|
||||
/// `specialization_info` is used for LocalSizeId / WorkgroupSizeId, using specialization_constants if not found.
|
||||
/// Errors if specialization constants are not found or are not u32's.
|
||||
pub(crate) fn local_size(
|
||||
&self,
|
||||
specialization_info: &HashMap<u32, SpecializationConstant>,
|
||||
) -> Result<Option<[u32; 3]>, Box<ValidationError>> {
|
||||
if let ShaderExecution::Compute(execution) = self.execution {
|
||||
match execution {
|
||||
ComputeShaderExecution::LocalSize(local_size)
|
||||
| ComputeShaderExecution::LocalSizeId(local_size) => {
|
||||
let mut output = [0; 3];
|
||||
for (output, local_size) in output.iter_mut().zip(local_size) {
|
||||
let id = match local_size {
|
||||
LocalSize::Literal(literal) => {
|
||||
*output = literal;
|
||||
continue;
|
||||
}
|
||||
LocalSize::SpecId(id) => id,
|
||||
};
|
||||
if let Some(default) = self.specialization_constants.get(&id) {
|
||||
let default_value = if let SpecializationConstant::U32(default_value) =
|
||||
default
|
||||
{
|
||||
default_value
|
||||
} else {
|
||||
return Err(Box::new(ValidationError {
|
||||
problem: format!(
|
||||
"`entry_point.info().specialization_constants[{id}]` is not a 32 bit integer"
|
||||
)
|
||||
.into(),
|
||||
..Default::default()
|
||||
}));
|
||||
};
|
||||
if let Some(provided) = specialization_info.get(&id) {
|
||||
if let SpecializationConstant::U32(provided_value) = provided {
|
||||
*output = *provided_value;
|
||||
} else {
|
||||
return Err(Box::new(ValidationError {
|
||||
problem: format!(
|
||||
"`specialization_info[{0}]` does not have the same type as \
|
||||
`entry_point.info().specialization_constants[{0}]`",
|
||||
id,
|
||||
)
|
||||
.into(),
|
||||
vuids: &["VUID-VkSpecializationMapEntry-constantID-00776"],
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
}
|
||||
*output = *default_value;
|
||||
} else {
|
||||
return Err(Box::new(ValidationError {
|
||||
problem: format!(
|
||||
"specialization constant {id} not found in `entry_point.info().specialization_constants`"
|
||||
)
|
||||
.into(),
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
}
|
||||
Ok(Some(output))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a shader entry point in a shader module.
|
||||
///
|
||||
/// Can be obtained by calling [`entry_point`](ShaderModule::entry_point) on the shader module.
|
||||
@ -581,15 +653,15 @@ pub enum ShaderExecution {
|
||||
TessellationEvaluation,
|
||||
Geometry(GeometryShaderExecution),
|
||||
Fragment(FragmentShaderExecution),
|
||||
Compute,
|
||||
Compute(ComputeShaderExecution),
|
||||
RayGeneration,
|
||||
AnyHit,
|
||||
ClosestHit,
|
||||
Miss,
|
||||
Intersection,
|
||||
Callable,
|
||||
Task,
|
||||
Mesh,
|
||||
Task, // TODO: like compute?
|
||||
Mesh, // TODO: like compute?
|
||||
SubpassShading,
|
||||
}
|
||||
|
||||
@ -601,7 +673,7 @@ impl From<&ShaderExecution> for ExecutionModel {
|
||||
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
|
||||
ShaderExecution::Geometry(_) => Self::Geometry,
|
||||
ShaderExecution::Fragment(_) => Self::Fragment,
|
||||
ShaderExecution::Compute => Self::GLCompute,
|
||||
ShaderExecution::Compute(_) => Self::GLCompute,
|
||||
ShaderExecution::RayGeneration => Self::RayGenerationKHR,
|
||||
ShaderExecution::AnyHit => Self::AnyHitKHR,
|
||||
ShaderExecution::ClosestHit => Self::ClosestHitKHR,
|
||||
@ -699,6 +771,32 @@ pub enum FragmentTestsStages {
|
||||
EarlyAndLate,
|
||||
}
|
||||
|
||||
/// LocalSize.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum LocalSize {
|
||||
Literal(u32),
|
||||
SpecId(u32),
|
||||
}
|
||||
|
||||
/// The mode in which the compute shader executes.
|
||||
///
|
||||
/// The workgroup size is specified for x, y, and z dimensions.
|
||||
///
|
||||
/// Constants are resolved to literals, while specialization constants
|
||||
/// map to spec ids.
|
||||
///
|
||||
/// The `WorkgroupSize` builtin overrides the values specified in the
|
||||
/// execution mode. It can decorate a 3 component ConstantComposite or
|
||||
/// SpecConstantComposite vector.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum ComputeShaderExecution {
|
||||
/// Workgroup size in x, y, and z.
|
||||
LocalSize([LocalSize; 3]),
|
||||
/// Requires spirv 1.2.
|
||||
/// Like `LocalSize`, but uses ids instead of literals.
|
||||
LocalSizeId([LocalSize; 3]),
|
||||
}
|
||||
|
||||
/// The requirements imposed by a shader on a binding within a descriptor set layout, and on any
|
||||
/// resource that is bound to that binding.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
@ -1263,7 +1361,7 @@ impl From<&ShaderExecution> for ShaderStage {
|
||||
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
|
||||
ShaderExecution::Geometry(_) => Self::Geometry,
|
||||
ShaderExecution::Fragment(_) => Self::Fragment,
|
||||
ShaderExecution::Compute => Self::Compute,
|
||||
ShaderExecution::Compute(_) => Self::Compute,
|
||||
ShaderExecution::RayGeneration => Self::Raygen,
|
||||
ShaderExecution::AnyHit => Self::AnyHit,
|
||||
ShaderExecution::ClosestHit => Self::ClosestHit,
|
||||
|
@ -16,12 +16,13 @@ use crate::{
|
||||
pipeline::layout::PushConstantRange,
|
||||
shader::{
|
||||
spirv::{
|
||||
Capability, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction, Spirv,
|
||||
StorageClass,
|
||||
BuiltIn, Capability, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction,
|
||||
Spirv, StorageClass,
|
||||
},
|
||||
DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, GeometryShaderExecution,
|
||||
GeometryShaderInput, NumericType, ShaderExecution, ShaderInterface, ShaderInterfaceEntry,
|
||||
ShaderInterfaceEntryType, ShaderStage, SpecializationConstant,
|
||||
ComputeShaderExecution, DescriptorIdentifier, DescriptorRequirements, EntryPointInfo,
|
||||
GeometryShaderExecution, GeometryShaderInput, LocalSize, NumericType, ShaderExecution,
|
||||
ShaderInterface, ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage,
|
||||
SpecializationConstant,
|
||||
},
|
||||
DeviceSize,
|
||||
};
|
||||
@ -55,6 +56,9 @@ pub fn spirv_extensions(spirv: &Spirv) -> impl Iterator<Item = &str> {
|
||||
#[inline]
|
||||
pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_ {
|
||||
let interface_variables = interface_variables(spirv);
|
||||
let u32_constants = u32_constants(spirv);
|
||||
let specialization_constant_ids = specialization_constant_ids(spirv);
|
||||
let workgroup_size_decorations = workgroup_size_decorations(spirv);
|
||||
|
||||
spirv.iter_entry_point().filter_map(move |instruction| {
|
||||
let (execution_model, function_id, entry_point_name, interface) = match instruction {
|
||||
@ -68,7 +72,14 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let execution = shader_execution(spirv, execution_model, function_id);
|
||||
let execution = shader_execution(
|
||||
spirv,
|
||||
execution_model,
|
||||
function_id,
|
||||
&u32_constants,
|
||||
&specialization_constant_ids,
|
||||
&workgroup_size_decorations,
|
||||
);
|
||||
let stage = ShaderStage::from(&execution);
|
||||
|
||||
let descriptor_binding_requirements = inspect_entry_point(
|
||||
@ -109,11 +120,87 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_
|
||||
})
|
||||
}
|
||||
|
||||
/// Extracts the u32 constants from `spirv`.
|
||||
fn u32_constants(spirv: &Spirv) -> HashMap<Id, u32> {
|
||||
let type_u32s: HashSet<Id> = spirv
|
||||
.iter_global()
|
||||
.filter_map(|inst| {
|
||||
if let Instruction::TypeInt {
|
||||
result_id,
|
||||
width,
|
||||
signedness,
|
||||
} = inst
|
||||
{
|
||||
if *width == 0 && *signedness == 0 {
|
||||
return Some(*result_id);
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
.collect();
|
||||
spirv
|
||||
.iter_decoration()
|
||||
.filter_map(|inst| {
|
||||
if let Instruction::Constant {
|
||||
result_type_id,
|
||||
result_id,
|
||||
value,
|
||||
} = inst
|
||||
{
|
||||
if type_u32s.contains(result_type_id) {
|
||||
if let [value] = value.as_slice() {
|
||||
return Some((*result_id, *value));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Extracts the specialization constant ids from `spirv`.
|
||||
fn specialization_constant_ids(spirv: &Spirv) -> HashMap<Id, u32> {
|
||||
spirv
|
||||
.iter_decoration()
|
||||
.filter_map(|inst| {
|
||||
if let Instruction::Decorate { target, decoration } = inst {
|
||||
if let Decoration::SpecId {
|
||||
specialization_constant_id,
|
||||
} = decoration
|
||||
{
|
||||
return Some((*target, *specialization_constant_id));
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Extracts the `WorkgroupSize` builtin Id's from `spirv`.
|
||||
fn workgroup_size_decorations(spirv: &Spirv) -> HashSet<Id> {
|
||||
spirv
|
||||
.iter_decoration()
|
||||
.filter_map(|inst| {
|
||||
if let Instruction::Decorate { target, decoration } = inst {
|
||||
if let Decoration::BuiltIn { built_in } = decoration {
|
||||
if *built_in == BuiltIn::WorkgroupSize {
|
||||
return Some(*target);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Extracts the `ShaderExecution` for the entry point `function_id` from `spirv`.
|
||||
fn shader_execution(
|
||||
spirv: &Spirv,
|
||||
execution_model: ExecutionModel,
|
||||
function_id: Id,
|
||||
u32_constants: &HashMap<Id, u32>,
|
||||
specialization_constant_ids: &HashMap<Id, u32>,
|
||||
workgroup_size_decorations: &HashSet<Id>,
|
||||
) -> ShaderExecution {
|
||||
match execution_model {
|
||||
ExecutionModel::Vertex => ShaderExecution::Vertex,
|
||||
@ -187,7 +274,124 @@ fn shader_execution(
|
||||
})
|
||||
}
|
||||
|
||||
ExecutionModel::GLCompute => ShaderExecution::Compute,
|
||||
ExecutionModel::GLCompute => {
|
||||
let mut execution = ComputeShaderExecution::LocalSize([LocalSize::Literal(0); 3]);
|
||||
for instruction in spirv.iter_execution_mode() {
|
||||
match instruction {
|
||||
Instruction::ExecutionMode { entry_point, mode }
|
||||
if *entry_point == function_id =>
|
||||
{
|
||||
if let ExecutionMode::LocalSize {
|
||||
x_size,
|
||||
y_size,
|
||||
z_size,
|
||||
} = mode
|
||||
{
|
||||
execution = ComputeShaderExecution::LocalSize([
|
||||
LocalSize::Literal(*x_size),
|
||||
LocalSize::Literal(*y_size),
|
||||
LocalSize::Literal(*z_size),
|
||||
]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Instruction::ExecutionModeId { entry_point, mode }
|
||||
if *entry_point == function_id =>
|
||||
{
|
||||
if let ExecutionMode::LocalSizeId {
|
||||
x_size,
|
||||
y_size,
|
||||
z_size,
|
||||
} = mode
|
||||
{
|
||||
let mut local_size = [LocalSize::Literal(0); 3];
|
||||
for (local_size, id) in
|
||||
local_size.iter_mut().zip([*x_size, *y_size, *z_size])
|
||||
{
|
||||
if let Some(constant) = u32_constants.get(&id) {
|
||||
*local_size = LocalSize::Literal(*constant);
|
||||
} else if let Some(spec_id) = specialization_constant_ids.get(&id) {
|
||||
*local_size = LocalSize::SpecId(*spec_id);
|
||||
} else {
|
||||
panic!("LocalSizeId {id:?} not defined!");
|
||||
}
|
||||
}
|
||||
execution = ComputeShaderExecution::LocalSizeId(local_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
}
|
||||
if !workgroup_size_decorations.is_empty() {
|
||||
let mut in_function = false;
|
||||
for instruction in spirv.instructions() {
|
||||
if !in_function {
|
||||
match *instruction {
|
||||
Instruction::Function { result_id, .. } if result_id == function_id => {
|
||||
in_function = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else {
|
||||
let mut local_size = [LocalSize::Literal(0); 3];
|
||||
match instruction {
|
||||
Instruction::ConstantComposite {
|
||||
result_type_id: _,
|
||||
result_id,
|
||||
constituents,
|
||||
} => {
|
||||
if workgroup_size_decorations.contains(result_id) {
|
||||
if constituents.len() != 3 {
|
||||
panic!("WorkgroupSize must be 3 component vector!");
|
||||
}
|
||||
for (local_size, id) in
|
||||
local_size.iter_mut().zip(constituents.iter())
|
||||
{
|
||||
if let Some(constant) = u32_constants.get(id) {
|
||||
*local_size = LocalSize::Literal(*constant);
|
||||
} else {
|
||||
panic!("WorkgroupSize Constant {id:?} not defined!");
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
Instruction::SpecConstantComposite {
|
||||
result_type_id: _,
|
||||
result_id,
|
||||
constituents,
|
||||
} => {
|
||||
if workgroup_size_decorations.contains(result_id) {
|
||||
if constituents.len() != 3 {
|
||||
panic!("WorkgroupSize must be 3 component vector!");
|
||||
}
|
||||
for (local_size, id) in
|
||||
local_size.iter_mut().zip(constituents.iter())
|
||||
{
|
||||
if let Some(spec_id) = specialization_constant_ids.get(id) {
|
||||
*local_size = LocalSize::SpecId(*spec_id);
|
||||
} else {
|
||||
panic!("WorkgroupSize SpecializationConstant {id:?} not defined!");
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
Instruction::FunctionEnd => break,
|
||||
_ => continue,
|
||||
}
|
||||
match &mut execution {
|
||||
ComputeShaderExecution::LocalSize(output) => {
|
||||
*output = local_size;
|
||||
}
|
||||
ComputeShaderExecution::LocalSizeId(output) => {
|
||||
*output = local_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ShaderExecution::Compute(execution)
|
||||
}
|
||||
|
||||
ExecutionModel::RayGenerationKHR => ShaderExecution::RayGeneration,
|
||||
ExecutionModel::IntersectionKHR => ShaderExecution::Intersection,
|
||||
|
Loading…
Reference in New Issue
Block a user