extract set_push_constant from compute_pass_end_impl

This commit is contained in:
teoxoy 2024-06-25 14:21:57 +02:00 committed by Teodor Tanasoaia
parent 3a199cf258
commit 5fdb663f45

View File

@ -612,39 +612,15 @@ impl Global {
values_offset, values_offset,
} => { } => {
let scope = PassErrorScope::SetPushConstant; let scope = PassErrorScope::SetPushConstant;
set_push_constant(
let end_offset_bytes = offset + size_bytes; &state,
let values_end_offset = raw,
(values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize; &base.push_constant_data,
let data_slice =
&base.push_constant_data[(values_offset as usize)..values_end_offset];
let pipeline_layout = state
.binder
.pipeline_layout
.as_ref()
//TODO: don't error here, lazily update the push constants
.ok_or(ComputePassErrorInner::Dispatch(
DispatchError::MissingPipeline,
))
.map_pass_err(scope)?;
pipeline_layout
.validate_push_constant_ranges(
wgt::ShaderStages::COMPUTE,
offset, offset,
end_offset_bytes, size_bytes,
values_offset,
) )
.map_pass_err(scope)?; .map_pass_err(scope)?;
unsafe {
raw.set_push_constants(
pipeline_layout.raw(),
wgt::ShaderStages::COMPUTE,
offset,
data_slice,
);
}
} }
ArcComputeCommand::Dispatch(groups) => { ArcComputeCommand::Dispatch(groups) => {
let scope = PassErrorScope::Dispatch { let scope = PassErrorScope::Dispatch {
@ -989,6 +965,44 @@ fn set_pipeline<A: HalApi>(
Ok(()) Ok(())
} }
fn set_push_constant<A: HalApi>(
state: &State<A>,
raw: &mut A::CommandEncoder,
push_constant_data: &[u32],
offset: u32,
size_bytes: u32,
values_offset: u32,
) -> Result<(), ComputePassErrorInner> {
let end_offset_bytes = offset + size_bytes;
let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
let pipeline_layout = state
.binder
.pipeline_layout
.as_ref()
//TODO: don't error here, lazily update the push constants
.ok_or(ComputePassErrorInner::Dispatch(
DispatchError::MissingPipeline,
))?;
pipeline_layout.validate_push_constant_ranges(
wgt::ShaderStages::COMPUTE,
offset,
end_offset_bytes,
)?;
unsafe {
raw.set_push_constants(
pipeline_layout.raw(),
wgt::ShaderStages::COMPUTE,
offset,
data_slice,
);
}
Ok(())
}
// Recording a compute pass. // Recording a compute pass.
impl Global { impl Global {
pub fn compute_pass_set_bind_group<A: HalApi>( pub fn compute_pass_set_bind_group<A: HalApi>(