extract dispatch_indirect from compute_pass_end_impl

This commit is contained in:
teoxoy 2024-06-25 14:26:37 +02:00 committed by Teodor Tanasoaia
parent fefb9c2453
commit 868b9cd866

View File

@ -17,7 +17,9 @@ use crate::{
hal_api::HalApi,
hal_label, id,
init_tracker::{BufferInitTrackerAction, MemoryInitKind},
resource::{self, DestroyedResourceError, MissingBufferUsageError, ParentDevice, Resource},
resource::{
self, Buffer, DestroyedResourceError, MissingBufferUsageError, ParentDevice, Resource,
},
snatch::SnatchGuard,
track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope},
Label,
@ -634,53 +636,8 @@ impl Global {
indirect: true,
pipeline: state.pipeline,
};
buffer.same_device_as(cmd_buf).map_pass_err(scope)?;
state.is_ready().map_pass_err(scope)?;
state
.device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
dispatch_indirect(&mut state, raw, cmd_buf, buffer, offset)
.map_pass_err(scope)?;
state
.scope
.buffers
.merge_single(&buffer, hal::BufferUses::INDIRECT)
.map_pass_err(scope)?;
buffer
.check_usage(wgt::BufferUsages::INDIRECT)
.map_pass_err(scope)?;
let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64;
if end_offset > buffer.size {
return Err(ComputePassErrorInner::IndirectBufferOverrun {
offset,
end_offset,
buffer_size: buffer.size,
})
.map_pass_err(scope);
}
let stride = 3 * 4; // 3 integers, x/y/z group size
state.buffer_memory_init_actions.extend(
buffer.initialization_status.read().create_action(
&buffer,
offset..(offset + stride),
MemoryInitKind::NeedsInitializedMemory,
),
);
state
.flush_states(raw, Some(buffer.as_info().tracker_index()))
.map_pass_err(scope)?;
let buf_raw = buffer.try_raw(&state.snatch_guard).map_pass_err(scope)?;
unsafe {
raw.dispatch_indirect(buf_raw, offset);
}
}
ArcComputeCommand::PushDebugGroup { color: _, len } => {
state.debug_scope_depth += 1;
@ -1010,6 +967,55 @@ fn dispatch<A: HalApi>(
Ok(())
}
fn dispatch_indirect<A: HalApi>(
state: &mut State<A>,
raw: &mut A::CommandEncoder,
cmd_buf: &CommandBuffer<A>,
buffer: Arc<Buffer<A>>,
offset: u64,
) -> Result<(), ComputePassErrorInner> {
buffer.same_device_as(cmd_buf)?;
state.is_ready()?;
state
.device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
state
.scope
.buffers
.merge_single(&buffer, hal::BufferUses::INDIRECT)?;
buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64;
if end_offset > buffer.size {
return Err(ComputePassErrorInner::IndirectBufferOverrun {
offset,
end_offset,
buffer_size: buffer.size,
});
}
let stride = 3 * 4; // 3 integers, x/y/z group size
state
.buffer_memory_init_actions
.extend(buffer.initialization_status.read().create_action(
&buffer,
offset..(offset + stride),
MemoryInitKind::NeedsInitializedMemory,
));
state.flush_states(raw, Some(buffer.as_info().tracker_index()))?;
let buf_raw = buffer.try_raw(&state.snatch_guard)?;
unsafe {
raw.dispatch_indirect(buf_raw, offset);
}
Ok(())
}
// Recording a compute pass.
impl Global {
pub fn compute_pass_set_bind_group<A: HalApi>(