mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-21 22:33:49 +00:00
[d3d12] get num_workgroups
builtin working for indirect dispatches
This commit is contained in:
parent
7f708edd1f
commit
bbee35b145
@ -185,6 +185,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
|
||||
#### DX12
|
||||
|
||||
- Replace `winapi` code to use the `windows` crate. By @MarijnS95 in [#5956](https://github.com/gfx-rs/wgpu/pull/5956) and [#6173](https://github.com/gfx-rs/wgpu/pull/6173)
|
||||
- Get `num_workgroups` builtin working for indirect dispatches. By @teoxoy in [#5730](https://github.com/gfx-rs/wgpu/pull/5730)
|
||||
|
||||
#### HAL
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};
|
||||
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};
|
||||
|
||||
/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
|
||||
#[gpu_test]
|
||||
@ -12,8 +12,7 @@ static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new(
|
||||
.limits(wgpu::Limits {
|
||||
max_push_constant_size: 4,
|
||||
..wgpu::Limits::downlevel_defaults()
|
||||
})
|
||||
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
|
||||
}),
|
||||
)
|
||||
.run_async(|ctx| async move {
|
||||
let num_workgroups = [1, 2, 3];
|
||||
@ -34,8 +33,7 @@ static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
|
||||
max_compute_workgroups_per_dimension: 10,
|
||||
max_push_constant_size: 4,
|
||||
..wgpu::Limits::downlevel_defaults()
|
||||
})
|
||||
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
|
||||
}),
|
||||
)
|
||||
.run_async(|ctx| async move {
|
||||
let max = ctx.device.limits().max_compute_workgroups_per_dimension;
|
||||
|
@ -2638,7 +2638,8 @@ impl Device {
|
||||
|
||||
let hal_desc = hal::PipelineLayoutDescriptor {
|
||||
label: desc.label.to_hal(self.instance_flags),
|
||||
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE,
|
||||
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE
|
||||
| hal::PipelineLayoutFlags::NUM_WORK_GROUPS,
|
||||
bind_group_layouts: &raw_bind_group_layouts,
|
||||
push_constant_ranges: desc.push_constant_ranges.as_ref(),
|
||||
};
|
||||
|
@ -1,3 +1,6 @@
|
||||
use std::mem::size_of;
|
||||
use std::num::NonZeroU64;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{
|
||||
@ -57,7 +60,7 @@ impl IndirectValidation {
|
||||
let src = format!(
|
||||
"
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> dst: array<u32, 3>;
|
||||
var<storage, read_write> dst: array<u32, 6>;
|
||||
@group(1) @binding(0)
|
||||
var<storage, read> src: array<u32>;
|
||||
struct OffsetPc {{
|
||||
@ -74,14 +77,25 @@ impl IndirectValidation {
|
||||
src.y > max_compute_workgroups_per_dimension ||
|
||||
src.z > max_compute_workgroups_per_dimension
|
||||
) {{
|
||||
dst = array(0u, 0u, 0u);
|
||||
dst = array(0u, 0u, 0u, 0u, 0u, 0u);
|
||||
}} else {{
|
||||
dst = array(src.x, src.y, src.z);
|
||||
dst = array(src.x, src.y, src.z, src.x, src.y, src.z);
|
||||
}}
|
||||
}}
|
||||
"
|
||||
);
|
||||
|
||||
// SAFETY: The value we are passing to `new_unchecked` is not zero, so this is safe.
|
||||
const SRC_BUFFER_SIZE: NonZeroU64 =
|
||||
unsafe { NonZeroU64::new_unchecked(size_of::<u32>() as u64 * 3) };
|
||||
|
||||
// SAFETY: The value we are passing to `new_unchecked` is not zero, so this is safe.
|
||||
const DST_BUFFER_SIZE: NonZeroU64 = unsafe {
|
||||
NonZeroU64::new_unchecked(
|
||||
SRC_BUFFER_SIZE.get() * 2, // From above: `dst: array<u32, 6>`
|
||||
)
|
||||
};
|
||||
|
||||
let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
|
||||
CreateShaderModuleError::Parsing(naga::error::ShaderError {
|
||||
source: src.clone(),
|
||||
@ -133,7 +147,7 @@ impl IndirectValidation {
|
||||
ty: wgt::BindingType::Buffer {
|
||||
ty: wgt::BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
|
||||
min_binding_size: Some(DST_BUFFER_SIZE),
|
||||
},
|
||||
count: None,
|
||||
}],
|
||||
@ -153,7 +167,7 @@ impl IndirectValidation {
|
||||
ty: wgt::BindingType::Buffer {
|
||||
ty: wgt::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: true,
|
||||
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
|
||||
min_binding_size: Some(SRC_BUFFER_SIZE),
|
||||
},
|
||||
count: None,
|
||||
}],
|
||||
@ -211,7 +225,7 @@ impl IndirectValidation {
|
||||
|
||||
let dst_buffer_desc = hal::BufferDescriptor {
|
||||
label: None,
|
||||
size: 4 * 3,
|
||||
size: DST_BUFFER_SIZE.get(),
|
||||
usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE,
|
||||
memory_flags: hal::MemoryFlags::empty(),
|
||||
};
|
||||
@ -229,7 +243,7 @@ impl IndirectValidation {
|
||||
buffers: &[hal::BufferBinding {
|
||||
buffer: dst_buffer.as_ref(),
|
||||
offset: 0,
|
||||
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
|
||||
size: Some(DST_BUFFER_SIZE),
|
||||
}],
|
||||
samplers: &[],
|
||||
textures: &[],
|
||||
@ -271,7 +285,7 @@ impl IndirectValidation {
|
||||
buffers: &[hal::BufferBinding {
|
||||
buffer,
|
||||
offset: 0,
|
||||
size: Some(std::num::NonZeroU64::new(binding_size).unwrap()),
|
||||
size: Some(NonZeroU64::new(binding_size).unwrap()),
|
||||
}],
|
||||
samplers: &[],
|
||||
textures: &[],
|
||||
|
@ -1210,11 +1210,17 @@ impl crate::CommandEncoder for super::CommandEncoder {
|
||||
}
|
||||
|
||||
unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) {
|
||||
self.prepare_dispatch([0; 3]);
|
||||
//TODO: update special constants indirectly
|
||||
self.update_root_elements();
|
||||
let cmd_signature = &self
|
||||
.pass
|
||||
.layout
|
||||
.special_constants_cmd_signatures
|
||||
.as_ref()
|
||||
.unwrap_or_else(|| &self.shared.cmd_signatures)
|
||||
.dispatch;
|
||||
unsafe {
|
||||
self.list.as_ref().unwrap().ExecuteIndirect(
|
||||
&self.shared.cmd_signatures.dispatch,
|
||||
cmd_signature,
|
||||
1,
|
||||
&buffer.resource,
|
||||
offset,
|
||||
|
@ -1,6 +1,6 @@
|
||||
use std::{
|
||||
ffi,
|
||||
mem::{self, size_of},
|
||||
mem::{self, size_of, size_of_val},
|
||||
num::NonZeroU32,
|
||||
ptr,
|
||||
sync::Arc,
|
||||
@ -94,34 +94,12 @@ impl super::Device {
|
||||
let capacity_views = limits.max_non_sampler_bindings as u64;
|
||||
let capacity_samplers = 2_048;
|
||||
|
||||
fn create_command_signature(
|
||||
raw: &Direct3D12::ID3D12Device,
|
||||
byte_stride: usize,
|
||||
arguments: &[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC],
|
||||
node_mask: u32,
|
||||
) -> Result<Direct3D12::ID3D12CommandSignature, crate::DeviceError> {
|
||||
let mut signature = None;
|
||||
unsafe {
|
||||
raw.CreateCommandSignature(
|
||||
&Direct3D12::D3D12_COMMAND_SIGNATURE_DESC {
|
||||
ByteStride: byte_stride as u32,
|
||||
NumArgumentDescs: arguments.len() as u32,
|
||||
pArgumentDescs: arguments.as_ptr(),
|
||||
NodeMask: node_mask,
|
||||
},
|
||||
None,
|
||||
&mut signature,
|
||||
)
|
||||
}
|
||||
.into_device_result("Command signature creation")?;
|
||||
signature.ok_or(crate::DeviceError::Unexpected)
|
||||
}
|
||||
|
||||
let shared = super::DeviceShared {
|
||||
zero_buffer,
|
||||
cmd_signatures: super::CommandSignatures {
|
||||
draw: create_command_signature(
|
||||
draw: Self::create_command_signature(
|
||||
&raw,
|
||||
None,
|
||||
size_of::<wgt::DrawIndirectArgs>(),
|
||||
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
|
||||
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW,
|
||||
@ -129,8 +107,9 @@ impl super::Device {
|
||||
}],
|
||||
0,
|
||||
)?,
|
||||
draw_indexed: create_command_signature(
|
||||
draw_indexed: Self::create_command_signature(
|
||||
&raw,
|
||||
None,
|
||||
size_of::<wgt::DrawIndexedIndirectArgs>(),
|
||||
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
|
||||
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW_INDEXED,
|
||||
@ -138,8 +117,9 @@ impl super::Device {
|
||||
}],
|
||||
0,
|
||||
)?,
|
||||
dispatch: create_command_signature(
|
||||
dispatch: Self::create_command_signature(
|
||||
&raw,
|
||||
None,
|
||||
size_of::<wgt::DispatchIndirectArgs>(),
|
||||
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
|
||||
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH,
|
||||
@ -214,6 +194,30 @@ impl super::Device {
|
||||
})
|
||||
}
|
||||
|
||||
fn create_command_signature(
|
||||
raw: &Direct3D12::ID3D12Device,
|
||||
root_signature: Option<&Direct3D12::ID3D12RootSignature>,
|
||||
byte_stride: usize,
|
||||
arguments: &[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC],
|
||||
node_mask: u32,
|
||||
) -> Result<Direct3D12::ID3D12CommandSignature, crate::DeviceError> {
|
||||
let mut signature = None;
|
||||
unsafe {
|
||||
raw.CreateCommandSignature(
|
||||
&Direct3D12::D3D12_COMMAND_SIGNATURE_DESC {
|
||||
ByteStride: byte_stride as u32,
|
||||
NumArgumentDescs: arguments.len() as u32,
|
||||
pArgumentDescs: arguments.as_ptr(),
|
||||
NodeMask: node_mask,
|
||||
},
|
||||
root_signature,
|
||||
&mut signature,
|
||||
)
|
||||
}
|
||||
.into_device_result("Command signature creation")?;
|
||||
signature.ok_or(crate::DeviceError::Unexpected)
|
||||
}
|
||||
|
||||
// Blocks until the dedicated present queue is finished with all of its work.
|
||||
//
|
||||
// Once this method completes, the surface is able to be resized or deleted.
|
||||
@ -1119,6 +1123,81 @@ impl crate::Device for super::Device {
|
||||
}
|
||||
.into_device_result("Root signature creation")?;
|
||||
|
||||
let special_constants_cmd_signatures = if let Some(root_index) =
|
||||
special_constants_root_index
|
||||
{
|
||||
let constant_indirect_argument_desc = Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
|
||||
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_CONSTANT,
|
||||
Anonymous: Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC_0 {
|
||||
Constant: Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC_0_1 {
|
||||
RootParameterIndex: root_index,
|
||||
DestOffsetIn32BitValues: 0,
|
||||
Num32BitValuesToSet: 3,
|
||||
},
|
||||
},
|
||||
};
|
||||
let special_constant_buffer_args_len = {
|
||||
// Hack: construct a dummy value of the special constants buffer value we need to
|
||||
// fill, and calculate the size of each member.
|
||||
let super::RootElement::SpecialConstantBuffer {
|
||||
first_vertex,
|
||||
first_instance,
|
||||
other,
|
||||
} = (super::RootElement::SpecialConstantBuffer {
|
||||
first_vertex: 0,
|
||||
first_instance: 0,
|
||||
other: 0,
|
||||
})
|
||||
else {
|
||||
unreachable!();
|
||||
};
|
||||
size_of_val(&first_vertex) + size_of_val(&first_instance) + size_of_val(&other)
|
||||
};
|
||||
Some(super::CommandSignatures {
|
||||
draw: Self::create_command_signature(
|
||||
&self.raw,
|
||||
Some(&raw),
|
||||
special_constant_buffer_args_len + size_of::<wgt::DrawIndirectArgs>(),
|
||||
&[
|
||||
constant_indirect_argument_desc,
|
||||
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
|
||||
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW,
|
||||
..Default::default()
|
||||
},
|
||||
],
|
||||
0,
|
||||
)?,
|
||||
draw_indexed: Self::create_command_signature(
|
||||
&self.raw,
|
||||
Some(&raw),
|
||||
special_constant_buffer_args_len + size_of::<wgt::DrawIndexedIndirectArgs>(),
|
||||
&[
|
||||
constant_indirect_argument_desc,
|
||||
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
|
||||
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW_INDEXED,
|
||||
..Default::default()
|
||||
},
|
||||
],
|
||||
0,
|
||||
)?,
|
||||
dispatch: Self::create_command_signature(
|
||||
&self.raw,
|
||||
Some(&raw),
|
||||
special_constant_buffer_args_len + size_of::<wgt::DispatchIndirectArgs>(),
|
||||
&[
|
||||
constant_indirect_argument_desc,
|
||||
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
|
||||
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH,
|
||||
..Default::default()
|
||||
},
|
||||
],
|
||||
0,
|
||||
)?,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(label) = desc.label {
|
||||
unsafe { raw.SetName(&windows::core::HSTRING::from(label)) }
|
||||
.into_device_result("SetName")?;
|
||||
@ -1131,6 +1210,7 @@ impl crate::Device for super::Device {
|
||||
signature: Some(raw),
|
||||
total_root_elements: parameters.len() as super::RootIndex,
|
||||
special_constants_root_index,
|
||||
special_constants_cmd_signatures,
|
||||
root_constant_info,
|
||||
},
|
||||
bind_group_infos,
|
||||
|
@ -564,6 +564,7 @@ struct Idler {
|
||||
event: Event,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CommandSignatures {
|
||||
draw: Direct3D12::ID3D12CommandSignature,
|
||||
draw_indexed: Direct3D12::ID3D12CommandSignature,
|
||||
@ -636,8 +637,11 @@ enum RootElement {
|
||||
Empty,
|
||||
Constant,
|
||||
SpecialConstantBuffer {
|
||||
/// The first vertex in an indirect draw call, _or_ the `x` of a compute dispatch.
|
||||
first_vertex: i32,
|
||||
/// The first instance in an indirect draw call, _or_ the `y` of a compute dispatch.
|
||||
first_instance: u32,
|
||||
/// Unused in an indirect draw call, _or_ the `z` of a compute dispatch.
|
||||
other: u32,
|
||||
},
|
||||
/// Descriptor table.
|
||||
@ -682,6 +686,7 @@ impl PassState {
|
||||
signature: None,
|
||||
total_root_elements: 0,
|
||||
special_constants_root_index: None,
|
||||
special_constants_cmd_signatures: None,
|
||||
root_constant_info: None,
|
||||
},
|
||||
root_elements: [RootElement::Empty; MAX_ROOT_ELEMENTS],
|
||||
@ -919,6 +924,7 @@ struct PipelineLayoutShared {
|
||||
signature: Option<Direct3D12::ID3D12RootSignature>,
|
||||
total_root_elements: RootIndex,
|
||||
special_constants_root_index: Option<RootIndex>,
|
||||
special_constants_cmd_signatures: Option<CommandSignatures>,
|
||||
root_constant_info: Option<RootConstantInfo>,
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user