move the raw encoder in State

This commit is contained in:
teoxoy 2024-06-25 15:22:13 +02:00 committed by Teodor Tanasoaia
parent 8ee9df9eb3
commit f0f61d9bb6

View File

@ -252,7 +252,7 @@ where
}
}
struct State<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> {
struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder, A: HalApi> {
binder: Binder<A>,
pipeline: Option<id::ComputePipelineId>,
scope: UsageScope<'scope, A>,
@ -261,6 +261,9 @@ struct State<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> {
snatch_guard: SnatchGuard<'snatch_guard>,
device: &'cmd_buf Arc<Device<A>>,
raw_encoder: &'raw_encoder mut A::CommandEncoder,
tracker: &'cmd_buf mut Tracker<A>,
buffer_memory_init_actions: &'cmd_buf mut Vec<BufferInitTrackerAction<A>>,
texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions<A>,
@ -277,7 +280,9 @@ struct State<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> {
pending_discard_init_fixups: SurfacesInDiscardState<A>,
}
impl<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> State<'scope, 'snatch_guard, 'cmd_buf, A> {
impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder, A: HalApi>
State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder, A>
{
fn is_ready(&self) -> Result<(), DispatchError> {
let bind_mask = self.binder.invalid_mask();
if bind_mask != 0 {
@ -301,7 +306,6 @@ impl<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> State<'scope, 'snatch_guard, 'c
// part of the usage scope.
fn flush_states(
&mut self,
raw_encoder: &mut A::CommandEncoder,
indirect_buffer: Option<TrackerIndex>,
) -> Result<(), ResourceUsageCompatibilityError> {
for bind_group in self.binder.list_active() {
@ -327,7 +331,7 @@ impl<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> State<'scope, 'snatch_guard, 'c
log::trace!("Encoding dispatch barriers");
CommandBuffer::drain_barriers(
raw_encoder,
self.raw_encoder,
&mut self.intermediate_trackers,
&self.snatch_guard,
);
@ -505,7 +509,7 @@ impl Global {
encoder.close().map_pass_err(pass_scope)?;
// will be reset to true if recording is done without errors
*status = CommandEncoderStatus::Error;
let raw = encoder.open().map_pass_err(pass_scope)?;
let raw_encoder = encoder.open().map_pass_err(pass_scope)?;
let mut state = State {
binder: Binder::new(),
@ -516,6 +520,7 @@ impl Global {
snatch_guard: device.snatchable_lock.read(),
device,
raw_encoder,
tracker: &mut cmd_buf_data.trackers,
buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
@ -561,7 +566,9 @@ impl Global {
// But no point in erroring over that nuance here!
if let Some(range) = range {
unsafe {
raw.reset_queries(query_set.raw.as_ref().unwrap(), range);
state
.raw_encoder
.reset_queries(query_set.raw.as_ref().unwrap(), range);
}
}
@ -580,7 +587,7 @@ impl Global {
};
unsafe {
raw.begin_compute_pass(&hal_desc);
state.raw_encoder.begin_compute_pass(&hal_desc);
}
// TODO: We should be draining the commands here, avoiding extra copies in the process.
@ -595,7 +602,6 @@ impl Global {
let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id());
set_bind_group(
&mut state,
raw,
cmd_buf,
&base.dynamic_offsets,
index,
@ -606,7 +612,7 @@ impl Global {
}
ArcComputeCommand::SetPipeline(pipeline) => {
let scope = PassErrorScope::SetPipelineCompute(pipeline.as_info().id());
set_pipeline(&mut state, raw, cmd_buf, pipeline).map_pass_err(scope)?;
set_pipeline(&mut state, cmd_buf, pipeline).map_pass_err(scope)?;
}
ArcComputeCommand::SetPushConstant {
offset,
@ -615,8 +621,7 @@ impl Global {
} => {
let scope = PassErrorScope::SetPushConstant;
set_push_constant(
&state,
raw,
&mut state,
&base.push_constant_data,
offset,
size_bytes,
@ -629,32 +634,31 @@ impl Global {
indirect: false,
pipeline: state.pipeline,
};
dispatch(&mut state, raw, groups).map_pass_err(scope)?;
dispatch(&mut state, groups).map_pass_err(scope)?;
}
ArcComputeCommand::DispatchIndirect { buffer, offset } => {
let scope = PassErrorScope::Dispatch {
indirect: true,
pipeline: state.pipeline,
};
dispatch_indirect(&mut state, raw, cmd_buf, buffer, offset)
.map_pass_err(scope)?;
dispatch_indirect(&mut state, cmd_buf, buffer, offset).map_pass_err(scope)?;
}
ArcComputeCommand::PushDebugGroup { color: _, len } => {
push_debug_group(&mut state, raw, &base.string_data, len);
push_debug_group(&mut state, &base.string_data, len);
}
ArcComputeCommand::PopDebugGroup => {
let scope = PassErrorScope::PopDebugGroup;
pop_debug_group(&mut state, raw).map_pass_err(scope)?;
pop_debug_group(&mut state).map_pass_err(scope)?;
}
ArcComputeCommand::InsertDebugMarker { color: _, len } => {
insert_debug_marker(&mut state, raw, &base.string_data, len);
insert_debug_marker(&mut state, &base.string_data, len);
}
ArcComputeCommand::WriteTimestamp {
query_set,
query_index,
} => {
let scope = PassErrorScope::WriteTimestamp;
write_timestamp(&mut state, raw, cmd_buf, query_set, query_index)
write_timestamp(&mut state, cmd_buf, query_set, query_index)
.map_pass_err(scope)?;
}
ArcComputeCommand::BeginPipelineStatisticsQuery {
@ -664,7 +668,7 @@ impl Global {
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
validate_and_begin_pipeline_statistics_query(
query_set,
raw,
state.raw_encoder,
&mut state.tracker.query_sets,
cmd_buf,
query_index,
@ -675,20 +679,28 @@ impl Global {
}
ArcComputeCommand::EndPipelineStatisticsQuery => {
let scope = PassErrorScope::EndPipelineStatisticsQuery;
end_pipeline_statistics_query(raw, &mut state.active_query)
end_pipeline_statistics_query(state.raw_encoder, &mut state.active_query)
.map_pass_err(scope)?;
}
}
}
unsafe {
raw.end_compute_pass();
state.raw_encoder.end_compute_pass();
}
// We've successfully recorded the compute pass, bring the
// command buffer out of the error state.
*status = CommandEncoderStatus::Recording;
let State {
snatch_guard,
tracker,
intermediate_trackers,
pending_discard_init_fixups,
..
} = state;
// Stop the current command buffer.
encoder.close().map_pass_err(pass_scope)?;
@ -697,17 +709,17 @@ impl Global {
// Use that buffer to insert barriers and clear discarded images.
let transit = encoder.open().map_pass_err(pass_scope)?;
fixup_discarded_surfaces(
state.pending_discard_init_fixups.into_iter(),
pending_discard_init_fixups.into_iter(),
transit,
&mut state.tracker.textures,
state.device,
&state.snatch_guard,
&mut tracker.textures,
device,
&snatch_guard,
);
CommandBuffer::insert_barriers_from_tracker(
transit,
state.tracker,
&state.intermediate_trackers,
&state.snatch_guard,
tracker,
&intermediate_trackers,
&snatch_guard,
);
// Close the command buffer, and swap it with the previous.
encoder.close_and_swap().map_pass_err(pass_scope)?;
@ -718,7 +730,6 @@ impl Global {
fn set_bind_group<A: HalApi>(
state: &mut State<A>,
raw: &mut A::CommandEncoder,
cmd_buf: &CommandBuffer<A>,
dynamic_offsets: &[DynamicOffset],
index: u32,
@ -771,7 +782,7 @@ fn set_bind_group<A: HalApi>(
if let Some(group) = e.group.as_ref() {
let raw_bg = group.try_raw(&state.snatch_guard)?;
unsafe {
raw.set_bind_group(
state.raw_encoder.set_bind_group(
pipeline_layout,
index + i as u32,
raw_bg,
@ -786,7 +797,6 @@ fn set_bind_group<A: HalApi>(
fn set_pipeline<A: HalApi>(
state: &mut State<A>,
raw: &mut A::CommandEncoder,
cmd_buf: &CommandBuffer<A>,
pipeline: Arc<crate::pipeline::ComputePipeline<A>>,
) -> Result<(), ComputePassErrorInner> {
@ -797,7 +807,7 @@ fn set_pipeline<A: HalApi>(
let pipeline = state.tracker.compute_pipelines.insert_single(pipeline);
unsafe {
raw.set_compute_pipeline(pipeline.raw());
state.raw_encoder.set_compute_pipeline(pipeline.raw());
}
// Rebind resources
@ -817,7 +827,7 @@ fn set_pipeline<A: HalApi>(
if let Some(group) = e.group.as_ref() {
let raw_bg = group.try_raw(&state.snatch_guard)?;
unsafe {
raw.set_bind_group(
state.raw_encoder.set_bind_group(
pipeline.layout.raw(),
start_index as u32 + i as u32,
raw_bg,
@ -835,7 +845,7 @@ fn set_pipeline<A: HalApi>(
let offset = range.range.start;
let size_bytes = range.range.end - offset;
super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
raw.set_push_constants(
state.raw_encoder.set_push_constants(
pipeline.layout.raw(),
wgt::ShaderStages::COMPUTE,
clear_offset,
@ -848,8 +858,7 @@ fn set_pipeline<A: HalApi>(
}
fn set_push_constant<A: HalApi>(
state: &State<A>,
raw: &mut A::CommandEncoder,
state: &mut State<A>,
push_constant_data: &[u32],
offset: u32,
size_bytes: u32,
@ -875,7 +884,7 @@ fn set_push_constant<A: HalApi>(
)?;
unsafe {
raw.set_push_constants(
state.raw_encoder.set_push_constants(
pipeline_layout.raw(),
wgt::ShaderStages::COMPUTE,
offset,
@ -887,12 +896,11 @@ fn set_push_constant<A: HalApi>(
fn dispatch<A: HalApi>(
state: &mut State<A>,
raw: &mut A::CommandEncoder,
groups: [u32; 3],
) -> Result<(), ComputePassErrorInner> {
state.is_ready()?;
state.flush_states(raw, None)?;
state.flush_states(None)?;
let groups_size_limit = state.device.limits.max_compute_workgroups_per_dimension;
@ -909,14 +917,13 @@ fn dispatch<A: HalApi>(
}
unsafe {
raw.dispatch(groups);
state.raw_encoder.dispatch(groups);
}
Ok(())
}
fn dispatch_indirect<A: HalApi>(
state: &mut State<A>,
raw: &mut A::CommandEncoder,
cmd_buf: &CommandBuffer<A>,
buffer: Arc<Buffer<A>>,
offset: u64,
@ -954,21 +961,16 @@ fn dispatch_indirect<A: HalApi>(
MemoryInitKind::NeedsInitializedMemory,
));
state.flush_states(raw, Some(buffer.as_info().tracker_index()))?;
state.flush_states(Some(buffer.as_info().tracker_index()))?;
let buf_raw = buffer.try_raw(&state.snatch_guard)?;
unsafe {
raw.dispatch_indirect(buf_raw, offset);
state.raw_encoder.dispatch_indirect(buf_raw, offset);
}
Ok(())
}
fn push_debug_group<A: HalApi>(
state: &mut State<A>,
raw: &mut A::CommandEncoder,
string_data: &[u8],
len: usize,
) {
fn push_debug_group<A: HalApi>(state: &mut State<A>, string_data: &[u8], len: usize) {
state.debug_scope_depth += 1;
if !state
.device
@ -978,16 +980,13 @@ fn push_debug_group<A: HalApi>(
let label =
str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
unsafe {
raw.begin_debug_marker(label);
state.raw_encoder.begin_debug_marker(label);
}
}
state.string_offset += len;
}
fn pop_debug_group<A: HalApi>(
state: &mut State<A>,
raw: &mut A::CommandEncoder,
) -> Result<(), ComputePassErrorInner> {
fn pop_debug_group<A: HalApi>(state: &mut State<A>) -> Result<(), ComputePassErrorInner> {
if state.debug_scope_depth == 0 {
return Err(ComputePassErrorInner::InvalidPopDebugGroup);
}
@ -998,18 +997,13 @@ fn pop_debug_group<A: HalApi>(
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
{
unsafe {
raw.end_debug_marker();
state.raw_encoder.end_debug_marker();
}
}
Ok(())
}
fn insert_debug_marker<A: HalApi>(
state: &mut State<A>,
raw: &mut A::CommandEncoder,
string_data: &[u8],
len: usize,
) {
fn insert_debug_marker<A: HalApi>(state: &mut State<A>, string_data: &[u8], len: usize) {
if !state
.device
.instance_flags
@ -1017,14 +1011,13 @@ fn insert_debug_marker<A: HalApi>(
{
let label =
str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
unsafe { raw.insert_debug_marker(label) }
unsafe { state.raw_encoder.insert_debug_marker(label) }
}
state.string_offset += len;
}
fn write_timestamp<A: HalApi>(
state: &mut State<A>,
raw: &mut A::CommandEncoder,
cmd_buf: &CommandBuffer<A>,
query_set: Arc<resource::QuerySet<A>>,
query_index: u32,
@ -1037,7 +1030,7 @@ fn write_timestamp<A: HalApi>(
let query_set = state.tracker.query_sets.insert_single(query_set);
query_set.validate_and_write_timestamp(raw, query_index, None)?;
query_set.validate_and_write_timestamp(state.raw_encoder, query_index, None)?;
Ok(())
}