diff --git a/Cargo.lock b/Cargo.lock index ca76c0ed4..9efcda4dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1918,6 +1918,7 @@ dependencies = [ "bitflags", "block", "foreign-types", + "fxhash", "log", "metal", "naga", diff --git a/wgpu-hal/Cargo.toml b/wgpu-hal/Cargo.toml index d7e29bf48..3ecbfe01f 100644 --- a/wgpu-hal/Cargo.toml +++ b/wgpu-hal/Cargo.toml @@ -19,6 +19,7 @@ metal = ["block", "foreign-types", "mtl", "objc", "parking_lot", "naga/msl-out"] [dependencies] arrayvec = "0.5" bitflags = "1.0" +fxhash = "0.2.1" log = "0.4" parking_lot = { version = "0.11", optional = true } raw-window-handle = "0.3" diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 387c37c26..499d84a99 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -1,6 +1,8 @@ use super::{conv, AsNative}; use std::{mem, ops::Range}; +const WORD_SIZE: usize = 4; + impl super::CommandBuffer { fn enter_blit(&mut self) -> &mtl::BlitCommandEncoderRef { if self.blit.is_none() { @@ -25,6 +27,34 @@ impl super::CommandBuffer { self.enter_blit() } } + + fn begin_pass(&mut self) { + self.state.storage_buffer_length_map.clear(); + self.state.stage_infos.vs.clear(); + self.state.stage_infos.fs.clear(); + self.state.stage_infos.cs.clear(); + self.leave_blit(); + } +} + +impl super::CommandState { + fn make_sizes_buffer_update<'a>( + &self, + stage: naga::ShaderStage, + result_sizes: &'a mut Vec, + ) -> Option<(u32, &'a [wgt::BufferSize])> { + let stage_info = &self.stage_infos[stage]; + let slot = stage_info.sizes_slot?; + result_sizes.clear(); + for br in stage_info.sized_bindings.iter() { + // If it's None, this isn't the right time to update the sizes + let size = self + .storage_buffer_length_map + .get(&(br.group, br.binding))?; + result_sizes.push(*size); + } + Some((slot as _, result_sizes)) + } } impl crate::CommandBuffer for super::CommandBuffer { @@ -228,7 +258,9 @@ impl crate::CommandBuffer for super::CommandBuffer { // render unsafe fn begin_render_pass(&mut self, desc: &crate::RenderPassDescriptor) { - self.leave_blit(); + self.begin_pass(); + self.state.index = None; + let descriptor = mtl::RenderPassDescriptor::new(); //TODO: set visibility results buffer @@ -311,13 +343,14 @@ impl crate::CommandBuffer for super::CommandBuffer { unsafe fn set_bind_group( &mut self, layout: &super::PipelineLayout, - index: u32, + group_index: u32, group: &super::BindGroup, dynamic_offsets: &[wgt::DynamicOffset], ) { - let bg_info = &layout.bind_group_infos[index as usize]; + let bg_info = &layout.bind_group_infos[group_index as usize]; if let Some(ref encoder) = self.render { + let mut changes_sizes_buffer = false; for index in 0..group.counters.vs.buffers { let buf = &group.buffers[index as usize]; let mut offset = buf.offset; @@ -329,7 +362,27 @@ impl crate::CommandBuffer for super::CommandBuffer { Some(buf.ptr.as_native()), offset, ); + if let Some(size) = buf.binding_size { + self.state + .storage_buffer_length_map + .insert((group_index, buf.binding_location), size); + changes_sizes_buffer = true; + } } + if changes_sizes_buffer { + if let Some((index, sizes)) = self.state.make_sizes_buffer_update( + naga::ShaderStage::Vertex, + &mut self.temp.binding_sizes, + ) { + encoder.set_vertex_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } + } + + changes_sizes_buffer = false; for index in 0..group.counters.fs.buffers { let buf = &group.buffers[(group.counters.vs.buffers + index) as usize]; let mut offset = buf.offset; @@ -341,6 +394,24 @@ impl crate::CommandBuffer for super::CommandBuffer { Some(buf.ptr.as_native()), offset, ); + if let Some(size) = buf.binding_size { + self.state + .storage_buffer_length_map + .insert((group_index, buf.binding_location), size); + changes_sizes_buffer = true; + } + } + if changes_sizes_buffer { + if let Some((index, sizes)) = self.state.make_sizes_buffer_update( + naga::ShaderStage::Fragment, + &mut self.temp.binding_sizes, + ) { + encoder.set_fragment_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } } for index in 0..group.counters.vs.samplers { @@ -380,6 +451,8 @@ impl crate::CommandBuffer for super::CommandBuffer { samplers: group.counters.vs.samplers + group.counters.fs.samplers, textures: group.counters.vs.textures + group.counters.fs.textures, }; + + let mut changes_sizes_buffer = false; for index in 0..group.counters.cs.buffers { let buf = &group.buffers[(index_base.buffers + index) as usize]; let mut offset = buf.offset; @@ -391,7 +464,26 @@ impl crate::CommandBuffer for super::CommandBuffer { Some(buf.ptr.as_native()), offset, ); + if let Some(size) = buf.binding_size { + self.state + .storage_buffer_length_map + .insert((group_index, buf.binding_location), size); + changes_sizes_buffer = true; + } } + if changes_sizes_buffer { + if let Some((index, sizes)) = self.state.make_sizes_buffer_update( + naga::ShaderStage::Compute, + &mut self.temp.binding_sizes, + ) { + encoder.set_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } + } + for index in 0..group.counters.cs.samplers { let res = group.samplers[(index_base.samplers + index) as usize]; encoder.set_sampler_state( @@ -430,7 +522,10 @@ impl crate::CommandBuffer for super::CommandBuffer { } unsafe fn set_render_pipeline(&mut self, pipeline: &super::RenderPipeline) { - self.raw_primitive_type = pipeline.raw_primitive_type; + self.state.raw_primitive_type = pipeline.raw_primitive_type; + self.state.stage_infos.vs.assign_from(&pipeline.vs_info); + self.state.stage_infos.fs.assign_from(&pipeline.fs_info); + let encoder = self.render.as_ref().unwrap(); encoder.set_render_pipeline_state(&pipeline.raw); encoder.set_front_facing_winding(pipeline.raw_front_winding); @@ -442,6 +537,31 @@ impl crate::CommandBuffer for super::CommandBuffer { encoder.set_depth_stencil_state(state); encoder.set_depth_bias(bias.constant as f32, bias.slope_scale, bias.clamp); } + + { + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Vertex, &mut self.temp.binding_sizes) + { + encoder.set_vertex_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } + } + if pipeline.fs_lib.is_some() { + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Fragment, &mut self.temp.binding_sizes) + { + encoder.set_fragment_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } + } } unsafe fn set_index_buffer<'a>( @@ -453,7 +573,7 @@ impl crate::CommandBuffer for super::CommandBuffer { wgt::IndexFormat::Uint16 => (2, mtl::MTLIndexType::UInt16), wgt::IndexFormat::Uint32 => (4, mtl::MTLIndexType::UInt32), }; - self.index_state = Some(super::IndexState { + self.state.index = Some(super::IndexState { buffer_ptr: AsNative::from(binding.buffer.raw.as_ref()), offset: binding.offset, stride, @@ -522,7 +642,7 @@ impl crate::CommandBuffer for super::CommandBuffer { let encoder = self.render.as_ref().unwrap(); if start_instance != 0 { encoder.draw_primitives_instanced_base_instance( - self.raw_primitive_type, + self.state.raw_primitive_type, start_vertex as _, vertex_count as _, instance_count as _, @@ -530,14 +650,14 @@ impl crate::CommandBuffer for super::CommandBuffer { ); } else if instance_count != 1 { encoder.draw_primitives_instanced( - self.raw_primitive_type, + self.state.raw_primitive_type, start_vertex as _, vertex_count as _, instance_count as _, ); } else { encoder.draw_primitives( - self.raw_primitive_type, + self.state.raw_primitive_type, start_vertex as _, vertex_count as _, ); @@ -553,11 +673,11 @@ impl crate::CommandBuffer for super::CommandBuffer { instance_count: u32, ) { let encoder = self.render.as_ref().unwrap(); - let index = self.index_state.as_ref().unwrap(); + let index = self.state.index.as_ref().unwrap(); let offset = index.offset + index.stride * start_index as wgt::BufferAddress; if base_vertex != 0 || start_instance != 0 { encoder.draw_indexed_primitives_instanced_base_instance( - self.raw_primitive_type, + self.state.raw_primitive_type, index_count as _, index.raw_type, index.buffer_ptr.as_native(), @@ -568,7 +688,7 @@ impl crate::CommandBuffer for super::CommandBuffer { ); } else if instance_count != 1 { encoder.draw_indexed_primitives_instanced( - self.raw_primitive_type, + self.state.raw_primitive_type, index_count as _, index.raw_type, index.buffer_ptr.as_native(), @@ -577,7 +697,7 @@ impl crate::CommandBuffer for super::CommandBuffer { ); } else { encoder.draw_indexed_primitives( - self.raw_primitive_type, + self.state.raw_primitive_type, index_count as _, index.raw_type, index.buffer_ptr.as_native(), @@ -594,7 +714,7 @@ impl crate::CommandBuffer for super::CommandBuffer { ) { let encoder = self.render.as_ref().unwrap(); for _ in 0..draw_count { - encoder.draw_primitives_indirect(self.raw_primitive_type, &buffer.raw, offset); + encoder.draw_primitives_indirect(self.state.raw_primitive_type, &buffer.raw, offset); offset += mem::size_of::() as wgt::BufferAddress; } } @@ -606,10 +726,10 @@ impl crate::CommandBuffer for super::CommandBuffer { draw_count: u32, ) { let encoder = self.render.as_ref().unwrap(); - let index = self.index_state.as_ref().unwrap(); + let index = self.state.index.as_ref().unwrap(); for _ in 0..draw_count { encoder.draw_indexed_primitives_indirect( - self.raw_primitive_type, + self.state.raw_primitive_type, index.raw_type, index.buffer_ptr.as_native(), index.offset, @@ -644,7 +764,8 @@ impl crate::CommandBuffer for super::CommandBuffer { // compute unsafe fn begin_compute_pass(&mut self, desc: &crate::ComputePassDescriptor) { - self.leave_blit(); + self.begin_pass(); + let encoder = self.raw.new_compute_command_encoder(); if let Some(label) = desc.label { encoder.set_label(label); @@ -656,9 +777,22 @@ impl crate::CommandBuffer for super::CommandBuffer { } unsafe fn set_compute_pipeline(&mut self, pipeline: &super::ComputePipeline) { - self.raw_wg_size = pipeline.work_group_size; + self.state.raw_wg_size = pipeline.work_group_size; + self.state.stage_infos.cs.assign_from(&pipeline.cs_info); + let encoder = self.compute.as_ref().unwrap(); encoder.set_compute_pipeline_state(&pipeline.raw); + + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Compute, &mut self.temp.binding_sizes) + { + encoder.set_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } } unsafe fn dispatch(&mut self, count: [u32; 3]) { @@ -668,11 +802,11 @@ impl crate::CommandBuffer for super::CommandBuffer { height: count[1] as u64, depth: count[2] as u64, }; - encoder.dispatch_thread_groups(raw_count, self.raw_wg_size); + encoder.dispatch_thread_groups(raw_count, self.state.raw_wg_size); } unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) { let encoder = self.compute.as_ref().unwrap(); - encoder.dispatch_thread_groups_indirect(&buffer.raw, offset, self.raw_wg_size); + encoder.dispatch_thread_groups_indirect(&buffer.raw, offset, self.state.raw_wg_size); } } diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 8b192c5d9..07cd36dd4 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -354,11 +354,16 @@ impl crate::Device for super::Device { blit: None, render: None, compute: None, - raw_primitive_type: mtl::MTLPrimitiveType::Point, - index_state: None, - raw_wg_size: mtl::MTLSize::new(0, 0, 0), - max_buffers_per_stage: self.shared.private_caps.max_buffers_per_stage, disabilities: self.shared.disabilities.clone(), + max_buffers_per_stage: self.shared.private_caps.max_buffers_per_stage, + state: super::CommandState { + raw_primitive_type: mtl::MTLPrimitiveType::Point, + index: None, + raw_wg_size: mtl::MTLSize::new(0, 0, 0), + stage_infos: Default::default(), + storage_buffer_length_map: Default::default(), + }, + temp: super::Temp::default(), }) } unsafe fn destroy_command_buffer(&self, mut cmd_buf: super::CommandBuffer) { @@ -431,7 +436,6 @@ impl crate::Device for super::Device { for (group_index, &bgl) in desc.bind_group_layouts.iter().enumerate() { // remember where the resources for this set start at each shader stage let mut dynamic_buffers = Vec::new(); - let mut sized_buffer_bindings = Vec::new(); let base_resource_indices = stage_data.map(|info| info.counters.clone()); for entry in bgl.entries.iter() { @@ -451,7 +455,6 @@ impl crate::Device for super::Device { })); } if let wgt::BufferBindingType::Storage { .. } = ty { - sized_buffer_bindings.push((entry.binding, entry.visibility)); for info in stage_data.iter_mut() { if entry.visibility.contains(map_naga_stage(info.stage)) { info.sizes_count += 1; @@ -506,8 +509,6 @@ impl crate::Device for super::Device { bind_group_infos.push(super::BindGroupLayoutInfo { base_resource_indices, - //dynamic_buffers, - sized_buffer_bindings, }); } @@ -592,9 +593,17 @@ impl crate::Device for super::Device { } match layout.ty { wgt::BindingType::Buffer { - has_dynamic_offset, .. + ty, + has_dynamic_offset, + .. } => { let source = &desc.buffers[entry.resource_index as usize]; + let binding_size = match ty { + wgt::BufferBindingType::Storage { .. } => source + .size + .or(wgt::BufferSize::new(source.buffer.size - source.offset)), + _ => None, + }; bg.buffers.push(super::BufferResource { ptr: source.buffer.as_raw(), offset: source.offset, @@ -603,6 +612,8 @@ impl crate::Device for super::Device { } else { None }, + binding_size, + binding_location: layout.binding, }); counter.buffers += 1; } diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index 2bb58fe0e..6c8652505 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -463,8 +463,6 @@ type MultiStageResourceCounters = MultiStageData>; #[derive(Debug)] struct BindGroupLayoutInfo { base_resource_indices: MultiStageResourceCounters, - //dynamic_buffers: Vec>, - sized_buffer_bindings: Vec<(u32, wgt::ShaderStage)>, } #[derive(Copy, Clone, Debug, Eq, PartialEq)] @@ -532,6 +530,8 @@ struct BufferResource { ptr: BufferPtr, offset: wgt::BufferAddress, dynamic_index: Option, + binding_size: Option, + binding_location: u32, } #[derive(Debug, Default)] @@ -557,10 +557,26 @@ struct PipelineStageInfo { sized_bindings: Vec, } -#[allow(dead_code)] // silence xx_lib and xx_info warnings +impl PipelineStageInfo { + fn clear(&mut self) { + self.push_constants = None; + self.sizes_slot = None; + self.sized_bindings.clear(); + } + + fn assign_from(&mut self, other: &Self) { + self.push_constants = other.push_constants; + self.sizes_slot = other.sizes_slot; + self.sized_bindings.clear(); + self.sized_bindings.extend_from_slice(&other.sized_bindings); + } +} + pub struct RenderPipeline { raw: mtl::RenderPipelineState, + #[allow(dead_code)] vs_lib: mtl::Library, + #[allow(dead_code)] fs_lib: Option, vs_info: PipelineStageInfo, fs_info: PipelineStageInfo, @@ -574,9 +590,9 @@ pub struct RenderPipeline { unsafe impl Send for RenderPipeline {} unsafe impl Sync for RenderPipeline {} -#[allow(dead_code)] // silence xx_lib and xx_info warnings pub struct ComputePipeline { raw: mtl::ComputePipelineState, + #[allow(dead_code)] cs_lib: mtl::Library, cs_info: PipelineStageInfo, work_group_size: mtl::MTLSize, @@ -628,16 +644,29 @@ struct IndexState { raw_type: mtl::MTLIndexType, } +#[derive(Default)] +struct Temp { + binding_sizes: Vec, +} + +struct CommandState { + raw_primitive_type: mtl::MTLPrimitiveType, + index: Option, + raw_wg_size: mtl::MTLSize, + stage_infos: MultiStageData, + //TODO: use `naga::ResourceBinding` for keys + storage_buffer_length_map: fxhash::FxHashMap<(u32, u32), wgt::BufferSize>, +} + pub struct CommandBuffer { raw: mtl::CommandBuffer, blit: Option, render: Option, compute: Option, - raw_primitive_type: mtl::MTLPrimitiveType, - index_state: Option, - raw_wg_size: mtl::MTLSize, max_buffers_per_stage: u32, disabilities: PrivateDisabilities, + state: CommandState, + temp: Temp, } unsafe impl Send for CommandBuffer {}