Ensure safety of indirect dispatch (#5714)

by injecting a compute shader that validates the content of the indirect buffer
This commit is contained in:
Teodor Tanasoaia 2024-10-14 15:02:01 +02:00 committed by GitHub
parent c0e39721a2
commit 7f708edd1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 913 additions and 22 deletions

View File

@ -133,6 +133,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
- Call `flush_mapped_ranges` when unmapping write-mapped buffers. By @teoxoy in [#6089](https://github.com/gfx-rs/wgpu/pull/6089). - Call `flush_mapped_ranges` when unmapping write-mapped buffers. By @teoxoy in [#6089](https://github.com/gfx-rs/wgpu/pull/6089).
- When mapping buffers for reading, mark buffers as initialized only when they have `MAP_WRITE` usage. By @teoxoy in [#6178](https://github.com/gfx-rs/wgpu/pull/6178). - When mapping buffers for reading, mark buffers as initialized only when they have `MAP_WRITE` usage. By @teoxoy in [#6178](https://github.com/gfx-rs/wgpu/pull/6178).
- Add a separate pipeline constants error. By @teoxoy in [#6094](https://github.com/gfx-rs/wgpu/pull/6094). - Add a separate pipeline constants error. By @teoxoy in [#6094](https://github.com/gfx-rs/wgpu/pull/6094).
- Ensure safety of indirect dispatch by injecting a compute shader that validates the content of the indirect buffer. By @teoxoy in [#5714](https://github.com/gfx-rs/wgpu/pull/5714)
#### GLES / OpenGL #### GLES / OpenGL

View File

@ -0,0 +1,241 @@
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};
/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
#[gpu_test]
static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::PUSH_CONSTANTS)
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
)
.run_async(|ctx| async move {
let num_workgroups = [1, 2, 3];
let res = run_test(&ctx, &num_workgroups, false).await;
assert_eq!(res, num_workgroups);
});
/// Make sure that we discard (don't run) the dispatch if its size exceeds the device limit.
#[gpu_test]
static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::PUSH_CONSTANTS)
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_compute_workgroups_per_dimension: 10,
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
)
.run_async(|ctx| async move {
let max = ctx.device.limits().max_compute_workgroups_per_dimension;
let res = run_test(&ctx, &[max, max, max], false).await;
assert_eq!(res, [max; 3]);
let res = run_test(&ctx, &[max + 1, 1, 1], false).await;
assert_eq!(res, [0; 3]);
let res = run_test(&ctx, &[1, max + 1, 1], false).await;
assert_eq!(res, [0; 3]);
let res = run_test(&ctx, &[1, 1, max + 1], false).await;
assert_eq!(res, [0; 3]);
});
/// Make sure that resetting the bind groups set by the validation code works properly.
#[gpu_test]
static RESET_BIND_GROUPS: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::PUSH_CONSTANTS)
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
}),
)
.run_async(|ctx| async move {
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);
let _ = run_test(&ctx, &[0, 0, 0], true).await;
let error = pollster::block_on(ctx.device.pop_error_scope());
assert!(error.map_or(false, |error| {
format!("{error}").contains("The current set ComputePipeline with '' label expects a BindGroup to be set at index 0")
}));
});
async fn run_test(
ctx: &TestingContext,
num_workgroups: &[u32; 3],
forget_to_set_bind_group: bool,
) -> [u32; 3] {
const SHADER_SRC: &str = "
struct TestOffsetPc {
inner: u32,
}
// `test_offset.inner` should always be 0; we test that resetting the push constant set by the validation code works properly.
var<push_constant> test_offset: TestOffsetPc;
@group(0) @binding(0)
var<storage, read_write> out: array<u32, 3>;
@compute @workgroup_size(1)
fn main(@builtin(num_workgroups) num_workgroups: vec3u, @builtin(workgroup_id) workgroup_id: vec3u) {
if (all(workgroup_id == vec3u())) {
out[0] = num_workgroups.x + test_offset.inner;
out[1] = num_workgroups.y + test_offset.inner;
out[2] = num_workgroups.z + test_offset.inner;
}
}
";
let module = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
});
let bgl = ctx
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgt::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}],
});
let layout = ctx
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bgl],
push_constant_ranges: &[wgt::PushConstantRange {
stages: wgt::ShaderStages::COMPUTE,
range: 0..4,
}],
});
let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&layout),
module: &module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: out_buffer.as_entire_binding(),
}],
});
let mut res = None;
for (indirect_offset, indirect_buffer_size) in [
// internal src buffer binding size will be buffer.size
(0, 12),
(4, 4 + 12),
(4, 8 + 12),
(256 * 2 - 4 - 12, 256 * 2 - 4),
// internal src buffer binding size will be 256 * 2 + x
(0, 256 * 2 * 2 + 4),
(256, 256 * 2 * 2 + 8),
(256 + 4, 256 * 2 * 2 + 12),
(256 * 2 + 16, 256 * 2 * 2 + 16),
(256 * 2 * 2, 256 * 2 * 2 + 32),
(256 + 12, 256 * 2 * 2 + 64),
] {
let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: indirect_buffer_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::INDIRECT,
mapped_at_creation: false,
});
ctx.queue.write_buffer(
&indirect_buffer,
indirect_offset,
bytemuck::bytes_of(num_workgroups),
);
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut compute_pass = encoder.begin_compute_pass(&Default::default());
compute_pass.set_pipeline(&pipeline);
compute_pass.set_push_constants(0, &[0, 0, 0, 0]);
if !forget_to_set_bind_group {
compute_pass.set_bind_group(0, Some(&bind_group), &[]);
}
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
}
encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback_buffer, 0, 12);
ctx.queue.submit(Some(encoder.finish()));
readback_buffer
.slice(..)
.map_async(wgpu::MapMode::Read, |_| {});
ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();
let view = readback_buffer.slice(..).get_mapped_range();
let current_res = *bytemuck::from_bytes(&view);
drop(view);
readback_buffer.unmap();
if let Some(past_res) = res {
assert_eq!(past_res, current_res);
} else {
res = Some(current_res);
}
}
res.unwrap()
}

View File

@ -19,6 +19,7 @@ mod clear_texture;
mod compute_pass_ownership; mod compute_pass_ownership;
mod create_surface_error; mod create_surface_error;
mod device; mod device;
mod dispatch_workgroups_indirect;
mod encoder; mod encoder;
mod external_texture; mod external_texture;
mod float32_filterable; mod float32_filterable;

View File

@ -51,6 +51,10 @@ renderdoc = ["hal/renderdoc"]
## to the validation carried out at public APIs in all builds. ## to the validation carried out at public APIs in all builds.
strict_asserts = ["wgt/strict_asserts"] strict_asserts = ["wgt/strict_asserts"]
## Validates indirect draw/dispatch calls. This will also enable naga's
## WGSL frontend since we use a WGSL compute shader to do the validation.
indirect-validation = ["naga/wgsl-in"]
## Enables serialization via `serde` on common wgpu types. ## Enables serialization via `serde` on common wgpu types.
serde = ["dep:serde", "wgt/serde", "arrayvec/serde"] serde = ["dep:serde", "wgt/serde", "arrayvec/serde"]

View File

@ -200,13 +200,17 @@ mod compat {
entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(), entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
} }
} }
fn make_range(&self, start_index: usize) -> Range<usize> {
pub fn num_valid_entries(&self) -> usize {
// find first incompatible entry // find first incompatible entry
let end = self self.entries
.entries
.iter() .iter()
.position(|e| e.is_incompatible()) .position(|e| e.is_incompatible())
.unwrap_or(self.entries.len()); .unwrap_or(self.entries.len())
}
fn make_range(&self, start_index: usize) -> Range<usize> {
let end = self.num_valid_entries();
start_index..end.max(start_index) start_index..end.max(start_index)
} }
@ -406,6 +410,14 @@ impl Binder {
.map(move |index| payloads[index].group.as_ref().unwrap()) .map(move |index| payloads[index].group.as_ref().unwrap())
} }
#[cfg(feature = "indirect-validation")]
pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + '_ {
self.payloads
.iter()
.take(self.manager.num_valid_entries())
.enumerate()
}
pub(super) fn check_compatibility<T: Labeled>( pub(super) fn check_compatibility<T: Labeled>(
&self, &self,
pipeline: &T, pipeline: &T,

View File

@ -18,7 +18,7 @@ use crate::{
pipeline::ComputePipeline, pipeline::ComputePipeline,
resource::{ resource::{
self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled, self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
MissingBufferUsageError, ParentDevice, Trackable, MissingBufferUsageError, ParentDevice,
}, },
snatch::SnatchGuard, snatch::SnatchGuard,
track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope}, track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope},
@ -216,6 +216,8 @@ struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
string_offset: usize, string_offset: usize,
active_query: Option<(Arc<resource::QuerySet>, u32)>, active_query: Option<(Arc<resource::QuerySet>, u32)>,
push_constants: Vec<u32>,
intermediate_trackers: Tracker, intermediate_trackers: Tracker,
/// Immediate texture inits required because of prior discards. Need to /// Immediate texture inits required because of prior discards. Need to
@ -443,6 +445,8 @@ impl Global {
string_offset: 0, string_offset: 0,
active_query: None, active_query: None,
push_constants: Vec::new(),
intermediate_trackers: Tracker::new(), intermediate_trackers: Tracker::new(),
pending_discard_init_fixups: SurfacesInDiscardState::new(), pending_discard_init_fixups: SurfacesInDiscardState::new(),
@ -746,6 +750,21 @@ fn set_pipeline(
} }
} }
// TODO: integrate this in the code below once we simplify push constants
state.push_constants.clear();
// Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error.
if let Some(push_constant_range) =
pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
pcr.stages
.contains(wgt::ShaderStages::COMPUTE)
.then_some(pcr.range.clone())
})
{
// Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
state.push_constants.extend(core::iter::repeat(0).take(len));
}
// Clear push constant ranges // Clear push constant ranges
let non_overlapping = let non_overlapping =
super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges); super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges);
@ -791,6 +810,10 @@ fn set_push_constant(
end_offset_bytes, end_offset_bytes,
)?; )?;
let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
state.push_constants[offset_in_elements..][..size_in_elements].copy_from_slice(data_slice);
unsafe { unsafe {
state.raw_encoder.set_push_constants( state.raw_encoder.set_push_constants(
pipeline_layout.raw(), pipeline_layout.raw(),
@ -841,10 +864,6 @@ fn dispatch_indirect(
.device .device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?; .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
state
.scope
.buffers
.merge_single(&buffer, hal::BufferUses::INDIRECT)?;
buffer.check_usage(wgt::BufferUsages::INDIRECT)?; buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
if offset % 4 != 0 { if offset % 4 != 0 {
@ -861,7 +880,6 @@ fn dispatch_indirect(
} }
let stride = 3 * 4; // 3 integers, x/y/z group size let stride = 3 * 4; // 3 integers, x/y/z group size
state state
.buffer_memory_init_actions .buffer_memory_init_actions
.extend(buffer.initialization_status.read().create_action( .extend(buffer.initialization_status.read().create_action(
@ -870,12 +888,132 @@ fn dispatch_indirect(
MemoryInitKind::NeedsInitializedMemory, MemoryInitKind::NeedsInitializedMemory,
)); ));
state.flush_states(Some(buffer.tracker_index()))?; #[cfg(feature = "indirect-validation")]
{
let params = state.device.indirect_validation.as_ref().unwrap().params(
&state.device.limits,
offset,
buffer.size,
);
let buf_raw = buffer.try_raw(&state.snatch_guard)?; unsafe {
unsafe { state.raw_encoder.set_compute_pipeline(params.pipeline);
state.raw_encoder.dispatch_indirect(buf_raw, offset); }
unsafe {
state.raw_encoder.set_push_constants(
params.pipeline_layout,
wgt::ShaderStages::COMPUTE,
0,
&[params.offset_remainder as u32 / 4],
);
}
unsafe {
state.raw_encoder.set_bind_group(
params.pipeline_layout,
0,
Some(params.dst_bind_group),
&[],
);
}
unsafe {
state.raw_encoder.set_bind_group(
params.pipeline_layout,
1,
Some(
buffer
.raw_indirect_validation_bind_group
.get(&state.snatch_guard)
.unwrap()
.as_ref(),
),
&[params.aligned_offset as u32],
);
}
let src_transition = state
.intermediate_trackers
.buffers
.set_single(&buffer, hal::BufferUses::STORAGE_READ);
let src_barrier =
src_transition.map(|transition| transition.into_hal(&buffer, &state.snatch_guard));
unsafe {
state.raw_encoder.transition_buffers(src_barrier.as_slice());
}
unsafe {
state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
buffer: params.dst_buffer,
usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE,
}]);
}
unsafe {
state.raw_encoder.dispatch([1, 1, 1]);
}
// reset state
{
let pipeline = state.pipeline.as_ref().unwrap();
unsafe {
state.raw_encoder.set_compute_pipeline(pipeline.raw());
}
if !state.push_constants.is_empty() {
unsafe {
state.raw_encoder.set_push_constants(
pipeline.layout.raw(),
wgt::ShaderStages::COMPUTE,
0,
&state.push_constants,
);
}
}
for (i, e) in state.binder.list_valid() {
let group = e.group.as_ref().unwrap();
let raw_bg = group.try_raw(&state.snatch_guard)?;
unsafe {
state.raw_encoder.set_bind_group(
pipeline.layout.raw(),
i as u32,
Some(raw_bg),
&e.dynamic_offsets,
);
}
}
}
unsafe {
state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
buffer: params.dst_buffer,
usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT,
}]);
}
state.flush_states(None)?;
unsafe {
state.raw_encoder.dispatch_indirect(params.dst_buffer, 0);
}
};
#[cfg(not(feature = "indirect-validation"))]
{
state
.scope
.buffers
.merge_single(&buffer, hal::BufferUses::INDIRECT)?;
use crate::resource::Trackable;
state.flush_states(Some(buffer.tracker_index()))?;
let buf_raw = buffer.try_raw(&state.snatch_guard)?;
unsafe {
state.raw_encoder.dispatch_indirect(buf_raw, offset);
}
} }
Ok(()) Ok(())
} }

View File

@ -406,12 +406,12 @@ impl Global {
trace.add(trace::Action::CreateBuffer(fid.id(), desc.clone())); trace.add(trace::Action::CreateBuffer(fid.id(), desc.clone()));
} }
let buffer = device.create_buffer_from_hal(Box::new(hal_buffer), desc); let (buffer, err) = device.create_buffer_from_hal(Box::new(hal_buffer), desc);
let id = fid.assign(buffer); let id = fid.assign(buffer);
api_log!("Device::create_buffer -> {id:?}"); api_log!("Device::create_buffer -> {id:?}");
(id, None) (id, err)
} }
pub fn texture_destroy(&self, texture_id: id::TextureId) -> Result<(), resource::DestroyError> { pub fn texture_destroy(&self, texture_id: id::TextureId) -> Result<(), resource::DestroyError> {

View File

@ -36,7 +36,7 @@ pub(crate) const ZERO_BUFFER_SIZE: BufferAddress = 512 << 10;
// See https://github.com/gfx-rs/wgpu/issues/4589. 60s to reduce the chances of this. // See https://github.com/gfx-rs/wgpu/issues/4589. 60s to reduce the chances of this.
const CLEANUP_WAIT_MS: u32 = 60000; const CLEANUP_WAIT_MS: u32 = 60000;
const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid"; pub(crate) const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid";
pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>; pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>;

View File

@ -31,7 +31,7 @@ use crate::{
UsageScopePool, UsageScopePool,
}, },
validation::{self, validate_color_attachment_bytes_per_sample}, validation::{self, validate_color_attachment_bytes_per_sample},
FastHashMap, LabelHelpers as _, PreHashedKey, PreHashedMap, FastHashMap, LabelHelpers, PreHashedKey, PreHashedMap,
}; };
use arrayvec::ArrayVec; use arrayvec::ArrayVec;
@ -144,6 +144,9 @@ pub struct Device {
#[cfg(feature = "trace")] #[cfg(feature = "trace")]
pub(crate) trace: Mutex<Option<trace::Trace>>, pub(crate) trace: Mutex<Option<trace::Trace>>,
pub(crate) usage_scopes: UsageScopePool, pub(crate) usage_scopes: UsageScopePool,
#[cfg(feature = "indirect-validation")]
pub(crate) indirect_validation: Option<crate::indirect_validation::IndirectValidation>,
} }
pub(crate) enum DeferredDestroy { pub(crate) enum DeferredDestroy {
@ -175,6 +178,11 @@ impl Drop for Device {
let fence = unsafe { ManuallyDrop::take(&mut self.fence.write()) }; let fence = unsafe { ManuallyDrop::take(&mut self.fence.write()) };
pending_writes.dispose(raw.as_ref()); pending_writes.dispose(raw.as_ref());
self.command_allocator.dispose(raw.as_ref()); self.command_allocator.dispose(raw.as_ref());
#[cfg(feature = "indirect-validation")]
self.indirect_validation
.take()
.unwrap()
.dispose(raw.as_ref());
unsafe { unsafe {
raw.destroy_buffer(zero_buffer); raw.destroy_buffer(zero_buffer);
raw.destroy_fence(fence); raw.destroy_fence(fence);
@ -261,6 +269,25 @@ impl Device {
let alignments = adapter.raw.capabilities.alignments.clone(); let alignments = adapter.raw.capabilities.alignments.clone();
let downlevel = adapter.raw.capabilities.downlevel.clone(); let downlevel = adapter.raw.capabilities.downlevel.clone();
#[cfg(feature = "indirect-validation")]
let indirect_validation = if downlevel
.flags
.contains(wgt::DownlevelFlags::INDIRECT_EXECUTION)
{
match crate::indirect_validation::IndirectValidation::new(
raw_device.as_ref(),
&desc.required_limits,
) {
Ok(indirect_validation) => Some(indirect_validation),
Err(e) => {
log::error!("indirect-validation error: {e:?}");
return Err(DeviceError::Lost);
}
}
} else {
None
};
Ok(Self { Ok(Self {
raw: ManuallyDrop::new(raw_device), raw: ManuallyDrop::new(raw_device),
adapter: adapter.clone(), adapter: adapter.clone(),
@ -306,6 +333,8 @@ impl Device {
), ),
deferred_destroy: Mutex::new(rank::DEVICE_DEFERRED_DESTROY, Vec::new()), deferred_destroy: Mutex::new(rank::DEVICE_DEFERRED_DESTROY, Vec::new()),
usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()), usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()),
#[cfg(feature = "indirect-validation")]
indirect_validation,
}) })
} }
@ -547,6 +576,13 @@ impl Device {
let mut usage = conv::map_buffer_usage(desc.usage); let mut usage = conv::map_buffer_usage(desc.usage);
if desc.usage.contains(wgt::BufferUsages::INDIRECT) {
self.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
// We are going to be reading from it, internally;
// when validating the content of the buffer
usage |= hal::BufferUses::STORAGE_READ | hal::BufferUses::STORAGE_READ_WRITE;
}
if desc.mapped_at_creation { if desc.mapped_at_creation {
if desc.size % wgt::COPY_BUFFER_ALIGNMENT != 0 { if desc.size % wgt::COPY_BUFFER_ALIGNMENT != 0 {
return Err(resource::CreateBufferError::UnalignedSize); return Err(resource::CreateBufferError::UnalignedSize);
@ -586,6 +622,10 @@ impl Device {
let buffer = let buffer =
unsafe { self.raw().create_buffer(&hal_desc) }.map_err(|e| self.handle_hal_error(e))?; unsafe { self.raw().create_buffer(&hal_desc) }.map_err(|e| self.handle_hal_error(e))?;
#[cfg(feature = "indirect-validation")]
let raw_indirect_validation_bind_group =
self.create_indirect_validation_bind_group(buffer.as_ref(), desc.size, desc.usage)?;
let buffer = Buffer { let buffer = Buffer {
raw: Snatchable::new(buffer), raw: Snatchable::new(buffer),
device: self.clone(), device: self.clone(),
@ -599,6 +639,8 @@ impl Device {
label: desc.label.to_string(), label: desc.label.to_string(),
tracking_data: TrackingData::new(self.tracker_indices.buffers.clone()), tracking_data: TrackingData::new(self.tracker_indices.buffers.clone()),
bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()), bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()),
#[cfg(feature = "indirect-validation")]
raw_indirect_validation_bind_group,
}; };
let buffer = Arc::new(buffer); let buffer = Arc::new(buffer);
@ -686,7 +728,17 @@ impl Device {
self: &Arc<Self>, self: &Arc<Self>,
hal_buffer: Box<dyn hal::DynBuffer>, hal_buffer: Box<dyn hal::DynBuffer>,
desc: &resource::BufferDescriptor, desc: &resource::BufferDescriptor,
) -> Fallible<Buffer> { ) -> (Fallible<Buffer>, Option<resource::CreateBufferError>) {
#[cfg(feature = "indirect-validation")]
let raw_indirect_validation_bind_group = match self.create_indirect_validation_bind_group(
hal_buffer.as_ref(),
desc.size,
desc.usage,
) {
Ok(ok) => ok,
Err(e) => return (Fallible::Invalid(Arc::new(desc.label.to_string())), Some(e)),
};
unsafe { self.raw().add_raw_buffer(&*hal_buffer) }; unsafe { self.raw().add_raw_buffer(&*hal_buffer) };
let buffer = Buffer { let buffer = Buffer {
@ -702,6 +754,8 @@ impl Device {
label: desc.label.to_string(), label: desc.label.to_string(),
tracking_data: TrackingData::new(self.tracker_indices.buffers.clone()), tracking_data: TrackingData::new(self.tracker_indices.buffers.clone()),
bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()), bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()),
#[cfg(feature = "indirect-validation")]
raw_indirect_validation_bind_group,
}; };
let buffer = Arc::new(buffer); let buffer = Arc::new(buffer);
@ -711,7 +765,25 @@ impl Device {
.buffers .buffers
.insert_single(&buffer, hal::BufferUses::empty()); .insert_single(&buffer, hal::BufferUses::empty());
Fallible::Valid(buffer) (Fallible::Valid(buffer), None)
}
#[cfg(feature = "indirect-validation")]
fn create_indirect_validation_bind_group(
&self,
raw_buffer: &dyn hal::DynBuffer,
buffer_size: u64,
usage: wgt::BufferUsages,
) -> Result<Snatchable<Box<dyn hal::DynBindGroup>>, resource::CreateBufferError> {
if usage.contains(wgt::BufferUsages::INDIRECT) {
let indirect_validation = self.indirect_validation.as_ref().unwrap();
let bind_group = indirect_validation
.create_src_bind_group(self.raw(), &self.limits, buffer_size, raw_buffer)
.map_err(resource::CreateBufferError::IndirectValidationBindGroup)?;
Ok(Snatchable::new(bind_group))
} else {
Ok(Snatchable::empty())
}
} }
pub(crate) fn create_texture( pub(crate) fn create_texture(

View File

@ -0,0 +1,378 @@
use thiserror::Error;
use crate::{
device::DeviceError,
pipeline::{CreateComputePipelineError, CreateShaderModuleError},
};
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum CreateDispatchIndirectValidationPipelineError {
#[error(transparent)]
DeviceError(#[from] DeviceError),
#[error(transparent)]
ShaderModule(#[from] CreateShaderModuleError),
#[error(transparent)]
ComputePipeline(#[from] CreateComputePipelineError),
}
/// This machinery requires the following limits:
///
/// - max_bind_groups: 2,
/// - max_dynamic_storage_buffers_per_pipeline_layout: 1,
/// - max_storage_buffers_per_shader_stage: 2,
/// - max_storage_buffer_binding_size: 3 * min_storage_buffer_offset_alignment,
/// - max_push_constant_size: 4,
/// - max_compute_invocations_per_workgroup 1
///
/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also
/// required for this module's functionality to work.
#[derive(Debug)]
pub struct IndirectValidation {
module: Box<dyn hal::DynShaderModule>,
dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
pipeline_layout: Box<dyn hal::DynPipelineLayout>,
pipeline: Box<dyn hal::DynComputePipeline>,
dst_buffer: Box<dyn hal::DynBuffer>,
dst_bind_group: Box<dyn hal::DynBindGroup>,
}
pub struct Params<'a> {
pub pipeline_layout: &'a dyn hal::DynPipelineLayout,
pub pipeline: &'a dyn hal::DynComputePipeline,
pub dst_buffer: &'a dyn hal::DynBuffer,
pub dst_bind_group: &'a dyn hal::DynBindGroup,
pub aligned_offset: u64,
pub offset_remainder: u64,
}
impl IndirectValidation {
pub fn new(
device: &dyn hal::DynDevice,
limits: &wgt::Limits,
) -> Result<Self, CreateDispatchIndirectValidationPipelineError> {
let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension;
let src = format!(
"
@group(0) @binding(0)
var<storage, read_write> dst: array<u32, 3>;
@group(1) @binding(0)
var<storage, read> src: array<u32>;
struct OffsetPc {{
inner: u32,
}}
var<push_constant> offset: OffsetPc;
@compute @workgroup_size(1)
fn main() {{
let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]);
let max_compute_workgroups_per_dimension = {max_compute_workgroups_per_dimension}u;
if (
src.x > max_compute_workgroups_per_dimension ||
src.y > max_compute_workgroups_per_dimension ||
src.z > max_compute_workgroups_per_dimension
) {{
dst = array(0u, 0u, 0u);
}} else {{
dst = array(src.x, src.y, src.z);
}}
}}
"
);
let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
CreateShaderModuleError::Parsing(naga::error::ShaderError {
source: src.clone(),
label: None,
inner: Box::new(inner),
})
})?;
let info = crate::device::create_validator(
wgt::Features::PUSH_CONSTANTS,
wgt::DownlevelFlags::empty(),
naga::valid::ValidationFlags::all(),
)
.validate(&module)
.map_err(|inner| {
CreateShaderModuleError::Validation(naga::error::ShaderError {
source: src,
label: None,
inner: Box::new(inner),
})
})?;
let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
module: std::borrow::Cow::Owned(module),
info,
debug_source: None,
});
let hal_desc = hal::ShaderModuleDescriptor {
label: None,
runtime_checks: false,
};
let module =
unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(|error| {
match error {
hal::ShaderError::Device(error) => {
CreateShaderModuleError::Device(DeviceError::from_hal(error))
}
hal::ShaderError::Compilation(ref msg) => {
log::error!("Shader error: {}", msg);
CreateShaderModuleError::Generation
}
}
})?;
let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
label: None,
flags: hal::BindGroupLayoutFlags::empty(),
entries: &[wgt::BindGroupLayoutEntry {
binding: 0,
visibility: wgt::ShaderStages::COMPUTE,
ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
},
count: None,
}],
};
let dst_bind_group_layout = unsafe {
device
.create_bind_group_layout(&dst_bind_group_layout_desc)
.map_err(DeviceError::from_hal)?
};
let src_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
label: None,
flags: hal::BindGroupLayoutFlags::empty(),
entries: &[wgt::BindGroupLayoutEntry {
binding: 0,
visibility: wgt::ShaderStages::COMPUTE,
ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: true,
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
},
count: None,
}],
};
let src_bind_group_layout = unsafe {
device
.create_bind_group_layout(&src_bind_group_layout_desc)
.map_err(DeviceError::from_hal)?
};
let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
label: None,
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE,
bind_group_layouts: &[
dst_bind_group_layout.as_ref(),
src_bind_group_layout.as_ref(),
],
push_constant_ranges: &[wgt::PushConstantRange {
stages: wgt::ShaderStages::COMPUTE,
range: 0..4,
}],
};
let pipeline_layout = unsafe {
device
.create_pipeline_layout(&pipeline_layout_desc)
.map_err(DeviceError::from_hal)?
};
let pipeline_desc = hal::ComputePipelineDescriptor {
label: None,
layout: pipeline_layout.as_ref(),
stage: hal::ProgrammableStage {
module: module.as_ref(),
entry_point: "main",
constants: &Default::default(),
zero_initialize_workgroup_memory: false,
},
cache: None,
};
let pipeline =
unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
hal::PipelineError::Device(error) => {
CreateComputePipelineError::Device(DeviceError::from_hal(error))
}
hal::PipelineError::Linkage(_stages, msg) => {
CreateComputePipelineError::Internal(msg)
}
hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
),
hal::PipelineError::PipelineConstants(_, error) => {
CreateComputePipelineError::PipelineConstants(error)
}
})?;
let dst_buffer_desc = hal::BufferDescriptor {
label: None,
size: 4 * 3,
usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE,
memory_flags: hal::MemoryFlags::empty(),
};
let dst_buffer =
unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?;
let dst_bind_group_desc = hal::BindGroupDescriptor {
label: None,
layout: dst_bind_group_layout.as_ref(),
entries: &[hal::BindGroupEntry {
binding: 0,
resource_index: 0,
count: 1,
}],
buffers: &[hal::BufferBinding {
buffer: dst_buffer.as_ref(),
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
}],
samplers: &[],
textures: &[],
acceleration_structures: &[],
};
let dst_bind_group = unsafe {
device
.create_bind_group(&dst_bind_group_desc)
.map_err(DeviceError::from_hal)
}?;
Ok(Self {
module,
dst_bind_group_layout,
src_bind_group_layout,
pipeline_layout,
pipeline,
dst_buffer,
dst_bind_group,
})
}
pub fn create_src_bind_group(
&self,
device: &dyn hal::DynDevice,
limits: &wgt::Limits,
buffer_size: u64,
buffer: &dyn hal::DynBuffer,
) -> Result<Box<dyn hal::DynBindGroup>, DeviceError> {
let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
let hal_desc = hal::BindGroupDescriptor {
label: None,
layout: self.src_bind_group_layout.as_ref(),
entries: &[hal::BindGroupEntry {
binding: 0,
resource_index: 0,
count: 1,
}],
buffers: &[hal::BufferBinding {
buffer,
offset: 0,
size: Some(std::num::NonZeroU64::new(binding_size).unwrap()),
}],
samplers: &[],
textures: &[],
acceleration_structures: &[],
};
unsafe {
device
.create_bind_group(&hal_desc)
.map_err(DeviceError::from_hal)
}
}
pub fn params<'a>(&'a self, limits: &wgt::Limits, offset: u64, buffer_size: u64) -> Params<'a> {
// The offset we receive is only required to be aligned to 4 bytes.
//
// Binding offsets and dynamic offsets are required to be aligned to
// min_storage_buffer_offset_alignment (256 bytes by default).
//
// So, we work around this limitation by calculating an aligned offset
// and pass the remainder through a push constant.
//
// We could bind the whole buffer and only have to pass the offset
// through a push constant but we might run into the
// max_storage_buffer_binding_size limit.
//
// See the inner docs of `calculate_src_buffer_binding_size` to
// see how we get the appropriate `binding_size`.
let alignment = limits.min_storage_buffer_offset_alignment as u64;
let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
let aligned_offset = offset - offset % alignment;
// This works because `binding_size` is either `buffer_size` or `alignment * 2 + buffer_size % alignment`.
let max_aligned_offset = buffer_size - binding_size;
let aligned_offset = aligned_offset.min(max_aligned_offset);
let offset_remainder = offset - aligned_offset;
Params {
pipeline_layout: self.pipeline_layout.as_ref(),
pipeline: self.pipeline.as_ref(),
dst_buffer: self.dst_buffer.as_ref(),
dst_bind_group: self.dst_bind_group.as_ref(),
aligned_offset,
offset_remainder,
}
}
pub fn dispose(self, device: &dyn hal::DynDevice) {
let IndirectValidation {
module,
dst_bind_group_layout,
src_bind_group_layout,
pipeline_layout,
pipeline,
dst_buffer,
dst_bind_group,
} = self;
unsafe {
device.destroy_bind_group(dst_bind_group);
device.destroy_buffer(dst_buffer);
device.destroy_compute_pipeline(pipeline);
device.destroy_pipeline_layout(pipeline_layout);
device.destroy_bind_group_layout(src_bind_group_layout);
device.destroy_bind_group_layout(dst_bind_group_layout);
device.destroy_shader_module(module);
}
}
}
fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &wgt::Limits) -> u64 {
let alignment = limits.min_storage_buffer_offset_alignment as u64;
// We need to choose a binding size that can address all possible sets of 12 contiguous bytes in the buffer taking
// into account that the dynamic offset needs to be a multiple of `min_storage_buffer_offset_alignment`.
// Given the know variables: `offset`, `buffer_size`, `alignment` and the rule `offset + 12 <= buffer_size`.
// Let `chunks = floor(buffer_size / alignment)`.
// Let `chunk` be the interval `[0, chunks]`.
// Let `offset = alignment * chunk + r` where `r` is the interval [0, alignment - 4].
// Let `binding` be the interval `[offset, offset + 12]`.
// Let `aligned_offset = alignment * chunk`.
// Let `aligned_binding` be the interval `[aligned_offset, aligned_offset + r + 12]`.
// Let `aligned_binding_size = r + 12 = [12, alignment + 8]`.
// Let `min_aligned_binding_size = alignment + 8`.
// `min_aligned_binding_size` is the minimum binding size required to address all 12 contiguous bytes in the buffer
// but the last aligned_offset + min_aligned_binding_size might overflow the buffer. In order to avoid this we must
// pick a larger `binding_size` that satisfies: `last_aligned_offset + binding_size = buffer_size` and
// `binding_size >= min_aligned_binding_size`.
// Let `buffer_size = alignment * chunks + sr` where `sr` is the interval [0, alignment - 4].
// Let `last_aligned_offset = alignment * (chunks - u)` where `u` is the interval [0, chunks].
// => `binding_size = buffer_size - last_aligned_offset`
// => `binding_size = alignment * chunks + sr - alignment * (chunks - u)`
// => `binding_size = alignment * chunks + sr - alignment * chunks + alignment * u`
// => `binding_size = sr + alignment * u`
// => `min_aligned_binding_size <= sr + alignment * u`
// => `alignment + 8 <= sr + alignment * u`
// => `u` must be at least 2
// => `binding_size = sr + alignment * 2`
let binding_size = 2 * alignment + (buffer_size % alignment);
binding_size.min(buffer_size)
}

View File

@ -67,6 +67,8 @@ mod hash_utils;
pub mod hub; pub mod hub;
pub mod id; pub mod id;
pub mod identity; pub mod identity;
#[cfg(feature = "indirect-validation")]
mod indirect_validation;
mod init_tracker; mod init_tracker;
pub mod instance; pub mod instance;
mod lock; mod lock;

View File

@ -92,7 +92,7 @@ impl ShaderModule {
#[derive(Clone, Debug, Error)] #[derive(Clone, Debug, Error)]
#[non_exhaustive] #[non_exhaustive]
pub enum CreateShaderModuleError { pub enum CreateShaderModuleError {
#[cfg(feature = "wgsl")] #[cfg(any(feature = "wgsl", feature = "indirect-validation"))]
#[error(transparent)] #[error(transparent)]
Parsing(#[from] ShaderError<naga::front::wgsl::ParseError>), Parsing(#[from] ShaderError<naga::front::wgsl::ParseError>),
#[cfg(feature = "glsl")] #[cfg(feature = "glsl")]

View File

@ -475,10 +475,18 @@ pub struct Buffer {
pub(crate) tracking_data: TrackingData, pub(crate) tracking_data: TrackingData,
pub(crate) map_state: Mutex<BufferMapState>, pub(crate) map_state: Mutex<BufferMapState>,
pub(crate) bind_groups: Mutex<Vec<Weak<BindGroup>>>, pub(crate) bind_groups: Mutex<Vec<Weak<BindGroup>>>,
#[cfg(feature = "indirect-validation")]
pub(crate) raw_indirect_validation_bind_group: Snatchable<Box<dyn hal::DynBindGroup>>,
} }
impl Drop for Buffer { impl Drop for Buffer {
fn drop(&mut self) { fn drop(&mut self) {
#[cfg(feature = "indirect-validation")]
if let Some(raw) = self.raw_indirect_validation_bind_group.take() {
unsafe {
self.device.raw().destroy_bind_group(raw);
}
}
if let Some(raw) = self.raw.take() { if let Some(raw) = self.raw.take() {
resource_log!("Destroy raw {}", self.error_ident()); resource_log!("Destroy raw {}", self.error_ident());
unsafe { unsafe {
@ -737,13 +745,22 @@ impl Buffer {
let device = &self.device; let device = &self.device;
let temp = { let temp = {
let raw = match self.raw.snatch(&mut device.snatchable_lock.write()) { let mut snatch_guard = device.snatchable_lock.write();
let raw = match self.raw.snatch(&mut snatch_guard) {
Some(raw) => raw, Some(raw) => raw,
None => { None => {
return Err(DestroyError::AlreadyDestroyed); return Err(DestroyError::AlreadyDestroyed);
} }
}; };
#[cfg(feature = "indirect-validation")]
let raw_indirect_validation_bind_group = self
.raw_indirect_validation_bind_group
.snatch(&mut snatch_guard);
drop(snatch_guard);
let bind_groups = { let bind_groups = {
let mut guard = self.bind_groups.lock(); let mut guard = self.bind_groups.lock();
mem::take(&mut *guard) mem::take(&mut *guard)
@ -754,6 +771,8 @@ impl Buffer {
device: Arc::clone(&self.device), device: Arc::clone(&self.device),
label: self.label().to_owned(), label: self.label().to_owned(),
bind_groups, bind_groups,
#[cfg(feature = "indirect-validation")]
raw_indirect_validation_bind_group,
}) })
}; };
@ -789,6 +808,8 @@ pub enum CreateBufferError {
MaxBufferSize { requested: u64, maximum: u64 }, MaxBufferSize { requested: u64, maximum: u64 },
#[error(transparent)] #[error(transparent)]
MissingDownlevelFlags(#[from] MissingDownlevelFlags), MissingDownlevelFlags(#[from] MissingDownlevelFlags),
#[error("Failed to create bind group for indirect buffer validation: {0}")]
IndirectValidationBindGroup(DeviceError),
} }
crate::impl_resource_type!(Buffer); crate::impl_resource_type!(Buffer);
@ -804,6 +825,8 @@ pub struct DestroyedBuffer {
device: Arc<Device>, device: Arc<Device>,
label: String, label: String,
bind_groups: Vec<Weak<BindGroup>>, bind_groups: Vec<Weak<BindGroup>>,
#[cfg(feature = "indirect-validation")]
raw_indirect_validation_bind_group: Option<Box<dyn hal::DynBindGroup>>,
} }
impl DestroyedBuffer { impl DestroyedBuffer {
@ -820,6 +843,13 @@ impl Drop for DestroyedBuffer {
} }
drop(deferred); drop(deferred);
#[cfg(feature = "indirect-validation")]
if let Some(raw) = self.raw_indirect_validation_bind_group.take() {
unsafe {
self.device.raw().destroy_bind_group(raw);
}
}
resource_log!("Destroy raw Buffer (destroyed) {:?}", self.label()); resource_log!("Destroy raw Buffer (destroyed) {:?}", self.label());
// SAFETY: We are in the Drop impl and we don't use self.raw anymore after this point. // SAFETY: We are in the Drop impl and we don't use self.raw anymore after this point.
let raw = unsafe { ManuallyDrop::take(&mut self.raw) }; let raw = unsafe { ManuallyDrop::take(&mut self.raw) };

View File

@ -32,6 +32,12 @@ impl<T> Snatchable<T> {
} }
} }
pub fn empty() -> Self {
Snatchable {
value: UnsafeCell::new(None),
}
}
/// Get read access to the value. Requires a the snatchable lock's read guard. /// Get read access to the value. Requires a the snatchable lock's read guard.
pub fn get<'a>(&'a self, _guard: &'a SnatchGuard) -> Option<&'a T> { pub fn get<'a>(&'a self, _guard: &'a SnatchGuard) -> Option<&'a T> {
unsafe { (*self.value.get()).as_ref() } unsafe { (*self.value.get()).as_ref() }

View File

@ -130,6 +130,12 @@ features = ["raw-window-handle"]
workspace = true workspace = true
features = ["raw-window-handle"] features = ["raw-window-handle"]
# If we are not targeting WebGL, enable indirect-validation.
# WebGL doesn't support indirect execution so this is not needed.
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.wgc]
workspace = true
features = ["indirect-validation"]
# Enable `wgc` by default on macOS and iOS to allow the `metal` crate feature to # Enable `wgc` by default on macOS and iOS to allow the `metal` crate feature to
# enable the Metal backend while being no-op on other targets. # enable the Metal backend while being no-op on other targets.
[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies.wgc] [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies.wgc]