mirror of
https://github.com/vulkano-rs/vulkano.git
synced 2024-11-25 00:04:15 +00:00
Added required_subgroup_size to PipelineShaderStageCreateInfo (#2235)
* Added required_subgroup_size to PipelineShaderStageCreateInfo * Added validation errors. * Fixed error msgs / vuids. * ComputeShaderExecution for validating local_size. * WorkgroupSizeId reflection. * contains_enum * Reworked ComputeShaderExecution. * panic msgs. * workgroup size validation * unused import * fixed test deprecated fn * catch workgroup size overflow * EntryPointInfo::local_size docs * comments * typo + error msg
This commit is contained in:
parent
4133a3bf63
commit
ee4e3089d3
@ -73,7 +73,49 @@ fn write_shader_execution(execution: &ShaderExecution) -> TokenStream {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ShaderExecution::Compute => quote! { ::vulkano::shader::ShaderExecution::Compute },
|
ShaderExecution::Compute(execution) => {
|
||||||
|
use ::quote::ToTokens;
|
||||||
|
use ::vulkano::shader::{ComputeShaderExecution, LocalSize};
|
||||||
|
|
||||||
|
struct LocalSizeToTokens(LocalSize);
|
||||||
|
|
||||||
|
impl ToTokens for LocalSizeToTokens {
|
||||||
|
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||||
|
match self.0 {
|
||||||
|
LocalSize::Literal(literal) => quote! {
|
||||||
|
::vulkano::shader::LocalSize::Literal(#literal)
|
||||||
|
},
|
||||||
|
LocalSize::SpecId(id) => quote! {
|
||||||
|
::vulkano::shader::LocalSize::SpecId(#id)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
.to_tokens(tokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match execution {
|
||||||
|
ComputeShaderExecution::LocalSize([x, y, z]) => {
|
||||||
|
let [x, y, z] = [
|
||||||
|
LocalSizeToTokens(*x),
|
||||||
|
LocalSizeToTokens(*y),
|
||||||
|
LocalSizeToTokens(*z),
|
||||||
|
];
|
||||||
|
quote! { ::vulkano::shader::ShaderExecution::Compute(
|
||||||
|
::vulkano::shader::ComputeShaderExecution::LocalSize([#x, #y, #z])
|
||||||
|
) }
|
||||||
|
}
|
||||||
|
ComputeShaderExecution::LocalSizeId([x, y, z]) => {
|
||||||
|
let [x, y, z] = [
|
||||||
|
LocalSizeToTokens(*x),
|
||||||
|
LocalSizeToTokens(*y),
|
||||||
|
LocalSizeToTokens(*z),
|
||||||
|
];
|
||||||
|
quote! { ::vulkano::shader::ShaderExecution::Compute(
|
||||||
|
::vulkano::shader::ComputeShaderExecution::LocalSizeId([#x, #y, #z])
|
||||||
|
) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
ShaderExecution::RayGeneration => {
|
ShaderExecution::RayGeneration => {
|
||||||
quote! { ::vulkano::shader::ShaderExecution::RayGeneration }
|
quote! { ::vulkano::shader::ShaderExecution::RayGeneration }
|
||||||
}
|
}
|
||||||
|
@ -74,11 +74,9 @@ impl ComputePipeline {
|
|||||||
if let Some(cache) = &cache {
|
if let Some(cache) = &cache {
|
||||||
assert_eq!(device, cache.device().as_ref());
|
assert_eq!(device, cache.device().as_ref());
|
||||||
}
|
}
|
||||||
|
|
||||||
create_info
|
create_info
|
||||||
.validate(device)
|
.validate(device)
|
||||||
.map_err(|err| err.add_context("create_info"))?;
|
.map_err(|err| err.add_context("create_info"))?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,12 +98,14 @@ impl ComputePipeline {
|
|||||||
let specialization_info_vk;
|
let specialization_info_vk;
|
||||||
let specialization_map_entries_vk: Vec<_>;
|
let specialization_map_entries_vk: Vec<_>;
|
||||||
let mut specialization_data_vk: Vec<u8>;
|
let mut specialization_data_vk: Vec<u8>;
|
||||||
|
let required_subgroup_size_create_info;
|
||||||
|
|
||||||
{
|
{
|
||||||
let &PipelineShaderStageCreateInfo {
|
let &PipelineShaderStageCreateInfo {
|
||||||
flags,
|
flags,
|
||||||
ref entry_point,
|
ref entry_point,
|
||||||
ref specialization_info,
|
ref specialization_info,
|
||||||
|
ref required_subgroup_size,
|
||||||
_ne: _,
|
_ne: _,
|
||||||
} = stage;
|
} = stage;
|
||||||
|
|
||||||
@ -135,7 +135,20 @@ impl ComputePipeline {
|
|||||||
data_size: specialization_data_vk.len(),
|
data_size: specialization_data_vk.len(),
|
||||||
p_data: specialization_data_vk.as_ptr() as *const _,
|
p_data: specialization_data_vk.as_ptr() as *const _,
|
||||||
};
|
};
|
||||||
|
required_subgroup_size_create_info =
|
||||||
|
required_subgroup_size.map(|required_subgroup_size| {
|
||||||
|
ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo {
|
||||||
|
required_subgroup_size,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
});
|
||||||
stage_vk = ash::vk::PipelineShaderStageCreateInfo {
|
stage_vk = ash::vk::PipelineShaderStageCreateInfo {
|
||||||
|
p_next: required_subgroup_size_create_info.as_ref().map_or(
|
||||||
|
ptr::null(),
|
||||||
|
|required_subgroup_size_create_info| {
|
||||||
|
required_subgroup_size_create_info as *const _ as _
|
||||||
|
},
|
||||||
|
),
|
||||||
flags: flags.into(),
|
flags: flags.into(),
|
||||||
stage: ShaderStage::from(&entry_point_info.execution).into(),
|
stage: ShaderStage::from(&entry_point_info.execution).into(),
|
||||||
module: entry_point.module().handle(),
|
module: entry_point.module().handle(),
|
||||||
@ -333,12 +346,13 @@ impl ComputePipelineCreateInfo {
|
|||||||
flags: _,
|
flags: _,
|
||||||
ref entry_point,
|
ref entry_point,
|
||||||
specialization_info: _,
|
specialization_info: _,
|
||||||
|
required_subgroup_size: _vk,
|
||||||
_ne: _,
|
_ne: _,
|
||||||
} = &stage;
|
} = &stage;
|
||||||
|
|
||||||
let entry_point_info = entry_point.info();
|
let entry_point_info = entry_point.info();
|
||||||
|
|
||||||
if !matches!(entry_point_info.execution, ShaderExecution::Compute) {
|
if !matches!(entry_point_info.execution, ShaderExecution::Compute(_)) {
|
||||||
return Err(Box::new(ValidationError {
|
return Err(Box::new(ValidationError {
|
||||||
context: "stage.entry_point".into(),
|
context: "stage.entry_point".into(),
|
||||||
problem: "is not a `ShaderStage::Compute` entry point".into(),
|
problem: "is not a `ShaderStage::Compute` entry point".into(),
|
||||||
@ -514,4 +528,135 @@ mod tests {
|
|||||||
let data_buffer_content = data_buffer.read().unwrap();
|
let data_buffer_content = data_buffer.read().unwrap();
|
||||||
assert_eq!(*data_buffer_content, 0x12345678);
|
assert_eq!(*data_buffer_content, 0x12345678);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn required_subgroup_size() {
|
||||||
|
// This test checks whether required_subgroup_size works.
|
||||||
|
// It executes a single compute shader (one invocation) that writes the subgroup size
|
||||||
|
// to a buffer. The buffer content is then checked for the right value.
|
||||||
|
|
||||||
|
let (device, queue) = gfx_dev_and_queue!(subgroup_size_control);
|
||||||
|
|
||||||
|
let cs = unsafe {
|
||||||
|
/*
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_KHR_shader_subgroup_basic: enable
|
||||||
|
|
||||||
|
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(set = 0, binding = 0) buffer Output {
|
||||||
|
uint write;
|
||||||
|
} write;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
if (gl_GlobalInvocationID.x == 0) {
|
||||||
|
write.write = gl_SubgroupSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
const MODULE: [u32; 246] = [
|
||||||
|
119734787, 65536, 851978, 30, 0, 131089, 1, 131089, 61, 393227, 1, 1280527431,
|
||||||
|
1685353262, 808793134, 0, 196622, 0, 1, 458767, 5, 4, 1852399981, 0, 9, 23, 393232,
|
||||||
|
4, 17, 128, 1, 1, 196611, 2, 450, 655364, 1197427783, 1279741775, 1885560645,
|
||||||
|
1953718128, 1600482425, 1701734764, 1919509599, 1769235301, 25974, 524292,
|
||||||
|
1197427783, 1279741775, 1852399429, 1685417059, 1768185701, 1952671090, 6649449,
|
||||||
|
589828, 1264536647, 1935626824, 1701077352, 1970495346, 1869768546, 1650421877,
|
||||||
|
1667855201, 0, 262149, 4, 1852399981, 0, 524293, 9, 1197436007, 1633841004,
|
||||||
|
1986939244, 1952539503, 1231974249, 68, 262149, 18, 1886680399, 29813, 327686, 18,
|
||||||
|
0, 1953067639, 101, 262149, 20, 1953067639, 101, 393221, 23, 1398762599,
|
||||||
|
1919378037, 1399879023, 6650473, 262215, 9, 11, 28, 327752, 18, 0, 35, 0, 196679,
|
||||||
|
18, 3, 262215, 20, 34, 0, 262215, 20, 33, 0, 196679, 23, 0, 262215, 23, 11, 36,
|
||||||
|
196679, 24, 0, 262215, 29, 11, 25, 131091, 2, 196641, 3, 2, 262165, 6, 32, 0,
|
||||||
|
262167, 7, 6, 3, 262176, 8, 1, 7, 262203, 8, 9, 1, 262187, 6, 10, 0, 262176, 11, 1,
|
||||||
|
6, 131092, 14, 196638, 18, 6, 262176, 19, 2, 18, 262203, 19, 20, 2, 262165, 21, 32,
|
||||||
|
1, 262187, 21, 22, 0, 262203, 11, 23, 1, 262176, 25, 2, 6, 262187, 6, 27, 128,
|
||||||
|
262187, 6, 28, 1, 393260, 7, 29, 27, 28, 28, 327734, 2, 4, 0, 3, 131320, 5, 327745,
|
||||||
|
11, 12, 9, 10, 262205, 6, 13, 12, 327850, 14, 15, 13, 10, 196855, 17, 0, 262394,
|
||||||
|
15, 16, 17, 131320, 16, 262205, 6, 24, 23, 327745, 25, 26, 20, 22, 196670, 26, 24,
|
||||||
|
131321, 17, 131320, 17, 65789, 65592,
|
||||||
|
];
|
||||||
|
let module =
|
||||||
|
ShaderModule::new(device.clone(), ShaderModuleCreateInfo::new(&MODULE)).unwrap();
|
||||||
|
module.entry_point("main").unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
let properties = device.physical_device().properties();
|
||||||
|
let subgroup_size = properties.min_subgroup_size.unwrap_or(1);
|
||||||
|
|
||||||
|
let pipeline = {
|
||||||
|
let stage = PipelineShaderStageCreateInfo {
|
||||||
|
required_subgroup_size: Some(subgroup_size),
|
||||||
|
..PipelineShaderStageCreateInfo::new(cs)
|
||||||
|
};
|
||||||
|
let layout = PipelineLayout::new(
|
||||||
|
device.clone(),
|
||||||
|
PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage])
|
||||||
|
.into_pipeline_layout_create_info(device.clone())
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
ComputePipeline::new(
|
||||||
|
device.clone(),
|
||||||
|
None,
|
||||||
|
ComputePipelineCreateInfo::stage_layout(stage, layout),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
let memory_allocator = StandardMemoryAllocator::new_default(device.clone());
|
||||||
|
let data_buffer = Buffer::from_data(
|
||||||
|
&memory_allocator,
|
||||||
|
BufferCreateInfo {
|
||||||
|
usage: BufferUsage::STORAGE_BUFFER,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
AllocationCreateInfo {
|
||||||
|
memory_type_filter: MemoryTypeFilter::PREFER_DEVICE
|
||||||
|
| MemoryTypeFilter::HOST_RANDOM_ACCESS,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let ds_allocator = StandardDescriptorSetAllocator::new(device.clone());
|
||||||
|
let set = PersistentDescriptorSet::new(
|
||||||
|
&ds_allocator,
|
||||||
|
pipeline.layout().set_layouts().get(0).unwrap().clone(),
|
||||||
|
[WriteDescriptorSet::buffer(0, data_buffer.clone())],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cb_allocator = StandardCommandBufferAllocator::new(device.clone(), Default::default());
|
||||||
|
let mut cbb = AutoCommandBufferBuilder::primary(
|
||||||
|
&cb_allocator,
|
||||||
|
queue.queue_family_index(),
|
||||||
|
CommandBufferUsage::OneTimeSubmit,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
cbb.bind_pipeline_compute(pipeline.clone())
|
||||||
|
.unwrap()
|
||||||
|
.bind_descriptor_sets(
|
||||||
|
PipelineBindPoint::Compute,
|
||||||
|
pipeline.layout().clone(),
|
||||||
|
0,
|
||||||
|
set,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.dispatch([128, 1, 1])
|
||||||
|
.unwrap();
|
||||||
|
let cb = cbb.build().unwrap();
|
||||||
|
|
||||||
|
let future = now(device)
|
||||||
|
.then_execute(queue, cb)
|
||||||
|
.unwrap()
|
||||||
|
.then_signal_fence_and_flush()
|
||||||
|
.unwrap();
|
||||||
|
future.wait(None).unwrap();
|
||||||
|
|
||||||
|
let data_buffer_content = data_buffer.read().unwrap();
|
||||||
|
assert_eq!(*data_buffer_content, subgroup_size);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -167,7 +167,6 @@ impl GraphicsPipeline {
|
|||||||
create_info
|
create_info
|
||||||
.validate(device)
|
.validate(device)
|
||||||
.map_err(|err| err.add_context("create_info"))?;
|
.map_err(|err| err.add_context("create_info"))?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -204,6 +203,8 @@ impl GraphicsPipeline {
|
|||||||
specialization_info_vk: ash::vk::SpecializationInfo,
|
specialization_info_vk: ash::vk::SpecializationInfo,
|
||||||
specialization_map_entries_vk: Vec<ash::vk::SpecializationMapEntry>,
|
specialization_map_entries_vk: Vec<ash::vk::SpecializationMapEntry>,
|
||||||
specialization_data_vk: Vec<u8>,
|
specialization_data_vk: Vec<u8>,
|
||||||
|
required_subgroup_size_create_info:
|
||||||
|
Option<ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo>,
|
||||||
}
|
}
|
||||||
|
|
||||||
let (mut stages_vk, mut per_stage_vk): (SmallVec<[_; 5]>, SmallVec<[_; 5]>) = stages
|
let (mut stages_vk, mut per_stage_vk): (SmallVec<[_; 5]>, SmallVec<[_; 5]>) = stages
|
||||||
@ -213,6 +214,7 @@ impl GraphicsPipeline {
|
|||||||
flags,
|
flags,
|
||||||
ref entry_point,
|
ref entry_point,
|
||||||
ref specialization_info,
|
ref specialization_info,
|
||||||
|
ref required_subgroup_size,
|
||||||
_ne: _,
|
_ne: _,
|
||||||
} = stage;
|
} = stage;
|
||||||
|
|
||||||
@ -235,7 +237,13 @@ impl GraphicsPipeline {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
let required_subgroup_size_create_info =
|
||||||
|
required_subgroup_size.map(|required_subgroup_size| {
|
||||||
|
ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo {
|
||||||
|
required_subgroup_size,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
});
|
||||||
(
|
(
|
||||||
ash::vk::PipelineShaderStageCreateInfo {
|
ash::vk::PipelineShaderStageCreateInfo {
|
||||||
flags: flags.into(),
|
flags: flags.into(),
|
||||||
@ -255,6 +263,7 @@ impl GraphicsPipeline {
|
|||||||
},
|
},
|
||||||
specialization_map_entries_vk,
|
specialization_map_entries_vk,
|
||||||
specialization_data_vk,
|
specialization_data_vk,
|
||||||
|
required_subgroup_size_create_info,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
@ -267,10 +276,17 @@ impl GraphicsPipeline {
|
|||||||
specialization_info_vk,
|
specialization_info_vk,
|
||||||
specialization_map_entries_vk,
|
specialization_map_entries_vk,
|
||||||
specialization_data_vk,
|
specialization_data_vk,
|
||||||
|
required_subgroup_size_create_info,
|
||||||
},
|
},
|
||||||
) in (stages_vk.iter_mut()).zip(per_stage_vk.iter_mut())
|
) in (stages_vk.iter_mut()).zip(per_stage_vk.iter_mut())
|
||||||
{
|
{
|
||||||
*stage_vk = ash::vk::PipelineShaderStageCreateInfo {
|
*stage_vk = ash::vk::PipelineShaderStageCreateInfo {
|
||||||
|
p_next: required_subgroup_size_create_info.as_ref().map_or(
|
||||||
|
ptr::null(),
|
||||||
|
|required_subgroup_size_create_info| {
|
||||||
|
required_subgroup_size_create_info as *const _ as _
|
||||||
|
},
|
||||||
|
),
|
||||||
p_name: name_vk.as_ptr(),
|
p_name: name_vk.as_ptr(),
|
||||||
p_specialization_info: specialization_info_vk,
|
p_specialization_info: specialization_info_vk,
|
||||||
..*stage_vk
|
..*stage_vk
|
||||||
@ -2420,6 +2436,7 @@ impl GraphicsPipelineCreateInfo {
|
|||||||
flags: _,
|
flags: _,
|
||||||
ref entry_point,
|
ref entry_point,
|
||||||
specialization_info: _,
|
specialization_info: _,
|
||||||
|
required_subgroup_size: _vk,
|
||||||
_ne: _,
|
_ne: _,
|
||||||
} = stage;
|
} = stage;
|
||||||
|
|
||||||
|
@ -321,6 +321,21 @@ pub struct PipelineShaderStageCreateInfo {
|
|||||||
/// The default value is empty.
|
/// The default value is empty.
|
||||||
pub specialization_info: HashMap<u32, SpecializationConstant>,
|
pub specialization_info: HashMap<u32, SpecializationConstant>,
|
||||||
|
|
||||||
|
/// The required subgroup size.
|
||||||
|
///
|
||||||
|
/// Requires [`subgroup_size_control`](crate::device::Features::subgroup_size_control). The
|
||||||
|
/// shader stage must be included in
|
||||||
|
/// [`required_subgroup_size_stages`](crate::device::Properties::required_subgroup_size_stages).
|
||||||
|
/// Subgroup size must be power of 2 and within
|
||||||
|
/// [`min_subgroup_size`](crate::device::Properties::min_subgroup_size)
|
||||||
|
/// and [`max_subgroup_size`](crate::device::Properties::max_subgroup_size).
|
||||||
|
///
|
||||||
|
/// For compute shaders, `max_compute_workgroup_subgroups * required_subgroup_size` must be
|
||||||
|
/// greater than or equal to `workgroup_size.x * workgroup_size.y * workgroup_size.z`.
|
||||||
|
///
|
||||||
|
/// The default value is None.
|
||||||
|
pub required_subgroup_size: Option<u32>,
|
||||||
|
|
||||||
pub _ne: crate::NonExhaustive,
|
pub _ne: crate::NonExhaustive,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -332,6 +347,7 @@ impl PipelineShaderStageCreateInfo {
|
|||||||
flags: PipelineShaderStageCreateFlags::empty(),
|
flags: PipelineShaderStageCreateFlags::empty(),
|
||||||
entry_point,
|
entry_point,
|
||||||
specialization_info: HashMap::default(),
|
specialization_info: HashMap::default(),
|
||||||
|
required_subgroup_size: None,
|
||||||
_ne: crate::NonExhaustive(()),
|
_ne: crate::NonExhaustive(()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -341,6 +357,7 @@ impl PipelineShaderStageCreateInfo {
|
|||||||
flags,
|
flags,
|
||||||
ref entry_point,
|
ref entry_point,
|
||||||
ref specialization_info,
|
ref specialization_info,
|
||||||
|
ref required_subgroup_size,
|
||||||
_ne: _,
|
_ne: _,
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
@ -469,10 +486,145 @@ impl PipelineShaderStageCreateInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let workgroup_size = if let Some(local_size) =
|
||||||
|
entry_point_info.local_size(specialization_info)?
|
||||||
|
{
|
||||||
|
let [x, y, z] = local_size;
|
||||||
|
if x == 0 || y == 0 || z == 0 {
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
problem: format!("`workgroup size` {local_size:?} cannot be 0").into(),
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
let properties = device.physical_device().properties();
|
||||||
|
if stage_enum == ShaderStage::Compute {
|
||||||
|
let max_compute_work_group_size = properties.max_compute_work_group_size;
|
||||||
|
let [max_x, max_y, max_z] = max_compute_work_group_size;
|
||||||
|
if x > max_x || y > max_y || z > max_z {
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
problem: format!("`workgroup size` {local_size:?} is greater than `max_compute_work_group_size` {max_compute_work_group_size:?}").into(),
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
let max_invocations = properties.max_compute_work_group_invocations;
|
||||||
|
if let Some(workgroup_size) = (|| x.checked_mul(y)?.checked_mul(z))() {
|
||||||
|
if workgroup_size > max_invocations {
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
problem: format!("the product of `workgroup size` {local_size:?} = {workgroup_size} is greater than `max_compute_work_group_invocations` {max_invocations}").into(),
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Some(workgroup_size)
|
||||||
|
} else {
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
problem: format!("the product of `workgroup size` {local_size:?} = (overflow) is greater than `max_compute_work_group_invocations` {max_invocations}").into(),
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// TODO: Additional stages when `.local_size()` supports them.
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(required_subgroup_size) = required_subgroup_size {
|
||||||
|
validate_required_subgroup_size(
|
||||||
|
device,
|
||||||
|
stage_enum,
|
||||||
|
workgroup_size,
|
||||||
|
*required_subgroup_size,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn validate_required_subgroup_size(
|
||||||
|
device: &Device,
|
||||||
|
stage: ShaderStage,
|
||||||
|
workgroup_size: Option<u32>,
|
||||||
|
subgroup_size: u32,
|
||||||
|
) -> Result<(), Box<ValidationError>> {
|
||||||
|
if !device.enabled_features().subgroup_size_control {
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
context: "required_subgroup_size".into(),
|
||||||
|
requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::Feature(
|
||||||
|
"subgroup_size_control",
|
||||||
|
)])]),
|
||||||
|
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"],
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
let properties = device.physical_device().properties();
|
||||||
|
if !properties
|
||||||
|
.required_subgroup_size_stages
|
||||||
|
.unwrap_or_default()
|
||||||
|
.contains_enum(stage)
|
||||||
|
{
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
context: "required_subgroup_size".into(),
|
||||||
|
problem: format!("`shader stage` {stage:?} is not in `required_subgroup_size_stages`")
|
||||||
|
.into(),
|
||||||
|
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"],
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
if !subgroup_size.is_power_of_two() {
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
context: "required_subgroup_size".into(),
|
||||||
|
problem: format!("`subgroup_size` {subgroup_size} is not a power of 2").into(),
|
||||||
|
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02760"],
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
let min_subgroup_size = properties.min_subgroup_size.unwrap_or(1);
|
||||||
|
if subgroup_size < min_subgroup_size {
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
context: "required_subgroup_size".into(),
|
||||||
|
problem: format!(
|
||||||
|
"`subgroup_size` {subgroup_size} is less than `min_subgroup_size` {min_subgroup_size}"
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02761"],
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
let max_subgroup_size = properties.max_subgroup_size.unwrap_or(128);
|
||||||
|
if subgroup_size > max_subgroup_size {
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
context: "required_subgroup_size".into(),
|
||||||
|
problem:
|
||||||
|
format!("`subgroup_size` {subgroup_size} is greater than `max_subgroup_size` {max_subgroup_size}")
|
||||||
|
.into(),
|
||||||
|
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02762"],
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
if let Some(workgroup_size) = workgroup_size {
|
||||||
|
if stage == ShaderStage::Compute {
|
||||||
|
let max_compute_workgroup_subgroups = properties
|
||||||
|
.max_compute_workgroup_subgroups
|
||||||
|
.unwrap_or_default();
|
||||||
|
if max_compute_workgroup_subgroups
|
||||||
|
.checked_mul(subgroup_size)
|
||||||
|
.unwrap_or(u32::MAX)
|
||||||
|
< workgroup_size
|
||||||
|
{
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
context: "required_subgroup_size".into(),
|
||||||
|
problem: format!("`subgroup_size` {subgroup_size} creates more than {max_compute_workgroup_subgroups} subgroups").into(),
|
||||||
|
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02756"],
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
vulkan_bitflags! {
|
vulkan_bitflags! {
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
|
|
||||||
|
@ -549,6 +549,78 @@ pub struct EntryPointInfo {
|
|||||||
pub output_interface: ShaderInterface,
|
pub output_interface: ShaderInterface,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl EntryPointInfo {
|
||||||
|
/// The local size in Compute shaders, None otherwise.
|
||||||
|
///
|
||||||
|
/// `specialization_info` is used for LocalSizeId / WorkgroupSizeId, using specialization_constants if not found.
|
||||||
|
/// Errors if specialization constants are not found or are not u32's.
|
||||||
|
pub(crate) fn local_size(
|
||||||
|
&self,
|
||||||
|
specialization_info: &HashMap<u32, SpecializationConstant>,
|
||||||
|
) -> Result<Option<[u32; 3]>, Box<ValidationError>> {
|
||||||
|
if let ShaderExecution::Compute(execution) = self.execution {
|
||||||
|
match execution {
|
||||||
|
ComputeShaderExecution::LocalSize(local_size)
|
||||||
|
| ComputeShaderExecution::LocalSizeId(local_size) => {
|
||||||
|
let mut output = [0; 3];
|
||||||
|
for (output, local_size) in output.iter_mut().zip(local_size) {
|
||||||
|
let id = match local_size {
|
||||||
|
LocalSize::Literal(literal) => {
|
||||||
|
*output = literal;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
LocalSize::SpecId(id) => id,
|
||||||
|
};
|
||||||
|
if let Some(default) = self.specialization_constants.get(&id) {
|
||||||
|
let default_value = if let SpecializationConstant::U32(default_value) =
|
||||||
|
default
|
||||||
|
{
|
||||||
|
default_value
|
||||||
|
} else {
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
problem: format!(
|
||||||
|
"`entry_point.info().specialization_constants[{id}]` is not a 32 bit integer"
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
if let Some(provided) = specialization_info.get(&id) {
|
||||||
|
if let SpecializationConstant::U32(provided_value) = provided {
|
||||||
|
*output = *provided_value;
|
||||||
|
} else {
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
problem: format!(
|
||||||
|
"`specialization_info[{0}]` does not have the same type as \
|
||||||
|
`entry_point.info().specialization_constants[{0}]`",
|
||||||
|
id,
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
vuids: &["VUID-VkSpecializationMapEntry-constantID-00776"],
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*output = *default_value;
|
||||||
|
} else {
|
||||||
|
return Err(Box::new(ValidationError {
|
||||||
|
problem: format!(
|
||||||
|
"specialization constant {id} not found in `entry_point.info().specialization_constants`"
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Some(output))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Represents a shader entry point in a shader module.
|
/// Represents a shader entry point in a shader module.
|
||||||
///
|
///
|
||||||
/// Can be obtained by calling [`entry_point`](ShaderModule::entry_point) on the shader module.
|
/// Can be obtained by calling [`entry_point`](ShaderModule::entry_point) on the shader module.
|
||||||
@ -581,15 +653,15 @@ pub enum ShaderExecution {
|
|||||||
TessellationEvaluation,
|
TessellationEvaluation,
|
||||||
Geometry(GeometryShaderExecution),
|
Geometry(GeometryShaderExecution),
|
||||||
Fragment(FragmentShaderExecution),
|
Fragment(FragmentShaderExecution),
|
||||||
Compute,
|
Compute(ComputeShaderExecution),
|
||||||
RayGeneration,
|
RayGeneration,
|
||||||
AnyHit,
|
AnyHit,
|
||||||
ClosestHit,
|
ClosestHit,
|
||||||
Miss,
|
Miss,
|
||||||
Intersection,
|
Intersection,
|
||||||
Callable,
|
Callable,
|
||||||
Task,
|
Task, // TODO: like compute?
|
||||||
Mesh,
|
Mesh, // TODO: like compute?
|
||||||
SubpassShading,
|
SubpassShading,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -601,7 +673,7 @@ impl From<&ShaderExecution> for ExecutionModel {
|
|||||||
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
|
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
|
||||||
ShaderExecution::Geometry(_) => Self::Geometry,
|
ShaderExecution::Geometry(_) => Self::Geometry,
|
||||||
ShaderExecution::Fragment(_) => Self::Fragment,
|
ShaderExecution::Fragment(_) => Self::Fragment,
|
||||||
ShaderExecution::Compute => Self::GLCompute,
|
ShaderExecution::Compute(_) => Self::GLCompute,
|
||||||
ShaderExecution::RayGeneration => Self::RayGenerationKHR,
|
ShaderExecution::RayGeneration => Self::RayGenerationKHR,
|
||||||
ShaderExecution::AnyHit => Self::AnyHitKHR,
|
ShaderExecution::AnyHit => Self::AnyHitKHR,
|
||||||
ShaderExecution::ClosestHit => Self::ClosestHitKHR,
|
ShaderExecution::ClosestHit => Self::ClosestHitKHR,
|
||||||
@ -699,6 +771,32 @@ pub enum FragmentTestsStages {
|
|||||||
EarlyAndLate,
|
EarlyAndLate,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// LocalSize.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
|
pub enum LocalSize {
|
||||||
|
Literal(u32),
|
||||||
|
SpecId(u32),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The mode in which the compute shader executes.
|
||||||
|
///
|
||||||
|
/// The workgroup size is specified for x, y, and z dimensions.
|
||||||
|
///
|
||||||
|
/// Constants are resolved to literals, while specialization constants
|
||||||
|
/// map to spec ids.
|
||||||
|
///
|
||||||
|
/// The `WorkgroupSize` builtin overrides the values specified in the
|
||||||
|
/// execution mode. It can decorate a 3 component ConstantComposite or
|
||||||
|
/// SpecConstantComposite vector.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
|
pub enum ComputeShaderExecution {
|
||||||
|
/// Workgroup size in x, y, and z.
|
||||||
|
LocalSize([LocalSize; 3]),
|
||||||
|
/// Requires spirv 1.2.
|
||||||
|
/// Like `LocalSize`, but uses ids instead of literals.
|
||||||
|
LocalSizeId([LocalSize; 3]),
|
||||||
|
}
|
||||||
|
|
||||||
/// The requirements imposed by a shader on a binding within a descriptor set layout, and on any
|
/// The requirements imposed by a shader on a binding within a descriptor set layout, and on any
|
||||||
/// resource that is bound to that binding.
|
/// resource that is bound to that binding.
|
||||||
#[derive(Clone, Debug, Default)]
|
#[derive(Clone, Debug, Default)]
|
||||||
@ -1263,7 +1361,7 @@ impl From<&ShaderExecution> for ShaderStage {
|
|||||||
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
|
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
|
||||||
ShaderExecution::Geometry(_) => Self::Geometry,
|
ShaderExecution::Geometry(_) => Self::Geometry,
|
||||||
ShaderExecution::Fragment(_) => Self::Fragment,
|
ShaderExecution::Fragment(_) => Self::Fragment,
|
||||||
ShaderExecution::Compute => Self::Compute,
|
ShaderExecution::Compute(_) => Self::Compute,
|
||||||
ShaderExecution::RayGeneration => Self::Raygen,
|
ShaderExecution::RayGeneration => Self::Raygen,
|
||||||
ShaderExecution::AnyHit => Self::AnyHit,
|
ShaderExecution::AnyHit => Self::AnyHit,
|
||||||
ShaderExecution::ClosestHit => Self::ClosestHit,
|
ShaderExecution::ClosestHit => Self::ClosestHit,
|
||||||
|
@ -16,12 +16,13 @@ use crate::{
|
|||||||
pipeline::layout::PushConstantRange,
|
pipeline::layout::PushConstantRange,
|
||||||
shader::{
|
shader::{
|
||||||
spirv::{
|
spirv::{
|
||||||
Capability, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction, Spirv,
|
BuiltIn, Capability, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction,
|
||||||
StorageClass,
|
Spirv, StorageClass,
|
||||||
},
|
},
|
||||||
DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, GeometryShaderExecution,
|
ComputeShaderExecution, DescriptorIdentifier, DescriptorRequirements, EntryPointInfo,
|
||||||
GeometryShaderInput, NumericType, ShaderExecution, ShaderInterface, ShaderInterfaceEntry,
|
GeometryShaderExecution, GeometryShaderInput, LocalSize, NumericType, ShaderExecution,
|
||||||
ShaderInterfaceEntryType, ShaderStage, SpecializationConstant,
|
ShaderInterface, ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage,
|
||||||
|
SpecializationConstant,
|
||||||
},
|
},
|
||||||
DeviceSize,
|
DeviceSize,
|
||||||
};
|
};
|
||||||
@ -55,6 +56,9 @@ pub fn spirv_extensions(spirv: &Spirv) -> impl Iterator<Item = &str> {
|
|||||||
#[inline]
|
#[inline]
|
||||||
pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_ {
|
pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_ {
|
||||||
let interface_variables = interface_variables(spirv);
|
let interface_variables = interface_variables(spirv);
|
||||||
|
let u32_constants = u32_constants(spirv);
|
||||||
|
let specialization_constant_ids = specialization_constant_ids(spirv);
|
||||||
|
let workgroup_size_decorations = workgroup_size_decorations(spirv);
|
||||||
|
|
||||||
spirv.iter_entry_point().filter_map(move |instruction| {
|
spirv.iter_entry_point().filter_map(move |instruction| {
|
||||||
let (execution_model, function_id, entry_point_name, interface) = match instruction {
|
let (execution_model, function_id, entry_point_name, interface) = match instruction {
|
||||||
@ -68,7 +72,14 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_
|
|||||||
_ => return None,
|
_ => return None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let execution = shader_execution(spirv, execution_model, function_id);
|
let execution = shader_execution(
|
||||||
|
spirv,
|
||||||
|
execution_model,
|
||||||
|
function_id,
|
||||||
|
&u32_constants,
|
||||||
|
&specialization_constant_ids,
|
||||||
|
&workgroup_size_decorations,
|
||||||
|
);
|
||||||
let stage = ShaderStage::from(&execution);
|
let stage = ShaderStage::from(&execution);
|
||||||
|
|
||||||
let descriptor_binding_requirements = inspect_entry_point(
|
let descriptor_binding_requirements = inspect_entry_point(
|
||||||
@ -109,11 +120,87 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Extracts the u32 constants from `spirv`.
|
||||||
|
fn u32_constants(spirv: &Spirv) -> HashMap<Id, u32> {
|
||||||
|
let type_u32s: HashSet<Id> = spirv
|
||||||
|
.iter_global()
|
||||||
|
.filter_map(|inst| {
|
||||||
|
if let Instruction::TypeInt {
|
||||||
|
result_id,
|
||||||
|
width,
|
||||||
|
signedness,
|
||||||
|
} = inst
|
||||||
|
{
|
||||||
|
if *width == 0 && *signedness == 0 {
|
||||||
|
return Some(*result_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
spirv
|
||||||
|
.iter_decoration()
|
||||||
|
.filter_map(|inst| {
|
||||||
|
if let Instruction::Constant {
|
||||||
|
result_type_id,
|
||||||
|
result_id,
|
||||||
|
value,
|
||||||
|
} = inst
|
||||||
|
{
|
||||||
|
if type_u32s.contains(result_type_id) {
|
||||||
|
if let [value] = value.as_slice() {
|
||||||
|
return Some((*result_id, *value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts the specialization constant ids from `spirv`.
|
||||||
|
fn specialization_constant_ids(spirv: &Spirv) -> HashMap<Id, u32> {
|
||||||
|
spirv
|
||||||
|
.iter_decoration()
|
||||||
|
.filter_map(|inst| {
|
||||||
|
if let Instruction::Decorate { target, decoration } = inst {
|
||||||
|
if let Decoration::SpecId {
|
||||||
|
specialization_constant_id,
|
||||||
|
} = decoration
|
||||||
|
{
|
||||||
|
return Some((*target, *specialization_constant_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts the `WorkgroupSize` builtin Id's from `spirv`.
|
||||||
|
fn workgroup_size_decorations(spirv: &Spirv) -> HashSet<Id> {
|
||||||
|
spirv
|
||||||
|
.iter_decoration()
|
||||||
|
.filter_map(|inst| {
|
||||||
|
if let Instruction::Decorate { target, decoration } = inst {
|
||||||
|
if let Decoration::BuiltIn { built_in } = decoration {
|
||||||
|
if *built_in == BuiltIn::WorkgroupSize {
|
||||||
|
return Some(*target);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
/// Extracts the `ShaderExecution` for the entry point `function_id` from `spirv`.
|
/// Extracts the `ShaderExecution` for the entry point `function_id` from `spirv`.
|
||||||
fn shader_execution(
|
fn shader_execution(
|
||||||
spirv: &Spirv,
|
spirv: &Spirv,
|
||||||
execution_model: ExecutionModel,
|
execution_model: ExecutionModel,
|
||||||
function_id: Id,
|
function_id: Id,
|
||||||
|
u32_constants: &HashMap<Id, u32>,
|
||||||
|
specialization_constant_ids: &HashMap<Id, u32>,
|
||||||
|
workgroup_size_decorations: &HashSet<Id>,
|
||||||
) -> ShaderExecution {
|
) -> ShaderExecution {
|
||||||
match execution_model {
|
match execution_model {
|
||||||
ExecutionModel::Vertex => ShaderExecution::Vertex,
|
ExecutionModel::Vertex => ShaderExecution::Vertex,
|
||||||
@ -187,7 +274,124 @@ fn shader_execution(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
ExecutionModel::GLCompute => ShaderExecution::Compute,
|
ExecutionModel::GLCompute => {
|
||||||
|
let mut execution = ComputeShaderExecution::LocalSize([LocalSize::Literal(0); 3]);
|
||||||
|
for instruction in spirv.iter_execution_mode() {
|
||||||
|
match instruction {
|
||||||
|
Instruction::ExecutionMode { entry_point, mode }
|
||||||
|
if *entry_point == function_id =>
|
||||||
|
{
|
||||||
|
if let ExecutionMode::LocalSize {
|
||||||
|
x_size,
|
||||||
|
y_size,
|
||||||
|
z_size,
|
||||||
|
} = mode
|
||||||
|
{
|
||||||
|
execution = ComputeShaderExecution::LocalSize([
|
||||||
|
LocalSize::Literal(*x_size),
|
||||||
|
LocalSize::Literal(*y_size),
|
||||||
|
LocalSize::Literal(*z_size),
|
||||||
|
]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Instruction::ExecutionModeId { entry_point, mode }
|
||||||
|
if *entry_point == function_id =>
|
||||||
|
{
|
||||||
|
if let ExecutionMode::LocalSizeId {
|
||||||
|
x_size,
|
||||||
|
y_size,
|
||||||
|
z_size,
|
||||||
|
} = mode
|
||||||
|
{
|
||||||
|
let mut local_size = [LocalSize::Literal(0); 3];
|
||||||
|
for (local_size, id) in
|
||||||
|
local_size.iter_mut().zip([*x_size, *y_size, *z_size])
|
||||||
|
{
|
||||||
|
if let Some(constant) = u32_constants.get(&id) {
|
||||||
|
*local_size = LocalSize::Literal(*constant);
|
||||||
|
} else if let Some(spec_id) = specialization_constant_ids.get(&id) {
|
||||||
|
*local_size = LocalSize::SpecId(*spec_id);
|
||||||
|
} else {
|
||||||
|
panic!("LocalSizeId {id:?} not defined!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
execution = ComputeShaderExecution::LocalSizeId(local_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if !workgroup_size_decorations.is_empty() {
|
||||||
|
let mut in_function = false;
|
||||||
|
for instruction in spirv.instructions() {
|
||||||
|
if !in_function {
|
||||||
|
match *instruction {
|
||||||
|
Instruction::Function { result_id, .. } if result_id == function_id => {
|
||||||
|
in_function = true;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let mut local_size = [LocalSize::Literal(0); 3];
|
||||||
|
match instruction {
|
||||||
|
Instruction::ConstantComposite {
|
||||||
|
result_type_id: _,
|
||||||
|
result_id,
|
||||||
|
constituents,
|
||||||
|
} => {
|
||||||
|
if workgroup_size_decorations.contains(result_id) {
|
||||||
|
if constituents.len() != 3 {
|
||||||
|
panic!("WorkgroupSize must be 3 component vector!");
|
||||||
|
}
|
||||||
|
for (local_size, id) in
|
||||||
|
local_size.iter_mut().zip(constituents.iter())
|
||||||
|
{
|
||||||
|
if let Some(constant) = u32_constants.get(id) {
|
||||||
|
*local_size = LocalSize::Literal(*constant);
|
||||||
|
} else {
|
||||||
|
panic!("WorkgroupSize Constant {id:?} not defined!");
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Instruction::SpecConstantComposite {
|
||||||
|
result_type_id: _,
|
||||||
|
result_id,
|
||||||
|
constituents,
|
||||||
|
} => {
|
||||||
|
if workgroup_size_decorations.contains(result_id) {
|
||||||
|
if constituents.len() != 3 {
|
||||||
|
panic!("WorkgroupSize must be 3 component vector!");
|
||||||
|
}
|
||||||
|
for (local_size, id) in
|
||||||
|
local_size.iter_mut().zip(constituents.iter())
|
||||||
|
{
|
||||||
|
if let Some(spec_id) = specialization_constant_ids.get(id) {
|
||||||
|
*local_size = LocalSize::SpecId(*spec_id);
|
||||||
|
} else {
|
||||||
|
panic!("WorkgroupSize SpecializationConstant {id:?} not defined!");
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Instruction::FunctionEnd => break,
|
||||||
|
_ => continue,
|
||||||
|
}
|
||||||
|
match &mut execution {
|
||||||
|
ComputeShaderExecution::LocalSize(output) => {
|
||||||
|
*output = local_size;
|
||||||
|
}
|
||||||
|
ComputeShaderExecution::LocalSizeId(output) => {
|
||||||
|
*output = local_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ShaderExecution::Compute(execution)
|
||||||
|
}
|
||||||
|
|
||||||
ExecutionModel::RayGenerationKHR => ShaderExecution::RayGeneration,
|
ExecutionModel::RayGenerationKHR => ShaderExecution::RayGeneration,
|
||||||
ExecutionModel::IntersectionKHR => ShaderExecution::Intersection,
|
ExecutionModel::IntersectionKHR => ShaderExecution::Intersection,
|
||||||
|
Loading…
Reference in New Issue
Block a user