add more fields to State and cleanup fn params

This commit is contained in:
teoxoy 2024-06-21 21:59:05 +02:00 committed by Teodor Tanasoaia
parent 9791b6ccb9
commit ebb930e9fd

View File

@ -366,9 +366,14 @@ impl RenderBundleEncoder {
vertex: (0..hal::MAX_VERTEX_BUFFERS).map(|_| None).collect(), vertex: (0..hal::MAX_VERTEX_BUFFERS).map(|_| None).collect(),
index: None, index: None,
flat_dynamic_offsets: Vec::new(), flat_dynamic_offsets: Vec::new(),
device: device.clone(),
commands: Vec::new(),
buffer_memory_init_actions: Vec::new(),
texture_memory_init_actions: Vec::new(),
next_dynamic_offset: 0,
}; };
let indices = &device.tracker_indices; let indices = &state.device.tracker_indices;
state state
.trackers .trackers
.buffers .buffers
@ -395,12 +400,6 @@ impl RenderBundleEncoder {
.write() .write()
.set_size(indices.query_sets.size()); .set_size(indices.query_sets.size());
let mut commands = Vec::new();
let mut buffer_memory_init_actions = Vec::new();
let mut texture_memory_init_actions = Vec::new();
let mut next_dynamic_offset = 0;
let base = &self.base; let base = &self.base;
for &command in &base.commands { for &command in &base.commands {
@ -411,11 +410,11 @@ impl RenderBundleEncoder {
bind_group_id, bind_group_id,
} => { } => {
let scope = PassErrorScope::SetBindGroup(bind_group_id); let scope = PassErrorScope::SetBindGroup(bind_group_id);
set_bind_group(bind_group_id, &bind_group_guard, &mut state, device, index, &mut next_dynamic_offset, num_dynamic_offsets, base, &mut buffer_memory_init_actions, &mut texture_memory_init_actions).map_pass_err(scope)?; set_bind_group(&mut state, &bind_group_guard, &base.dynamic_offsets, index, num_dynamic_offsets, bind_group_id).map_pass_err(scope)?;
} }
RenderCommand::SetPipeline(pipeline_id) => { RenderCommand::SetPipeline(pipeline_id) => {
let scope = PassErrorScope::SetPipelineRender(pipeline_id); let scope = PassErrorScope::SetPipelineRender(pipeline_id);
set_pipeline(&self, pipeline_id, &pipeline_guard, &mut state, device, &mut commands).map_pass_err(scope)?; set_pipeline(&mut state, &pipeline_guard, &self.context, self.is_depth_read_only, self.is_stencil_read_only, pipeline_id).map_pass_err(scope)?;
} }
RenderCommand::SetIndexBuffer { RenderCommand::SetIndexBuffer {
buffer_id, buffer_id,
@ -424,7 +423,7 @@ impl RenderBundleEncoder {
size, size,
} => { } => {
let scope = PassErrorScope::SetIndexBuffer(buffer_id); let scope = PassErrorScope::SetIndexBuffer(buffer_id);
set_index_buffer(buffer_id, &buffer_guard, &mut state, device, size, offset, &mut buffer_memory_init_actions, index_format).map_pass_err(scope)?; set_index_buffer(&mut state, &buffer_guard, buffer_id, index_format, offset, size).map_pass_err(scope)?;
} }
RenderCommand::SetVertexBuffer { RenderCommand::SetVertexBuffer {
slot, slot,
@ -433,7 +432,7 @@ impl RenderBundleEncoder {
size, size,
} => { } => {
let scope = PassErrorScope::SetVertexBuffer(buffer_id); let scope = PassErrorScope::SetVertexBuffer(buffer_id);
set_vertex_buffer(buffer_id, device, slot, &buffer_guard, &mut state, size, offset, &mut buffer_memory_init_actions).map_pass_err(scope)?; set_vertex_buffer(&mut state, &buffer_guard, slot, buffer_id, offset, size).map_pass_err(scope)?;
} }
RenderCommand::SetPushConstant { RenderCommand::SetPushConstant {
stages, stages,
@ -442,7 +441,7 @@ impl RenderBundleEncoder {
values_offset, values_offset,
} => { } => {
let scope = PassErrorScope::SetPushConstant; let scope = PassErrorScope::SetPushConstant;
set_push_constant(offset, size_bytes, &state, stages, &mut commands, values_offset).map_pass_err(scope)?; set_push_constant(&mut state, stages, offset, size_bytes, values_offset).map_pass_err(scope)?;
} }
RenderCommand::Draw { RenderCommand::Draw {
vertex_count, vertex_count,
@ -455,7 +454,7 @@ impl RenderBundleEncoder {
indexed: false, indexed: false,
pipeline: state.pipeline_id(), pipeline: state.pipeline_id(),
}; };
draw(&mut state, first_vertex, vertex_count, first_instance, instance_count, &mut commands, base).map_pass_err(scope)?; draw(&mut state, &base.dynamic_offsets, vertex_count, instance_count, first_vertex, first_instance).map_pass_err(scope)?;
} }
RenderCommand::DrawIndexed { RenderCommand::DrawIndexed {
index_count, index_count,
@ -469,7 +468,7 @@ impl RenderBundleEncoder {
indexed: true, indexed: true,
pipeline: state.pipeline_id(), pipeline: state.pipeline_id(),
}; };
draw_indexed(&mut state, first_index, index_count, first_instance, instance_count, &mut commands, base, base_vertex).map_pass_err(scope)?; draw_indexed(&mut state, &base.dynamic_offsets, index_count, instance_count, first_index, base_vertex, first_instance).map_pass_err(scope)?;
} }
RenderCommand::MultiDrawIndirect { RenderCommand::MultiDrawIndirect {
buffer_id, buffer_id,
@ -482,7 +481,7 @@ impl RenderBundleEncoder {
indexed: false, indexed: false,
pipeline: state.pipeline_id(), pipeline: state.pipeline_id(),
}; };
multi_draw_indirect(&mut state, device, &buffer_guard, buffer_id, &mut buffer_memory_init_actions, offset, &mut commands, base).map_pass_err(scope)?; multi_draw_indirect(&mut state, &base.dynamic_offsets, &buffer_guard, buffer_id, offset).map_pass_err(scope)?;
} }
RenderCommand::MultiDrawIndirect { RenderCommand::MultiDrawIndirect {
buffer_id, buffer_id,
@ -495,7 +494,7 @@ impl RenderBundleEncoder {
indexed: true, indexed: true,
pipeline: state.pipeline_id(), pipeline: state.pipeline_id(),
}; };
multi_draw_indirect2(&mut state, device, &buffer_guard, buffer_id, &mut buffer_memory_init_actions, offset, &mut commands, base).map_pass_err(scope)?; multi_draw_indirect2(&mut state, &base.dynamic_offsets, &buffer_guard, buffer_id, offset).map_pass_err(scope)?;
} }
RenderCommand::MultiDrawIndirect { .. } RenderCommand::MultiDrawIndirect { .. }
| RenderCommand::MultiDrawIndirectCount { .. } => unimplemented!(), | RenderCommand::MultiDrawIndirectCount { .. } => unimplemented!(),
@ -515,25 +514,38 @@ impl RenderBundleEncoder {
} }
} }
let State {
trackers,
flat_dynamic_offsets,
device,
commands,
buffer_memory_init_actions,
texture_memory_init_actions,
..
} = state;
let tracker_indices = device.tracker_indices.bundles.clone();
let discard_hal_labels = device
.instance_flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS);
Ok(RenderBundle { Ok(RenderBundle {
base: BasePass { base: BasePass {
label: desc.label.as_ref().map(|cow| cow.to_string()), label: desc.label.as_ref().map(|cow| cow.to_string()),
commands, commands,
dynamic_offsets: state.flat_dynamic_offsets, dynamic_offsets: flat_dynamic_offsets,
string_data: Vec::new(), string_data: Vec::new(),
push_constant_data: Vec::new(), push_constant_data: Vec::new(),
}, },
is_depth_read_only: self.is_depth_read_only, is_depth_read_only: self.is_depth_read_only,
is_stencil_read_only: self.is_stencil_read_only, is_stencil_read_only: self.is_stencil_read_only,
device: device.clone(), device,
used: state.trackers, used: trackers,
buffer_memory_init_actions, buffer_memory_init_actions,
texture_memory_init_actions, texture_memory_init_actions,
context: self.context, context: self.context,
info: ResourceInfo::new(&desc.label, Some(device.tracker_indices.bundles.clone())), info: ResourceInfo::new(&desc.label, Some(tracker_indices)),
discard_hal_labels: device discard_hal_labels,
.instance_flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS),
}) })
} }
@ -554,16 +566,12 @@ impl RenderBundleEncoder {
} }
fn set_bind_group<A: HalApi>( fn set_bind_group<A: HalApi>(
bind_group_id: id::Id<id::markers::BindGroup>,
bind_group_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<BindGroup<A>>>,
state: &mut State<A>, state: &mut State<A>,
device: &Arc<Device<A>>, bind_group_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<BindGroup<A>>>,
dynamic_offsets: &[u32],
index: u32, index: u32,
next_dynamic_offset: &mut usize,
num_dynamic_offsets: usize, num_dynamic_offsets: usize,
base: &BasePass<RenderCommand>, bind_group_id: id::Id<id::markers::BindGroup>,
buffer_memory_init_actions: &mut Vec<BufferInitTrackerAction<A>>,
texture_memory_init_actions: &mut Vec<TextureInitTrackerAction<A>>,
) -> Result<(), RenderBundleErrorInner> { ) -> Result<(), RenderBundleErrorInner> {
let bind_group = bind_group_guard let bind_group = bind_group_guard
.get(bind_group_id) .get(bind_group_id)
@ -571,9 +579,9 @@ fn set_bind_group<A: HalApi>(
state.trackers.bind_groups.write().add_single(bind_group); state.trackers.bind_groups.write().add_single(bind_group);
bind_group.same_device(device)?; bind_group.same_device(&state.device)?;
let max_bind_groups = device.limits.max_bind_groups; let max_bind_groups = state.device.limits.max_bind_groups;
if index >= max_bind_groups { if index >= max_bind_groups {
return Err(RenderCommandError::BindGroupIndexOutOfRange { return Err(RenderCommandError::BindGroupIndexOutOfRange {
index, index,
@ -582,10 +590,10 @@ fn set_bind_group<A: HalApi>(
.into()); .into());
} }
// Identify the next `num_dynamic_offsets` entries from `base.dynamic_offsets`. // Identify the next `num_dynamic_offsets` entries from `dynamic_offsets`.
let offsets_range = *next_dynamic_offset..*next_dynamic_offset + num_dynamic_offsets; let offsets_range = state.next_dynamic_offset..state.next_dynamic_offset + num_dynamic_offsets;
*next_dynamic_offset = offsets_range.end; state.next_dynamic_offset = offsets_range.end;
let offsets = &base.dynamic_offsets[offsets_range.clone()]; let offsets = &dynamic_offsets[offsets_range.clone()];
if bind_group.dynamic_binding_info.len() != offsets.len() { if bind_group.dynamic_binding_info.len() != offsets.len() {
return Err(RenderCommandError::InvalidDynamicOffsetCount { return Err(RenderCommandError::InvalidDynamicOffsetCount {
@ -602,7 +610,7 @@ fn set_bind_group<A: HalApi>(
.zip(bind_group.dynamic_binding_info.iter()) .zip(bind_group.dynamic_binding_info.iter())
{ {
let (alignment, limit_name) = let (alignment, limit_name) =
buffer_binding_type_alignment(&device.limits, info.binding_type); buffer_binding_type_alignment(&state.device.limits, info.binding_type);
if offset % alignment as u64 != 0 { if offset % alignment as u64 != 0 {
return Err( return Err(
RenderCommandError::UnalignedBufferOffset(offset, limit_name, alignment).into(), RenderCommandError::UnalignedBufferOffset(offset, limit_name, alignment).into(),
@ -610,8 +618,12 @@ fn set_bind_group<A: HalApi>(
} }
} }
buffer_memory_init_actions.extend_from_slice(&bind_group.used_buffer_ranges); state
texture_memory_init_actions.extend_from_slice(&bind_group.used_texture_ranges); .buffer_memory_init_actions
.extend_from_slice(&bind_group.used_buffer_ranges);
state
.texture_memory_init_actions
.extend_from_slice(&bind_group.used_texture_ranges);
state.set_bind_group( state.set_bind_group(
index, index,
@ -626,12 +638,12 @@ fn set_bind_group<A: HalApi>(
} }
fn set_pipeline<A: HalApi>( fn set_pipeline<A: HalApi>(
encoder: &RenderBundleEncoder,
pipeline_id: id::Id<id::markers::RenderPipeline>,
pipeline_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<RenderPipeline<A>>>,
state: &mut State<A>, state: &mut State<A>,
device: &Arc<Device<A>>, pipeline_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<RenderPipeline<A>>>,
commands: &mut Vec<ArcRenderCommand<A>>, context: &RenderPassContext,
is_depth_read_only: bool,
is_stencil_read_only: bool,
pipeline_id: id::Id<id::markers::RenderPipeline>,
) -> Result<(), RenderBundleErrorInner> { ) -> Result<(), RenderBundleErrorInner> {
let pipeline = pipeline_guard let pipeline = pipeline_guard
.get(pipeline_id) .get(pipeline_id)
@ -639,29 +651,30 @@ fn set_pipeline<A: HalApi>(
state.trackers.render_pipelines.write().add_single(pipeline); state.trackers.render_pipelines.write().add_single(pipeline);
pipeline.same_device(device)?; pipeline.same_device(&state.device)?;
encoder context
.context
.check_compatible( .check_compatible(
&pipeline.pass_context, &pipeline.pass_context,
RenderPassCompatibilityCheckType::RenderPipeline, RenderPassCompatibilityCheckType::RenderPipeline,
) )
.map_err(RenderCommandError::IncompatiblePipelineTargets)?; .map_err(RenderCommandError::IncompatiblePipelineTargets)?;
if (pipeline.flags.contains(PipelineFlags::WRITES_DEPTH) && encoder.is_depth_read_only) if (pipeline.flags.contains(PipelineFlags::WRITES_DEPTH) && is_depth_read_only)
|| (pipeline.flags.contains(PipelineFlags::WRITES_STENCIL) && encoder.is_stencil_read_only) || (pipeline.flags.contains(PipelineFlags::WRITES_STENCIL) && is_stencil_read_only)
{ {
return Err(RenderCommandError::IncompatiblePipelineRods.into()); return Err(RenderCommandError::IncompatiblePipelineRods.into());
} }
let pipeline_state = PipelineState::new(pipeline, pipeline_id); let pipeline_state = PipelineState::new(pipeline, pipeline_id);
commands.push(ArcRenderCommand::SetPipeline(pipeline.clone())); state
.commands
.push(ArcRenderCommand::SetPipeline(pipeline.clone()));
// If this pipeline uses push constants, zero out their values. // If this pipeline uses push constants, zero out their values.
if let Some(iter) = pipeline_state.zero_push_constants() { if let Some(iter) = pipeline_state.zero_push_constants() {
commands.extend(iter) state.commands.extend(iter)
} }
state.invalidate_bind_groups(&pipeline_state, &pipeline.layout); state.invalidate_bind_groups(&pipeline_state, &pipeline.layout);
@ -670,14 +683,12 @@ fn set_pipeline<A: HalApi>(
} }
fn set_index_buffer<A: HalApi>( fn set_index_buffer<A: HalApi>(
buffer_id: id::Id<id::markers::Buffer>,
buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Buffer<A>>>,
state: &mut State<A>, state: &mut State<A>,
device: &Arc<Device<A>>, buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Buffer<A>>>,
size: Option<std::num::NonZeroU64>, buffer_id: id::Id<id::markers::Buffer>,
offset: u64,
buffer_memory_init_actions: &mut Vec<BufferInitTrackerAction<A>>,
index_format: wgt::IndexFormat, index_format: wgt::IndexFormat,
offset: u64,
size: Option<std::num::NonZeroU64>,
) -> Result<(), RenderBundleErrorInner> { ) -> Result<(), RenderBundleErrorInner> {
let buffer = buffer_guard let buffer = buffer_guard
.get(buffer_id) .get(buffer_id)
@ -689,14 +700,16 @@ fn set_index_buffer<A: HalApi>(
.write() .write()
.merge_single(buffer, hal::BufferUses::INDEX)?; .merge_single(buffer, hal::BufferUses::INDEX)?;
buffer.same_device(device)?; buffer.same_device(&state.device)?;
buffer.check_usage(wgt::BufferUsages::INDEX)?; buffer.check_usage(wgt::BufferUsages::INDEX)?;
let end = match size { let end = match size {
Some(s) => offset + s.get(), Some(s) => offset + s.get(),
None => buffer.size, None => buffer.size,
}; };
buffer_memory_init_actions.extend(buffer.initialization_status.read().create_action( state
.buffer_memory_init_actions
.extend(buffer.initialization_status.read().create_action(
buffer, buffer,
offset..end, offset..end,
MemoryInitKind::NeedsInitializedMemory, MemoryInitKind::NeedsInitializedMemory,
@ -706,16 +719,14 @@ fn set_index_buffer<A: HalApi>(
} }
fn set_vertex_buffer<A: HalApi>( fn set_vertex_buffer<A: HalApi>(
buffer_id: id::Id<id::markers::Buffer>,
device: &Arc<Device<A>>,
slot: u32,
buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Buffer<A>>>,
state: &mut State<A>, state: &mut State<A>,
size: Option<std::num::NonZeroU64>, buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Buffer<A>>>,
slot: u32,
buffer_id: id::Id<id::markers::Buffer>,
offset: u64, offset: u64,
buffer_memory_init_actions: &mut Vec<BufferInitTrackerAction<A>>, size: Option<std::num::NonZeroU64>,
) -> Result<(), RenderBundleErrorInner> { ) -> Result<(), RenderBundleErrorInner> {
let max_vertex_buffers = device.limits.max_vertex_buffers; let max_vertex_buffers = state.device.limits.max_vertex_buffers;
if slot >= max_vertex_buffers { if slot >= max_vertex_buffers {
return Err(RenderCommandError::VertexBufferIndexOutOfRange { return Err(RenderCommandError::VertexBufferIndexOutOfRange {
index: slot, index: slot,
@ -734,14 +745,16 @@ fn set_vertex_buffer<A: HalApi>(
.write() .write()
.merge_single(buffer, hal::BufferUses::VERTEX)?; .merge_single(buffer, hal::BufferUses::VERTEX)?;
buffer.same_device(device)?; buffer.same_device(&state.device)?;
buffer.check_usage(wgt::BufferUsages::VERTEX)?; buffer.check_usage(wgt::BufferUsages::VERTEX)?;
let end = match size { let end = match size {
Some(s) => offset + s.get(), Some(s) => offset + s.get(),
None => buffer.size, None => buffer.size,
}; };
buffer_memory_init_actions.extend(buffer.initialization_status.read().create_action( state
.buffer_memory_init_actions
.extend(buffer.initialization_status.read().create_action(
buffer, buffer,
offset..end, offset..end,
MemoryInitKind::NeedsInitializedMemory, MemoryInitKind::NeedsInitializedMemory,
@ -751,11 +764,10 @@ fn set_vertex_buffer<A: HalApi>(
} }
fn set_push_constant<A: HalApi>( fn set_push_constant<A: HalApi>(
state: &mut State<A>,
stages: wgt::ShaderStages,
offset: u32, offset: u32,
size_bytes: u32, size_bytes: u32,
state: &State<A>,
stages: wgt::ShaderStages,
commands: &mut Vec<ArcRenderCommand<A>>,
values_offset: Option<u32>, values_offset: Option<u32>,
) -> Result<(), RenderBundleErrorInner> { ) -> Result<(), RenderBundleErrorInner> {
let end_offset = offset + size_bytes; let end_offset = offset + size_bytes;
@ -767,7 +779,7 @@ fn set_push_constant<A: HalApi>(
.layout .layout
.validate_push_constant_ranges(stages, offset, end_offset)?; .validate_push_constant_ranges(stages, offset, end_offset)?;
commands.push(ArcRenderCommand::SetPushConstant { state.commands.push(ArcRenderCommand::SetPushConstant {
stages, stages,
offset, offset,
size_bytes, size_bytes,
@ -778,12 +790,11 @@ fn set_push_constant<A: HalApi>(
fn draw<A: HalApi>( fn draw<A: HalApi>(
state: &mut State<A>, state: &mut State<A>,
first_vertex: u32, dynamic_offsets: &[u32],
vertex_count: u32, vertex_count: u32,
first_instance: u32,
instance_count: u32, instance_count: u32,
commands: &mut Vec<ArcRenderCommand<A>>, first_vertex: u32,
base: &BasePass<RenderCommand>, first_instance: u32,
) -> Result<(), RenderBundleErrorInner> { ) -> Result<(), RenderBundleErrorInner> {
let pipeline = state.pipeline()?; let pipeline = state.pipeline()?;
let used_bind_groups = pipeline.used_bind_groups; let used_bind_groups = pipeline.used_bind_groups;
@ -798,9 +809,9 @@ fn draw<A: HalApi>(
)?; )?;
if instance_count > 0 && vertex_count > 0 { if instance_count > 0 && vertex_count > 0 {
commands.extend(state.flush_vertices()); state.flush_vertices();
commands.extend(state.flush_binds(used_bind_groups, &base.dynamic_offsets)); state.flush_binds(used_bind_groups, dynamic_offsets);
commands.push(ArcRenderCommand::Draw { state.commands.push(ArcRenderCommand::Draw {
vertex_count, vertex_count,
instance_count, instance_count,
first_vertex, first_vertex,
@ -812,13 +823,12 @@ fn draw<A: HalApi>(
fn draw_indexed<A: HalApi>( fn draw_indexed<A: HalApi>(
state: &mut State<A>, state: &mut State<A>,
first_index: u32, dynamic_offsets: &[u32],
index_count: u32, index_count: u32,
first_instance: u32,
instance_count: u32, instance_count: u32,
commands: &mut Vec<ArcRenderCommand<A>>, first_index: u32,
base: &BasePass<RenderCommand>,
base_vertex: i32, base_vertex: i32,
first_instance: u32,
) -> Result<(), RenderBundleErrorInner> { ) -> Result<(), RenderBundleErrorInner> {
let pipeline = state.pipeline()?; let pipeline = state.pipeline()?;
let used_bind_groups = pipeline.used_bind_groups; let used_bind_groups = pipeline.used_bind_groups;
@ -838,10 +848,10 @@ fn draw_indexed<A: HalApi>(
)?; )?;
if instance_count > 0 && index_count > 0 { if instance_count > 0 && index_count > 0 {
commands.extend(state.flush_index()); state.flush_index();
commands.extend(state.flush_vertices()); state.flush_vertices();
commands.extend(state.flush_binds(used_bind_groups, &base.dynamic_offsets)); state.flush_binds(used_bind_groups, dynamic_offsets);
commands.push(ArcRenderCommand::DrawIndexed { state.commands.push(ArcRenderCommand::DrawIndexed {
index_count, index_count,
instance_count, instance_count,
first_index, first_index,
@ -854,15 +864,14 @@ fn draw_indexed<A: HalApi>(
fn multi_draw_indirect<A: HalApi>( fn multi_draw_indirect<A: HalApi>(
state: &mut State<A>, state: &mut State<A>,
device: &Arc<Device<A>>, dynamic_offsets: &[u32],
buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Buffer<A>>>, buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Buffer<A>>>,
buffer_id: id::Id<id::markers::Buffer>, buffer_id: id::Id<id::markers::Buffer>,
buffer_memory_init_actions: &mut Vec<BufferInitTrackerAction<A>>,
offset: u64, offset: u64,
commands: &mut Vec<ArcRenderCommand<A>>,
base: &BasePass<RenderCommand>,
) -> Result<(), RenderBundleErrorInner> { ) -> Result<(), RenderBundleErrorInner> {
device.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?; state
.device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
let pipeline = state.pipeline()?; let pipeline = state.pipeline()?;
let used_bind_groups = pipeline.used_bind_groups; let used_bind_groups = pipeline.used_bind_groups;
@ -877,18 +886,20 @@ fn multi_draw_indirect<A: HalApi>(
.write() .write()
.merge_single(buffer, hal::BufferUses::INDIRECT)?; .merge_single(buffer, hal::BufferUses::INDIRECT)?;
buffer.same_device(device)?; buffer.same_device(&state.device)?;
buffer.check_usage(wgt::BufferUsages::INDIRECT)?; buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
buffer_memory_init_actions.extend(buffer.initialization_status.read().create_action( state
.buffer_memory_init_actions
.extend(buffer.initialization_status.read().create_action(
buffer, buffer,
offset..(offset + mem::size_of::<wgt::DrawIndirectArgs>() as u64), offset..(offset + mem::size_of::<wgt::DrawIndirectArgs>() as u64),
MemoryInitKind::NeedsInitializedMemory, MemoryInitKind::NeedsInitializedMemory,
)); ));
commands.extend(state.flush_vertices()); state.flush_vertices();
commands.extend(state.flush_binds(used_bind_groups, &base.dynamic_offsets)); state.flush_binds(used_bind_groups, dynamic_offsets);
commands.push(ArcRenderCommand::MultiDrawIndirect { state.commands.push(ArcRenderCommand::MultiDrawIndirect {
buffer: buffer.clone(), buffer: buffer.clone(),
offset, offset,
count: None, count: None,
@ -899,15 +910,14 @@ fn multi_draw_indirect<A: HalApi>(
fn multi_draw_indirect2<A: HalApi>( fn multi_draw_indirect2<A: HalApi>(
state: &mut State<A>, state: &mut State<A>,
device: &Arc<Device<A>>, dynamic_offsets: &[u32],
buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Buffer<A>>>, buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Buffer<A>>>,
buffer_id: id::Id<id::markers::Buffer>, buffer_id: id::Id<id::markers::Buffer>,
buffer_memory_init_actions: &mut Vec<BufferInitTrackerAction<A>>,
offset: u64, offset: u64,
commands: &mut Vec<ArcRenderCommand<A>>,
base: &BasePass<RenderCommand>,
) -> Result<(), RenderBundleErrorInner> { ) -> Result<(), RenderBundleErrorInner> {
device.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?; state
.device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
let pipeline = state.pipeline()?; let pipeline = state.pipeline()?;
let used_bind_groups = pipeline.used_bind_groups; let used_bind_groups = pipeline.used_bind_groups;
@ -922,10 +932,12 @@ fn multi_draw_indirect2<A: HalApi>(
.write() .write()
.merge_single(buffer, hal::BufferUses::INDIRECT)?; .merge_single(buffer, hal::BufferUses::INDIRECT)?;
buffer.same_device(device)?; buffer.same_device(&state.device)?;
buffer.check_usage(wgt::BufferUsages::INDIRECT)?; buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
buffer_memory_init_actions.extend(buffer.initialization_status.read().create_action( state
.buffer_memory_init_actions
.extend(buffer.initialization_status.read().create_action(
buffer, buffer,
offset..(offset + mem::size_of::<wgt::DrawIndirectArgs>() as u64), offset..(offset + mem::size_of::<wgt::DrawIndirectArgs>() as u64),
MemoryInitKind::NeedsInitializedMemory, MemoryInitKind::NeedsInitializedMemory,
@ -936,10 +948,10 @@ fn multi_draw_indirect2<A: HalApi>(
None => return Err(DrawError::MissingIndexBuffer.into()), None => return Err(DrawError::MissingIndexBuffer.into()),
}; };
commands.extend(index.flush()); state.commands.extend(index.flush());
commands.extend(state.flush_vertices()); state.flush_vertices();
commands.extend(state.flush_binds(used_bind_groups, &base.dynamic_offsets)); state.flush_binds(used_bind_groups, dynamic_offsets);
commands.push(ArcRenderCommand::MultiDrawIndirect { state.commands.push(ArcRenderCommand::MultiDrawIndirect {
buffer: buffer.clone(), buffer: buffer.clone(),
offset, offset,
count: None, count: None,
@ -1416,6 +1428,12 @@ struct State<A: HalApi> {
/// ///
/// [`dynamic_offsets`]: BasePass::dynamic_offsets /// [`dynamic_offsets`]: BasePass::dynamic_offsets
flat_dynamic_offsets: Vec<wgt::DynamicOffset>, flat_dynamic_offsets: Vec<wgt::DynamicOffset>,
device: Arc<Device<A>>,
commands: Vec<ArcRenderCommand<A>>,
buffer_memory_init_actions: Vec<BufferInitTrackerAction<A>>,
texture_memory_init_actions: Vec<TextureInitTrackerAction<A>>,
next_dynamic_offset: usize,
} }
impl<A: HalApi> State<A> { impl<A: HalApi> State<A> {
@ -1541,23 +1559,22 @@ impl<A: HalApi> State<A> {
/// Generate a `SetIndexBuffer` command to prepare for an indexed draw /// Generate a `SetIndexBuffer` command to prepare for an indexed draw
/// command, if needed. /// command, if needed.
fn flush_index(&mut self) -> Option<ArcRenderCommand<A>> { fn flush_index(&mut self) {
self.index.as_mut().and_then(|index| index.flush()) let commands = self.index.as_mut().and_then(|index| index.flush());
self.commands.extend(commands);
} }
fn flush_vertices(&mut self) -> impl Iterator<Item = ArcRenderCommand<A>> + '_ { fn flush_vertices(&mut self) {
self.vertex let commands = self
.vertex
.iter_mut() .iter_mut()
.enumerate() .enumerate()
.flat_map(|(i, vs)| vs.as_mut().and_then(|vs| vs.flush(i as u32))) .flat_map(|(i, vs)| vs.as_mut().and_then(|vs| vs.flush(i as u32)));
self.commands.extend(commands);
} }
/// Generate `SetBindGroup` commands for any bind groups that need to be updated. /// Generate `SetBindGroup` commands for any bind groups that need to be updated.
fn flush_binds( fn flush_binds(&mut self, used_bind_groups: usize, dynamic_offsets: &[wgt::DynamicOffset]) {
&mut self,
used_bind_groups: usize,
dynamic_offsets: &[wgt::DynamicOffset],
) -> impl Iterator<Item = ArcRenderCommand<A>> + '_ {
// Append each dirty bind group's dynamic offsets to `flat_dynamic_offsets`. // Append each dirty bind group's dynamic offsets to `flat_dynamic_offsets`.
for contents in self.bind[..used_bind_groups].iter().flatten() { for contents in self.bind[..used_bind_groups].iter().flatten() {
if contents.is_dirty { if contents.is_dirty {
@ -1568,7 +1585,7 @@ impl<A: HalApi> State<A> {
// Then, generate `SetBindGroup` commands to update the dirty bind // Then, generate `SetBindGroup` commands to update the dirty bind
// groups. After this, all bind groups are clean. // groups. After this, all bind groups are clean.
self.bind[..used_bind_groups] let commands = self.bind[..used_bind_groups]
.iter_mut() .iter_mut()
.enumerate() .enumerate()
.flat_map(|(i, entry)| { .flat_map(|(i, entry)| {
@ -1584,7 +1601,9 @@ impl<A: HalApi> State<A> {
} }
} }
None None
}) });
self.commands.extend(commands);
} }
} }