[d3d12] get num_workgroups builtin working for indirect dispatches

This commit is contained in:
teoxoy 2024-05-22 16:24:52 +02:00 committed by Teodor Tanasoaia
parent 7f708edd1f
commit bbee35b145
7 changed files with 150 additions and 44 deletions

View File

@ -185,6 +185,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
#### DX12 #### DX12
- Replace `winapi` code to use the `windows` crate. By @MarijnS95 in [#5956](https://github.com/gfx-rs/wgpu/pull/5956) and [#6173](https://github.com/gfx-rs/wgpu/pull/6173) - Replace `winapi` code to use the `windows` crate. By @MarijnS95 in [#5956](https://github.com/gfx-rs/wgpu/pull/5956) and [#6173](https://github.com/gfx-rs/wgpu/pull/6173)
- Get `num_workgroups` builtin working for indirect dispatches. By @teoxoy in [#5730](https://github.com/gfx-rs/wgpu/pull/5730)
#### HAL #### HAL

View File

@ -1,4 +1,4 @@
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext}; use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};
/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12). /// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
#[gpu_test] #[gpu_test]
@ -12,8 +12,7 @@ static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new(
.limits(wgpu::Limits { .limits(wgpu::Limits {
max_push_constant_size: 4, max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults() ..wgpu::Limits::downlevel_defaults()
}) }),
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
) )
.run_async(|ctx| async move { .run_async(|ctx| async move {
let num_workgroups = [1, 2, 3]; let num_workgroups = [1, 2, 3];
@ -34,8 +33,7 @@ static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
max_compute_workgroups_per_dimension: 10, max_compute_workgroups_per_dimension: 10,
max_push_constant_size: 4, max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults() ..wgpu::Limits::downlevel_defaults()
}) }),
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
) )
.run_async(|ctx| async move { .run_async(|ctx| async move {
let max = ctx.device.limits().max_compute_workgroups_per_dimension; let max = ctx.device.limits().max_compute_workgroups_per_dimension;

View File

@ -2638,7 +2638,8 @@ impl Device {
let hal_desc = hal::PipelineLayoutDescriptor { let hal_desc = hal::PipelineLayoutDescriptor {
label: desc.label.to_hal(self.instance_flags), label: desc.label.to_hal(self.instance_flags),
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE, flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE
| hal::PipelineLayoutFlags::NUM_WORK_GROUPS,
bind_group_layouts: &raw_bind_group_layouts, bind_group_layouts: &raw_bind_group_layouts,
push_constant_ranges: desc.push_constant_ranges.as_ref(), push_constant_ranges: desc.push_constant_ranges.as_ref(),
}; };

View File

@ -1,3 +1,6 @@
use std::mem::size_of;
use std::num::NonZeroU64;
use thiserror::Error; use thiserror::Error;
use crate::{ use crate::{
@ -57,7 +60,7 @@ impl IndirectValidation {
let src = format!( let src = format!(
" "
@group(0) @binding(0) @group(0) @binding(0)
var<storage, read_write> dst: array<u32, 3>; var<storage, read_write> dst: array<u32, 6>;
@group(1) @binding(0) @group(1) @binding(0)
var<storage, read> src: array<u32>; var<storage, read> src: array<u32>;
struct OffsetPc {{ struct OffsetPc {{
@ -74,14 +77,25 @@ impl IndirectValidation {
src.y > max_compute_workgroups_per_dimension || src.y > max_compute_workgroups_per_dimension ||
src.z > max_compute_workgroups_per_dimension src.z > max_compute_workgroups_per_dimension
) {{ ) {{
dst = array(0u, 0u, 0u); dst = array(0u, 0u, 0u, 0u, 0u, 0u);
}} else {{ }} else {{
dst = array(src.x, src.y, src.z); dst = array(src.x, src.y, src.z, src.x, src.y, src.z);
}} }}
}} }}
" "
); );
// SAFETY: The value we are passing to `new_unchecked` is not zero, so this is safe.
const SRC_BUFFER_SIZE: NonZeroU64 =
unsafe { NonZeroU64::new_unchecked(size_of::<u32>() as u64 * 3) };
// SAFETY: The value we are passing to `new_unchecked` is not zero, so this is safe.
const DST_BUFFER_SIZE: NonZeroU64 = unsafe {
NonZeroU64::new_unchecked(
SRC_BUFFER_SIZE.get() * 2, // From above: `dst: array<u32, 6>`
)
};
let module = naga::front::wgsl::parse_str(&src).map_err(|inner| { let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
CreateShaderModuleError::Parsing(naga::error::ShaderError { CreateShaderModuleError::Parsing(naga::error::ShaderError {
source: src.clone(), source: src.clone(),
@ -133,7 +147,7 @@ impl IndirectValidation {
ty: wgt::BindingType::Buffer { ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: false }, ty: wgt::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false, has_dynamic_offset: false,
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), min_binding_size: Some(DST_BUFFER_SIZE),
}, },
count: None, count: None,
}], }],
@ -153,7 +167,7 @@ impl IndirectValidation {
ty: wgt::BindingType::Buffer { ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: true }, ty: wgt::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: true, has_dynamic_offset: true,
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), min_binding_size: Some(SRC_BUFFER_SIZE),
}, },
count: None, count: None,
}], }],
@ -211,7 +225,7 @@ impl IndirectValidation {
let dst_buffer_desc = hal::BufferDescriptor { let dst_buffer_desc = hal::BufferDescriptor {
label: None, label: None,
size: 4 * 3, size: DST_BUFFER_SIZE.get(),
usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE, usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE,
memory_flags: hal::MemoryFlags::empty(), memory_flags: hal::MemoryFlags::empty(),
}; };
@ -229,7 +243,7 @@ impl IndirectValidation {
buffers: &[hal::BufferBinding { buffers: &[hal::BufferBinding {
buffer: dst_buffer.as_ref(), buffer: dst_buffer.as_ref(),
offset: 0, offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), size: Some(DST_BUFFER_SIZE),
}], }],
samplers: &[], samplers: &[],
textures: &[], textures: &[],
@ -271,7 +285,7 @@ impl IndirectValidation {
buffers: &[hal::BufferBinding { buffers: &[hal::BufferBinding {
buffer, buffer,
offset: 0, offset: 0,
size: Some(std::num::NonZeroU64::new(binding_size).unwrap()), size: Some(NonZeroU64::new(binding_size).unwrap()),
}], }],
samplers: &[], samplers: &[],
textures: &[], textures: &[],

View File

@ -1210,11 +1210,17 @@ impl crate::CommandEncoder for super::CommandEncoder {
} }
unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) { unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) {
self.prepare_dispatch([0; 3]); self.update_root_elements();
//TODO: update special constants indirectly let cmd_signature = &self
.pass
.layout
.special_constants_cmd_signatures
.as_ref()
.unwrap_or_else(|| &self.shared.cmd_signatures)
.dispatch;
unsafe { unsafe {
self.list.as_ref().unwrap().ExecuteIndirect( self.list.as_ref().unwrap().ExecuteIndirect(
&self.shared.cmd_signatures.dispatch, cmd_signature,
1, 1,
&buffer.resource, &buffer.resource,
offset, offset,

View File

@ -1,6 +1,6 @@
use std::{ use std::{
ffi, ffi,
mem::{self, size_of}, mem::{self, size_of, size_of_val},
num::NonZeroU32, num::NonZeroU32,
ptr, ptr,
sync::Arc, sync::Arc,
@ -94,34 +94,12 @@ impl super::Device {
let capacity_views = limits.max_non_sampler_bindings as u64; let capacity_views = limits.max_non_sampler_bindings as u64;
let capacity_samplers = 2_048; let capacity_samplers = 2_048;
fn create_command_signature(
raw: &Direct3D12::ID3D12Device,
byte_stride: usize,
arguments: &[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC],
node_mask: u32,
) -> Result<Direct3D12::ID3D12CommandSignature, crate::DeviceError> {
let mut signature = None;
unsafe {
raw.CreateCommandSignature(
&Direct3D12::D3D12_COMMAND_SIGNATURE_DESC {
ByteStride: byte_stride as u32,
NumArgumentDescs: arguments.len() as u32,
pArgumentDescs: arguments.as_ptr(),
NodeMask: node_mask,
},
None,
&mut signature,
)
}
.into_device_result("Command signature creation")?;
signature.ok_or(crate::DeviceError::Unexpected)
}
let shared = super::DeviceShared { let shared = super::DeviceShared {
zero_buffer, zero_buffer,
cmd_signatures: super::CommandSignatures { cmd_signatures: super::CommandSignatures {
draw: create_command_signature( draw: Self::create_command_signature(
&raw, &raw,
None,
size_of::<wgt::DrawIndirectArgs>(), size_of::<wgt::DrawIndirectArgs>(),
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC { &[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW, Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW,
@ -129,8 +107,9 @@ impl super::Device {
}], }],
0, 0,
)?, )?,
draw_indexed: create_command_signature( draw_indexed: Self::create_command_signature(
&raw, &raw,
None,
size_of::<wgt::DrawIndexedIndirectArgs>(), size_of::<wgt::DrawIndexedIndirectArgs>(),
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC { &[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW_INDEXED, Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW_INDEXED,
@ -138,8 +117,9 @@ impl super::Device {
}], }],
0, 0,
)?, )?,
dispatch: create_command_signature( dispatch: Self::create_command_signature(
&raw, &raw,
None,
size_of::<wgt::DispatchIndirectArgs>(), size_of::<wgt::DispatchIndirectArgs>(),
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC { &[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH, Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH,
@ -214,6 +194,30 @@ impl super::Device {
}) })
} }
fn create_command_signature(
raw: &Direct3D12::ID3D12Device,
root_signature: Option<&Direct3D12::ID3D12RootSignature>,
byte_stride: usize,
arguments: &[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC],
node_mask: u32,
) -> Result<Direct3D12::ID3D12CommandSignature, crate::DeviceError> {
let mut signature = None;
unsafe {
raw.CreateCommandSignature(
&Direct3D12::D3D12_COMMAND_SIGNATURE_DESC {
ByteStride: byte_stride as u32,
NumArgumentDescs: arguments.len() as u32,
pArgumentDescs: arguments.as_ptr(),
NodeMask: node_mask,
},
root_signature,
&mut signature,
)
}
.into_device_result("Command signature creation")?;
signature.ok_or(crate::DeviceError::Unexpected)
}
// Blocks until the dedicated present queue is finished with all of its work. // Blocks until the dedicated present queue is finished with all of its work.
// //
// Once this method completes, the surface is able to be resized or deleted. // Once this method completes, the surface is able to be resized or deleted.
@ -1119,6 +1123,81 @@ impl crate::Device for super::Device {
} }
.into_device_result("Root signature creation")?; .into_device_result("Root signature creation")?;
let special_constants_cmd_signatures = if let Some(root_index) =
special_constants_root_index
{
let constant_indirect_argument_desc = Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_CONSTANT,
Anonymous: Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC_0 {
Constant: Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC_0_1 {
RootParameterIndex: root_index,
DestOffsetIn32BitValues: 0,
Num32BitValuesToSet: 3,
},
},
};
let special_constant_buffer_args_len = {
// Hack: construct a dummy value of the special constants buffer value we need to
// fill, and calculate the size of each member.
let super::RootElement::SpecialConstantBuffer {
first_vertex,
first_instance,
other,
} = (super::RootElement::SpecialConstantBuffer {
first_vertex: 0,
first_instance: 0,
other: 0,
})
else {
unreachable!();
};
size_of_val(&first_vertex) + size_of_val(&first_instance) + size_of_val(&other)
};
Some(super::CommandSignatures {
draw: Self::create_command_signature(
&self.raw,
Some(&raw),
special_constant_buffer_args_len + size_of::<wgt::DrawIndirectArgs>(),
&[
constant_indirect_argument_desc,
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW,
..Default::default()
},
],
0,
)?,
draw_indexed: Self::create_command_signature(
&self.raw,
Some(&raw),
special_constant_buffer_args_len + size_of::<wgt::DrawIndexedIndirectArgs>(),
&[
constant_indirect_argument_desc,
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW_INDEXED,
..Default::default()
},
],
0,
)?,
dispatch: Self::create_command_signature(
&self.raw,
Some(&raw),
special_constant_buffer_args_len + size_of::<wgt::DispatchIndirectArgs>(),
&[
constant_indirect_argument_desc,
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH,
..Default::default()
},
],
0,
)?,
})
} else {
None
};
if let Some(label) = desc.label { if let Some(label) = desc.label {
unsafe { raw.SetName(&windows::core::HSTRING::from(label)) } unsafe { raw.SetName(&windows::core::HSTRING::from(label)) }
.into_device_result("SetName")?; .into_device_result("SetName")?;
@ -1131,6 +1210,7 @@ impl crate::Device for super::Device {
signature: Some(raw), signature: Some(raw),
total_root_elements: parameters.len() as super::RootIndex, total_root_elements: parameters.len() as super::RootIndex,
special_constants_root_index, special_constants_root_index,
special_constants_cmd_signatures,
root_constant_info, root_constant_info,
}, },
bind_group_infos, bind_group_infos,

View File

@ -564,6 +564,7 @@ struct Idler {
event: Event, event: Event,
} }
#[derive(Debug, Clone)]
struct CommandSignatures { struct CommandSignatures {
draw: Direct3D12::ID3D12CommandSignature, draw: Direct3D12::ID3D12CommandSignature,
draw_indexed: Direct3D12::ID3D12CommandSignature, draw_indexed: Direct3D12::ID3D12CommandSignature,
@ -636,8 +637,11 @@ enum RootElement {
Empty, Empty,
Constant, Constant,
SpecialConstantBuffer { SpecialConstantBuffer {
/// The first vertex in an indirect draw call, _or_ the `x` of a compute dispatch.
first_vertex: i32, first_vertex: i32,
/// The first instance in an indirect draw call, _or_ the `y` of a compute dispatch.
first_instance: u32, first_instance: u32,
/// Unused in an indirect draw call, _or_ the `z` of a compute dispatch.
other: u32, other: u32,
}, },
/// Descriptor table. /// Descriptor table.
@ -682,6 +686,7 @@ impl PassState {
signature: None, signature: None,
total_root_elements: 0, total_root_elements: 0,
special_constants_root_index: None, special_constants_root_index: None,
special_constants_cmd_signatures: None,
root_constant_info: None, root_constant_info: None,
}, },
root_elements: [RootElement::Empty; MAX_ROOT_ELEMENTS], root_elements: [RootElement::Empty; MAX_ROOT_ELEMENTS],
@ -919,6 +924,7 @@ struct PipelineLayoutShared {
signature: Option<Direct3D12::ID3D12RootSignature>, signature: Option<Direct3D12::ID3D12RootSignature>,
total_root_elements: RootIndex, total_root_elements: RootIndex,
special_constants_root_index: Option<RootIndex>, special_constants_root_index: Option<RootIndex>,
special_constants_cmd_signatures: Option<CommandSignatures>,
root_constant_info: Option<RootConstantInfo>, root_constant_info: Option<RootConstantInfo>,
} }