From 9b85882ff82cd54071b2be82582f4d267fec1307 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 17 Aug 2021 22:45:55 -0400 Subject: [PATCH] [dx12] implement num_workgroups --- Cargo.lock | 2 +- wgpu-core/Cargo.toml | 2 +- wgpu-hal/Cargo.toml | 4 ++-- wgpu-hal/src/dx12/command.rs | 37 ++++++++++++++++++++++++++++++++---- wgpu-hal/src/dx12/device.rs | 12 ++++++------ wgpu-hal/src/dx12/mod.rs | 1 + wgpu-hal/src/lib.rs | 2 ++ wgpu/Cargo.toml | 4 ++-- 8 files changed, 48 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3a5bdc698..a622c81c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -964,7 +964,7 @@ dependencies = [ [[package]] name = "naga" version = "0.5.0" -source = "git+https://github.com/gfx-rs/naga?rev=7613798#7613798aed17c83f55375c91aafe53ab55106444" +source = "git+https://github.com/gfx-rs/naga?rev=4e181d6#4e181d6af4b758c11482b633b3a54b58bfbd4a59" dependencies = [ "bit-set", "bitflags", diff --git a/wgpu-core/Cargo.toml b/wgpu-core/Cargo.toml index e34d99608..aa9ddb4e5 100644 --- a/wgpu-core/Cargo.toml +++ b/wgpu-core/Cargo.toml @@ -36,7 +36,7 @@ thiserror = "1" [dependencies.naga] git = "https://github.com/gfx-rs/naga" -rev = "7613798" +rev = "4e181d6" features = ["wgsl-in"] [dependencies.wgt] diff --git a/wgpu-hal/Cargo.toml b/wgpu-hal/Cargo.toml index 1b8cf5fb2..d5092cc0d 100644 --- a/wgpu-hal/Cargo.toml +++ b/wgpu-hal/Cargo.toml @@ -65,11 +65,11 @@ core-graphics-types = "0.1" [dependencies.naga] git = "https://github.com/gfx-rs/naga" -rev = "7613798" +rev = "4e181d6" [dev-dependencies.naga] git = "https://github.com/gfx-rs/naga" -rev = "7613798" +rev = "4e181d6" features = ["wgsl-in"] [dev-dependencies] diff --git a/wgpu-hal/src/dx12/command.rs b/wgpu-hal/src/dx12/command.rs index 83385a370..21ec5ed4e 100644 --- a/wgpu-hal/src/dx12/command.rs +++ b/wgpu-hal/src/dx12/command.rs @@ -46,8 +46,8 @@ impl super::CommandEncoder { } unsafe fn prepare_draw(&mut self, base_vertex: i32, base_instance: u32) { - let list = self.list.unwrap(); while self.pass.dirty_vertex_buffers != 0 { + let list = self.list.unwrap(); let index = self.pass.dirty_vertex_buffers.trailing_zeros(); self.pass.dirty_vertex_buffers ^= 1 << index; list.IASetVertexBuffers( @@ -61,6 +61,7 @@ impl super::CommandEncoder { super::RootElement::SpecialConstantBuffer { base_vertex: other_vertex, base_instance: other_instance, + other: _, } => base_vertex != other_vertex || base_instance != other_instance, _ => true, }; @@ -70,13 +71,33 @@ impl super::CommandEncoder { super::RootElement::SpecialConstantBuffer { base_vertex, base_instance, + other: 0, }; } } self.update_root_elements(); } - fn prepare_dispatch(&mut self) { + fn prepare_dispatch(&mut self, count: [u32; 3]) { + if let Some(root_index) = self.pass.layout.special_constants_root_index { + let needs_update = match self.pass.root_elements[root_index as usize] { + super::RootElement::SpecialConstantBuffer { + base_vertex, + base_instance, + other, + } => [base_vertex as u32, base_instance, other] != count, + _ => true, + }; + if needs_update { + self.pass.dirty_root_elements |= 1 << root_index; + self.pass.root_elements[root_index as usize] = + super::RootElement::SpecialConstantBuffer { + base_vertex: count[0] as i32, + base_instance: count[1], + other: count[2], + }; + } + } self.update_root_elements(); } @@ -95,12 +116,17 @@ impl super::CommandEncoder { super::RootElement::SpecialConstantBuffer { base_vertex, base_instance, + other, } => match self.pass.kind { Pk::Render => { list.set_graphics_root_constant(index, base_vertex as u32, 0); list.set_graphics_root_constant(index, base_instance, 1); } - Pk::Compute => (), + Pk::Compute => { + list.set_compute_root_constant(index, base_vertex as u32, 0); + list.set_compute_root_constant(index, base_instance, 1); + list.set_compute_root_constant(index, other, 2); + } Pk::Transfer => (), }, super::RootElement::Table(descriptor) => match self.pass.kind { @@ -141,6 +167,7 @@ impl super::CommandEncoder { super::RootElement::SpecialConstantBuffer { base_vertex: 0, base_instance: 0, + other: 0, }; } self.pass.layout = layout.clone(); @@ -934,10 +961,12 @@ impl crate::CommandEncoder for super::CommandEncoder { } unsafe fn dispatch(&mut self, count: [u32; 3]) { - self.prepare_dispatch(); + self.prepare_dispatch(count); self.list.unwrap().dispatch(count); } unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) { + self.prepare_dispatch([0; 3]); + //TODO: update special constants indirectly self.list.unwrap().ExecuteIndirect( self.shared.cmd_signatures.dispatch.as_mut_ptr(), 1, diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index cccfc72fa..44375285e 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -745,15 +745,15 @@ impl crate::Device for super::Device { ); let mut parameters = Vec::new(); - let (special_constants_root_index, special_constants_binding) = if desc - .flags - .contains(crate::PipelineLayoutFlags::BASE_VERTEX_INSTANCE) - { + let (special_constants_root_index, special_constants_binding) = if desc.flags.intersects( + crate::PipelineLayoutFlags::BASE_VERTEX_INSTANCE + | crate::PipelineLayoutFlags::NUM_WORK_GROUPS, + ) { let parameter_index = parameters.len(); parameters.push(native::RootParameter::constants( - native::ShaderVisibility::VS, + native::ShaderVisibility::All, // really needed for VS and CS only native_binding(&bind_cbv), - 2, // 0 = base vertex, 1 = base instance + 3, // 0 = base vertex, 1 = base instance, 2 = other )); let binding = bind_cbv.clone(); bind_cbv.register += 1; diff --git a/wgpu-hal/src/dx12/mod.rs b/wgpu-hal/src/dx12/mod.rs index 21be83671..6d5931457 100644 --- a/wgpu-hal/src/dx12/mod.rs +++ b/wgpu-hal/src/dx12/mod.rs @@ -283,6 +283,7 @@ enum RootElement { SpecialConstantBuffer { base_vertex: i32, base_instance: u32, + other: u32, }, /// Descriptor table. Table(native::GpuDescriptor), diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index aa9677ee1..f3bfa9729 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -510,6 +510,8 @@ bitflags!( pub struct PipelineLayoutFlags: u32 { /// Include support for base vertex/instance drawing. const BASE_VERTEX_INSTANCE = 1 << 0; + /// Include support for num work groups builtin. + const NUM_WORK_GROUPS = 1 << 1; } ); diff --git a/wgpu/Cargo.toml b/wgpu/Cargo.toml index 3af544239..a62ad172f 100644 --- a/wgpu/Cargo.toml +++ b/wgpu/Cargo.toml @@ -75,13 +75,13 @@ env_logger = "0.8" [dependencies.naga] git = "https://github.com/gfx-rs/naga" -rev = "7613798" +rev = "4e181d6" optional = true # used to test all the example shaders [dev-dependencies.naga] git = "https://github.com/gfx-rs/naga" -rev = "7613798" +rev = "4e181d6" features = ["wgsl-in"] [[example]]