extract set_pipeline from compute_pass_end_impl

This commit is contained in:
teoxoy 2024-06-25 14:18:41 +02:00 committed by Teodor Tanasoaia
parent bc2f8edf9b
commit 3a199cf258

View File

@ -604,68 +604,7 @@ impl Global {
}
ArcComputeCommand::SetPipeline(pipeline) => {
let scope = PassErrorScope::SetPipelineCompute(pipeline.as_info().id());
pipeline.same_device_as(cmd_buf).map_pass_err(scope)?;
state.pipeline = Some(pipeline.as_info().id());
let pipeline = state.tracker.compute_pipelines.insert_single(pipeline);
unsafe {
raw.set_compute_pipeline(pipeline.raw());
}
// Rebind resources
if state.binder.pipeline_layout.is_none()
|| !state
.binder
.pipeline_layout
.as_ref()
.unwrap()
.is_equal(&pipeline.layout)
{
let (start_index, entries) = state.binder.change_pipeline_layout(
&pipeline.layout,
&pipeline.late_sized_buffer_groups,
);
if !entries.is_empty() {
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg =
group.try_raw(&state.snatch_guard).map_pass_err(scope)?;
unsafe {
raw.set_bind_group(
pipeline.layout.raw(),
start_index as u32 + i as u32,
raw_bg,
&e.dynamic_offsets,
);
}
}
}
}
// Clear push constant ranges
let non_overlapping = super::bind::compute_nonoverlapping_ranges(
&pipeline.layout.push_constant_ranges,
);
for range in non_overlapping {
let offset = range.range.start;
let size_bytes = range.range.end - offset;
super::push_constant_clear(
offset,
size_bytes,
|clear_offset, clear_data| unsafe {
raw.set_push_constants(
pipeline.layout.raw(),
wgt::ShaderStages::COMPUTE,
clear_offset,
clear_data,
);
},
);
}
}
set_pipeline(&mut state, raw, cmd_buf, pipeline).map_pass_err(scope)?;
}
ArcComputeCommand::SetPushConstant {
offset,
@ -987,6 +926,69 @@ fn set_bind_group<A: HalApi>(
Ok(())
}
fn set_pipeline<A: HalApi>(
state: &mut State<A>,
raw: &mut A::CommandEncoder,
cmd_buf: &CommandBuffer<A>,
pipeline: Arc<crate::pipeline::ComputePipeline<A>>,
) -> Result<(), ComputePassErrorInner> {
pipeline.same_device_as(cmd_buf)?;
state.pipeline = Some(pipeline.as_info().id());
let pipeline = state.tracker.compute_pipelines.insert_single(pipeline);
unsafe {
raw.set_compute_pipeline(pipeline.raw());
}
// Rebind resources
if state.binder.pipeline_layout.is_none()
|| !state
.binder
.pipeline_layout
.as_ref()
.unwrap()
.is_equal(&pipeline.layout)
{
let (start_index, entries) = state
.binder
.change_pipeline_layout(&pipeline.layout, &pipeline.late_sized_buffer_groups);
if !entries.is_empty() {
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.try_raw(&state.snatch_guard)?;
unsafe {
raw.set_bind_group(
pipeline.layout.raw(),
start_index as u32 + i as u32,
raw_bg,
&e.dynamic_offsets,
);
}
}
}
}
// Clear push constant ranges
let non_overlapping =
super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges);
for range in non_overlapping {
let offset = range.range.start;
let size_bytes = range.range.end - offset;
super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
raw.set_push_constants(
pipeline.layout.raw(),
wgt::ShaderStages::COMPUTE,
clear_offset,
clear_data,
);
});
}
}
Ok(())
}
// Recording a compute pass.
impl Global {
pub fn compute_pass_set_bind_group<A: HalApi>(