mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-25 00:03:29 +00:00
Ensure safety of indirect dispatch (#5714)
by injecting a compute shader that validates the content of the indirect buffer
This commit is contained in:
parent
c0e39721a2
commit
7f708edd1f
@ -133,6 +133,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
|
||||
- Call `flush_mapped_ranges` when unmapping write-mapped buffers. By @teoxoy in [#6089](https://github.com/gfx-rs/wgpu/pull/6089).
|
||||
- When mapping buffers for reading, mark buffers as initialized only when they have `MAP_WRITE` usage. By @teoxoy in [#6178](https://github.com/gfx-rs/wgpu/pull/6178).
|
||||
- Add a separate pipeline constants error. By @teoxoy in [#6094](https://github.com/gfx-rs/wgpu/pull/6094).
|
||||
- Ensure safety of indirect dispatch by injecting a compute shader that validates the content of the indirect buffer. By @teoxoy in [#5714](https://github.com/gfx-rs/wgpu/pull/5714)
|
||||
|
||||
#### GLES / OpenGL
|
||||
|
||||
|
241
tests/tests/dispatch_workgroups_indirect.rs
Normal file
241
tests/tests/dispatch_workgroups_indirect.rs
Normal file
@ -0,0 +1,241 @@
|
||||
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};
|
||||
|
||||
/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
|
||||
#[gpu_test]
|
||||
static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new()
|
||||
.parameters(
|
||||
TestParameters::default()
|
||||
.features(wgpu::Features::PUSH_CONSTANTS)
|
||||
.downlevel_flags(
|
||||
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
|
||||
)
|
||||
.limits(wgpu::Limits {
|
||||
max_push_constant_size: 4,
|
||||
..wgpu::Limits::downlevel_defaults()
|
||||
})
|
||||
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
|
||||
)
|
||||
.run_async(|ctx| async move {
|
||||
let num_workgroups = [1, 2, 3];
|
||||
let res = run_test(&ctx, &num_workgroups, false).await;
|
||||
assert_eq!(res, num_workgroups);
|
||||
});
|
||||
|
||||
/// Make sure that we discard (don't run) the dispatch if its size exceeds the device limit.
|
||||
#[gpu_test]
|
||||
static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
|
||||
.parameters(
|
||||
TestParameters::default()
|
||||
.features(wgpu::Features::PUSH_CONSTANTS)
|
||||
.downlevel_flags(
|
||||
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
|
||||
)
|
||||
.limits(wgpu::Limits {
|
||||
max_compute_workgroups_per_dimension: 10,
|
||||
max_push_constant_size: 4,
|
||||
..wgpu::Limits::downlevel_defaults()
|
||||
})
|
||||
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
|
||||
)
|
||||
.run_async(|ctx| async move {
|
||||
let max = ctx.device.limits().max_compute_workgroups_per_dimension;
|
||||
|
||||
let res = run_test(&ctx, &[max, max, max], false).await;
|
||||
assert_eq!(res, [max; 3]);
|
||||
|
||||
let res = run_test(&ctx, &[max + 1, 1, 1], false).await;
|
||||
assert_eq!(res, [0; 3]);
|
||||
|
||||
let res = run_test(&ctx, &[1, max + 1, 1], false).await;
|
||||
assert_eq!(res, [0; 3]);
|
||||
|
||||
let res = run_test(&ctx, &[1, 1, max + 1], false).await;
|
||||
assert_eq!(res, [0; 3]);
|
||||
});
|
||||
|
||||
/// Make sure that resetting the bind groups set by the validation code works properly.
|
||||
#[gpu_test]
|
||||
static RESET_BIND_GROUPS: GpuTestConfiguration = GpuTestConfiguration::new()
|
||||
.parameters(
|
||||
TestParameters::default()
|
||||
.features(wgpu::Features::PUSH_CONSTANTS)
|
||||
.downlevel_flags(
|
||||
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
|
||||
)
|
||||
.limits(wgpu::Limits {
|
||||
max_push_constant_size: 4,
|
||||
..wgpu::Limits::downlevel_defaults()
|
||||
}),
|
||||
)
|
||||
.run_async(|ctx| async move {
|
||||
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);
|
||||
|
||||
let _ = run_test(&ctx, &[0, 0, 0], true).await;
|
||||
|
||||
let error = pollster::block_on(ctx.device.pop_error_scope());
|
||||
assert!(error.map_or(false, |error| {
|
||||
format!("{error}").contains("The current set ComputePipeline with '' label expects a BindGroup to be set at index 0")
|
||||
}));
|
||||
});
|
||||
|
||||
async fn run_test(
|
||||
ctx: &TestingContext,
|
||||
num_workgroups: &[u32; 3],
|
||||
forget_to_set_bind_group: bool,
|
||||
) -> [u32; 3] {
|
||||
const SHADER_SRC: &str = "
|
||||
struct TestOffsetPc {
|
||||
inner: u32,
|
||||
}
|
||||
|
||||
// `test_offset.inner` should always be 0; we test that resetting the push constant set by the validation code works properly.
|
||||
var<push_constant> test_offset: TestOffsetPc;
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> out: array<u32, 3>;
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main(@builtin(num_workgroups) num_workgroups: vec3u, @builtin(workgroup_id) workgroup_id: vec3u) {
|
||||
if (all(workgroup_id == vec3u())) {
|
||||
out[0] = num_workgroups.x + test_offset.inner;
|
||||
out[1] = num_workgroups.y + test_offset.inner;
|
||||
out[2] = num_workgroups.z + test_offset.inner;
|
||||
}
|
||||
}
|
||||
";
|
||||
|
||||
let module = ctx
|
||||
.device
|
||||
.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
label: None,
|
||||
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
|
||||
});
|
||||
|
||||
let bgl = ctx
|
||||
.device
|
||||
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||
label: None,
|
||||
entries: &[wgpu::BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: wgt::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
}],
|
||||
});
|
||||
|
||||
let layout = ctx
|
||||
.device
|
||||
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||
label: None,
|
||||
bind_group_layouts: &[&bgl],
|
||||
push_constant_ranges: &[wgt::PushConstantRange {
|
||||
stages: wgt::ShaderStages::COMPUTE,
|
||||
range: 0..4,
|
||||
}],
|
||||
});
|
||||
|
||||
let pipeline = ctx
|
||||
.device
|
||||
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: Some(&layout),
|
||||
module: &module,
|
||||
entry_point: Some("main"),
|
||||
compilation_options: Default::default(),
|
||||
cache: None,
|
||||
});
|
||||
|
||||
let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: None,
|
||||
size: 12,
|
||||
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: None,
|
||||
size: 12,
|
||||
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: None,
|
||||
layout: &pipeline.get_bind_group_layout(0),
|
||||
entries: &[wgpu::BindGroupEntry {
|
||||
binding: 0,
|
||||
resource: out_buffer.as_entire_binding(),
|
||||
}],
|
||||
});
|
||||
|
||||
let mut res = None;
|
||||
|
||||
for (indirect_offset, indirect_buffer_size) in [
|
||||
// internal src buffer binding size will be buffer.size
|
||||
(0, 12),
|
||||
(4, 4 + 12),
|
||||
(4, 8 + 12),
|
||||
(256 * 2 - 4 - 12, 256 * 2 - 4),
|
||||
// internal src buffer binding size will be 256 * 2 + x
|
||||
(0, 256 * 2 * 2 + 4),
|
||||
(256, 256 * 2 * 2 + 8),
|
||||
(256 + 4, 256 * 2 * 2 + 12),
|
||||
(256 * 2 + 16, 256 * 2 * 2 + 16),
|
||||
(256 * 2 * 2, 256 * 2 * 2 + 32),
|
||||
(256 + 12, 256 * 2 * 2 + 64),
|
||||
] {
|
||||
let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: None,
|
||||
size: indirect_buffer_size,
|
||||
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::INDIRECT,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
ctx.queue.write_buffer(
|
||||
&indirect_buffer,
|
||||
indirect_offset,
|
||||
bytemuck::bytes_of(num_workgroups),
|
||||
);
|
||||
|
||||
let mut encoder = ctx.device.create_command_encoder(&Default::default());
|
||||
{
|
||||
let mut compute_pass = encoder.begin_compute_pass(&Default::default());
|
||||
compute_pass.set_pipeline(&pipeline);
|
||||
compute_pass.set_push_constants(0, &[0, 0, 0, 0]);
|
||||
if !forget_to_set_bind_group {
|
||||
compute_pass.set_bind_group(0, Some(&bind_group), &[]);
|
||||
}
|
||||
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
|
||||
}
|
||||
|
||||
encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback_buffer, 0, 12);
|
||||
|
||||
ctx.queue.submit(Some(encoder.finish()));
|
||||
|
||||
readback_buffer
|
||||
.slice(..)
|
||||
.map_async(wgpu::MapMode::Read, |_| {});
|
||||
|
||||
ctx.async_poll(wgpu::Maintain::wait())
|
||||
.await
|
||||
.panic_on_timeout();
|
||||
|
||||
let view = readback_buffer.slice(..).get_mapped_range();
|
||||
|
||||
let current_res = *bytemuck::from_bytes(&view);
|
||||
drop(view);
|
||||
readback_buffer.unmap();
|
||||
|
||||
if let Some(past_res) = res {
|
||||
assert_eq!(past_res, current_res);
|
||||
} else {
|
||||
res = Some(current_res);
|
||||
}
|
||||
}
|
||||
|
||||
res.unwrap()
|
||||
}
|
@ -19,6 +19,7 @@ mod clear_texture;
|
||||
mod compute_pass_ownership;
|
||||
mod create_surface_error;
|
||||
mod device;
|
||||
mod dispatch_workgroups_indirect;
|
||||
mod encoder;
|
||||
mod external_texture;
|
||||
mod float32_filterable;
|
||||
|
@ -51,6 +51,10 @@ renderdoc = ["hal/renderdoc"]
|
||||
## to the validation carried out at public APIs in all builds.
|
||||
strict_asserts = ["wgt/strict_asserts"]
|
||||
|
||||
## Validates indirect draw/dispatch calls. This will also enable naga's
|
||||
## WGSL frontend since we use a WGSL compute shader to do the validation.
|
||||
indirect-validation = ["naga/wgsl-in"]
|
||||
|
||||
## Enables serialization via `serde` on common wgpu types.
|
||||
serde = ["dep:serde", "wgt/serde", "arrayvec/serde"]
|
||||
|
||||
|
@ -200,13 +200,17 @@ mod compat {
|
||||
entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
|
||||
}
|
||||
}
|
||||
fn make_range(&self, start_index: usize) -> Range<usize> {
|
||||
|
||||
pub fn num_valid_entries(&self) -> usize {
|
||||
// find first incompatible entry
|
||||
let end = self
|
||||
.entries
|
||||
self.entries
|
||||
.iter()
|
||||
.position(|e| e.is_incompatible())
|
||||
.unwrap_or(self.entries.len());
|
||||
.unwrap_or(self.entries.len())
|
||||
}
|
||||
|
||||
fn make_range(&self, start_index: usize) -> Range<usize> {
|
||||
let end = self.num_valid_entries();
|
||||
start_index..end.max(start_index)
|
||||
}
|
||||
|
||||
@ -406,6 +410,14 @@ impl Binder {
|
||||
.map(move |index| payloads[index].group.as_ref().unwrap())
|
||||
}
|
||||
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + '_ {
|
||||
self.payloads
|
||||
.iter()
|
||||
.take(self.manager.num_valid_entries())
|
||||
.enumerate()
|
||||
}
|
||||
|
||||
pub(super) fn check_compatibility<T: Labeled>(
|
||||
&self,
|
||||
pipeline: &T,
|
||||
|
@ -18,7 +18,7 @@ use crate::{
|
||||
pipeline::ComputePipeline,
|
||||
resource::{
|
||||
self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
|
||||
MissingBufferUsageError, ParentDevice, Trackable,
|
||||
MissingBufferUsageError, ParentDevice,
|
||||
},
|
||||
snatch::SnatchGuard,
|
||||
track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope},
|
||||
@ -216,6 +216,8 @@ struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
|
||||
string_offset: usize,
|
||||
active_query: Option<(Arc<resource::QuerySet>, u32)>,
|
||||
|
||||
push_constants: Vec<u32>,
|
||||
|
||||
intermediate_trackers: Tracker,
|
||||
|
||||
/// Immediate texture inits required because of prior discards. Need to
|
||||
@ -443,6 +445,8 @@ impl Global {
|
||||
string_offset: 0,
|
||||
active_query: None,
|
||||
|
||||
push_constants: Vec::new(),
|
||||
|
||||
intermediate_trackers: Tracker::new(),
|
||||
|
||||
pending_discard_init_fixups: SurfacesInDiscardState::new(),
|
||||
@ -746,6 +750,21 @@ fn set_pipeline(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: integrate this in the code below once we simplify push constants
|
||||
state.push_constants.clear();
|
||||
// Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error.
|
||||
if let Some(push_constant_range) =
|
||||
pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
|
||||
pcr.stages
|
||||
.contains(wgt::ShaderStages::COMPUTE)
|
||||
.then_some(pcr.range.clone())
|
||||
})
|
||||
{
|
||||
// Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
|
||||
let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
|
||||
state.push_constants.extend(core::iter::repeat(0).take(len));
|
||||
}
|
||||
|
||||
// Clear push constant ranges
|
||||
let non_overlapping =
|
||||
super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges);
|
||||
@ -791,6 +810,10 @@ fn set_push_constant(
|
||||
end_offset_bytes,
|
||||
)?;
|
||||
|
||||
let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
|
||||
let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
|
||||
state.push_constants[offset_in_elements..][..size_in_elements].copy_from_slice(data_slice);
|
||||
|
||||
unsafe {
|
||||
state.raw_encoder.set_push_constants(
|
||||
pipeline_layout.raw(),
|
||||
@ -841,10 +864,6 @@ fn dispatch_indirect(
|
||||
.device
|
||||
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
|
||||
|
||||
state
|
||||
.scope
|
||||
.buffers
|
||||
.merge_single(&buffer, hal::BufferUses::INDIRECT)?;
|
||||
buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
|
||||
|
||||
if offset % 4 != 0 {
|
||||
@ -861,7 +880,6 @@ fn dispatch_indirect(
|
||||
}
|
||||
|
||||
let stride = 3 * 4; // 3 integers, x/y/z group size
|
||||
|
||||
state
|
||||
.buffer_memory_init_actions
|
||||
.extend(buffer.initialization_status.read().create_action(
|
||||
@ -870,12 +888,132 @@ fn dispatch_indirect(
|
||||
MemoryInitKind::NeedsInitializedMemory,
|
||||
));
|
||||
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
{
|
||||
let params = state.device.indirect_validation.as_ref().unwrap().params(
|
||||
&state.device.limits,
|
||||
offset,
|
||||
buffer.size,
|
||||
);
|
||||
|
||||
unsafe {
|
||||
state.raw_encoder.set_compute_pipeline(params.pipeline);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
state.raw_encoder.set_push_constants(
|
||||
params.pipeline_layout,
|
||||
wgt::ShaderStages::COMPUTE,
|
||||
0,
|
||||
&[params.offset_remainder as u32 / 4],
|
||||
);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
state.raw_encoder.set_bind_group(
|
||||
params.pipeline_layout,
|
||||
0,
|
||||
Some(params.dst_bind_group),
|
||||
&[],
|
||||
);
|
||||
}
|
||||
unsafe {
|
||||
state.raw_encoder.set_bind_group(
|
||||
params.pipeline_layout,
|
||||
1,
|
||||
Some(
|
||||
buffer
|
||||
.raw_indirect_validation_bind_group
|
||||
.get(&state.snatch_guard)
|
||||
.unwrap()
|
||||
.as_ref(),
|
||||
),
|
||||
&[params.aligned_offset as u32],
|
||||
);
|
||||
}
|
||||
|
||||
let src_transition = state
|
||||
.intermediate_trackers
|
||||
.buffers
|
||||
.set_single(&buffer, hal::BufferUses::STORAGE_READ);
|
||||
let src_barrier =
|
||||
src_transition.map(|transition| transition.into_hal(&buffer, &state.snatch_guard));
|
||||
unsafe {
|
||||
state.raw_encoder.transition_buffers(src_barrier.as_slice());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
|
||||
buffer: params.dst_buffer,
|
||||
usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE,
|
||||
}]);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
state.raw_encoder.dispatch([1, 1, 1]);
|
||||
}
|
||||
|
||||
// reset state
|
||||
{
|
||||
let pipeline = state.pipeline.as_ref().unwrap();
|
||||
|
||||
unsafe {
|
||||
state.raw_encoder.set_compute_pipeline(pipeline.raw());
|
||||
}
|
||||
|
||||
if !state.push_constants.is_empty() {
|
||||
unsafe {
|
||||
state.raw_encoder.set_push_constants(
|
||||
pipeline.layout.raw(),
|
||||
wgt::ShaderStages::COMPUTE,
|
||||
0,
|
||||
&state.push_constants,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (i, e) in state.binder.list_valid() {
|
||||
let group = e.group.as_ref().unwrap();
|
||||
let raw_bg = group.try_raw(&state.snatch_guard)?;
|
||||
unsafe {
|
||||
state.raw_encoder.set_bind_group(
|
||||
pipeline.layout.raw(),
|
||||
i as u32,
|
||||
Some(raw_bg),
|
||||
&e.dynamic_offsets,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe {
|
||||
state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
|
||||
buffer: params.dst_buffer,
|
||||
usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT,
|
||||
}]);
|
||||
}
|
||||
|
||||
state.flush_states(None)?;
|
||||
unsafe {
|
||||
state.raw_encoder.dispatch_indirect(params.dst_buffer, 0);
|
||||
}
|
||||
};
|
||||
#[cfg(not(feature = "indirect-validation"))]
|
||||
{
|
||||
state
|
||||
.scope
|
||||
.buffers
|
||||
.merge_single(&buffer, hal::BufferUses::INDIRECT)?;
|
||||
|
||||
use crate::resource::Trackable;
|
||||
state.flush_states(Some(buffer.tracker_index()))?;
|
||||
|
||||
let buf_raw = buffer.try_raw(&state.snatch_guard)?;
|
||||
unsafe {
|
||||
state.raw_encoder.dispatch_indirect(buf_raw, offset);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -406,12 +406,12 @@ impl Global {
|
||||
trace.add(trace::Action::CreateBuffer(fid.id(), desc.clone()));
|
||||
}
|
||||
|
||||
let buffer = device.create_buffer_from_hal(Box::new(hal_buffer), desc);
|
||||
let (buffer, err) = device.create_buffer_from_hal(Box::new(hal_buffer), desc);
|
||||
|
||||
let id = fid.assign(buffer);
|
||||
api_log!("Device::create_buffer -> {id:?}");
|
||||
|
||||
(id, None)
|
||||
(id, err)
|
||||
}
|
||||
|
||||
pub fn texture_destroy(&self, texture_id: id::TextureId) -> Result<(), resource::DestroyError> {
|
||||
|
@ -36,7 +36,7 @@ pub(crate) const ZERO_BUFFER_SIZE: BufferAddress = 512 << 10;
|
||||
// See https://github.com/gfx-rs/wgpu/issues/4589. 60s to reduce the chances of this.
|
||||
const CLEANUP_WAIT_MS: u32 = 60000;
|
||||
|
||||
const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid";
|
||||
pub(crate) const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid";
|
||||
|
||||
pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>;
|
||||
|
||||
|
@ -31,7 +31,7 @@ use crate::{
|
||||
UsageScopePool,
|
||||
},
|
||||
validation::{self, validate_color_attachment_bytes_per_sample},
|
||||
FastHashMap, LabelHelpers as _, PreHashedKey, PreHashedMap,
|
||||
FastHashMap, LabelHelpers, PreHashedKey, PreHashedMap,
|
||||
};
|
||||
|
||||
use arrayvec::ArrayVec;
|
||||
@ -144,6 +144,9 @@ pub struct Device {
|
||||
#[cfg(feature = "trace")]
|
||||
pub(crate) trace: Mutex<Option<trace::Trace>>,
|
||||
pub(crate) usage_scopes: UsageScopePool,
|
||||
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
pub(crate) indirect_validation: Option<crate::indirect_validation::IndirectValidation>,
|
||||
}
|
||||
|
||||
pub(crate) enum DeferredDestroy {
|
||||
@ -175,6 +178,11 @@ impl Drop for Device {
|
||||
let fence = unsafe { ManuallyDrop::take(&mut self.fence.write()) };
|
||||
pending_writes.dispose(raw.as_ref());
|
||||
self.command_allocator.dispose(raw.as_ref());
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
self.indirect_validation
|
||||
.take()
|
||||
.unwrap()
|
||||
.dispose(raw.as_ref());
|
||||
unsafe {
|
||||
raw.destroy_buffer(zero_buffer);
|
||||
raw.destroy_fence(fence);
|
||||
@ -261,6 +269,25 @@ impl Device {
|
||||
let alignments = adapter.raw.capabilities.alignments.clone();
|
||||
let downlevel = adapter.raw.capabilities.downlevel.clone();
|
||||
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
let indirect_validation = if downlevel
|
||||
.flags
|
||||
.contains(wgt::DownlevelFlags::INDIRECT_EXECUTION)
|
||||
{
|
||||
match crate::indirect_validation::IndirectValidation::new(
|
||||
raw_device.as_ref(),
|
||||
&desc.required_limits,
|
||||
) {
|
||||
Ok(indirect_validation) => Some(indirect_validation),
|
||||
Err(e) => {
|
||||
log::error!("indirect-validation error: {e:?}");
|
||||
return Err(DeviceError::Lost);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
raw: ManuallyDrop::new(raw_device),
|
||||
adapter: adapter.clone(),
|
||||
@ -306,6 +333,8 @@ impl Device {
|
||||
),
|
||||
deferred_destroy: Mutex::new(rank::DEVICE_DEFERRED_DESTROY, Vec::new()),
|
||||
usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()),
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
indirect_validation,
|
||||
})
|
||||
}
|
||||
|
||||
@ -547,6 +576,13 @@ impl Device {
|
||||
|
||||
let mut usage = conv::map_buffer_usage(desc.usage);
|
||||
|
||||
if desc.usage.contains(wgt::BufferUsages::INDIRECT) {
|
||||
self.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
|
||||
// We are going to be reading from it, internally;
|
||||
// when validating the content of the buffer
|
||||
usage |= hal::BufferUses::STORAGE_READ | hal::BufferUses::STORAGE_READ_WRITE;
|
||||
}
|
||||
|
||||
if desc.mapped_at_creation {
|
||||
if desc.size % wgt::COPY_BUFFER_ALIGNMENT != 0 {
|
||||
return Err(resource::CreateBufferError::UnalignedSize);
|
||||
@ -586,6 +622,10 @@ impl Device {
|
||||
let buffer =
|
||||
unsafe { self.raw().create_buffer(&hal_desc) }.map_err(|e| self.handle_hal_error(e))?;
|
||||
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
let raw_indirect_validation_bind_group =
|
||||
self.create_indirect_validation_bind_group(buffer.as_ref(), desc.size, desc.usage)?;
|
||||
|
||||
let buffer = Buffer {
|
||||
raw: Snatchable::new(buffer),
|
||||
device: self.clone(),
|
||||
@ -599,6 +639,8 @@ impl Device {
|
||||
label: desc.label.to_string(),
|
||||
tracking_data: TrackingData::new(self.tracker_indices.buffers.clone()),
|
||||
bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()),
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
raw_indirect_validation_bind_group,
|
||||
};
|
||||
|
||||
let buffer = Arc::new(buffer);
|
||||
@ -686,7 +728,17 @@ impl Device {
|
||||
self: &Arc<Self>,
|
||||
hal_buffer: Box<dyn hal::DynBuffer>,
|
||||
desc: &resource::BufferDescriptor,
|
||||
) -> Fallible<Buffer> {
|
||||
) -> (Fallible<Buffer>, Option<resource::CreateBufferError>) {
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
let raw_indirect_validation_bind_group = match self.create_indirect_validation_bind_group(
|
||||
hal_buffer.as_ref(),
|
||||
desc.size,
|
||||
desc.usage,
|
||||
) {
|
||||
Ok(ok) => ok,
|
||||
Err(e) => return (Fallible::Invalid(Arc::new(desc.label.to_string())), Some(e)),
|
||||
};
|
||||
|
||||
unsafe { self.raw().add_raw_buffer(&*hal_buffer) };
|
||||
|
||||
let buffer = Buffer {
|
||||
@ -702,6 +754,8 @@ impl Device {
|
||||
label: desc.label.to_string(),
|
||||
tracking_data: TrackingData::new(self.tracker_indices.buffers.clone()),
|
||||
bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()),
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
raw_indirect_validation_bind_group,
|
||||
};
|
||||
|
||||
let buffer = Arc::new(buffer);
|
||||
@ -711,7 +765,25 @@ impl Device {
|
||||
.buffers
|
||||
.insert_single(&buffer, hal::BufferUses::empty());
|
||||
|
||||
Fallible::Valid(buffer)
|
||||
(Fallible::Valid(buffer), None)
|
||||
}
|
||||
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
fn create_indirect_validation_bind_group(
|
||||
&self,
|
||||
raw_buffer: &dyn hal::DynBuffer,
|
||||
buffer_size: u64,
|
||||
usage: wgt::BufferUsages,
|
||||
) -> Result<Snatchable<Box<dyn hal::DynBindGroup>>, resource::CreateBufferError> {
|
||||
if usage.contains(wgt::BufferUsages::INDIRECT) {
|
||||
let indirect_validation = self.indirect_validation.as_ref().unwrap();
|
||||
let bind_group = indirect_validation
|
||||
.create_src_bind_group(self.raw(), &self.limits, buffer_size, raw_buffer)
|
||||
.map_err(resource::CreateBufferError::IndirectValidationBindGroup)?;
|
||||
Ok(Snatchable::new(bind_group))
|
||||
} else {
|
||||
Ok(Snatchable::empty())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn create_texture(
|
||||
|
378
wgpu-core/src/indirect_validation.rs
Normal file
378
wgpu-core/src/indirect_validation.rs
Normal file
@ -0,0 +1,378 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{
|
||||
device::DeviceError,
|
||||
pipeline::{CreateComputePipelineError, CreateShaderModuleError},
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug, Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum CreateDispatchIndirectValidationPipelineError {
|
||||
#[error(transparent)]
|
||||
DeviceError(#[from] DeviceError),
|
||||
#[error(transparent)]
|
||||
ShaderModule(#[from] CreateShaderModuleError),
|
||||
#[error(transparent)]
|
||||
ComputePipeline(#[from] CreateComputePipelineError),
|
||||
}
|
||||
|
||||
/// This machinery requires the following limits:
|
||||
///
|
||||
/// - max_bind_groups: 2,
|
||||
/// - max_dynamic_storage_buffers_per_pipeline_layout: 1,
|
||||
/// - max_storage_buffers_per_shader_stage: 2,
|
||||
/// - max_storage_buffer_binding_size: 3 * min_storage_buffer_offset_alignment,
|
||||
/// - max_push_constant_size: 4,
|
||||
/// - max_compute_invocations_per_workgroup 1
|
||||
///
|
||||
/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also
|
||||
/// required for this module's functionality to work.
|
||||
#[derive(Debug)]
|
||||
pub struct IndirectValidation {
|
||||
module: Box<dyn hal::DynShaderModule>,
|
||||
dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
|
||||
src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
|
||||
pipeline_layout: Box<dyn hal::DynPipelineLayout>,
|
||||
pipeline: Box<dyn hal::DynComputePipeline>,
|
||||
dst_buffer: Box<dyn hal::DynBuffer>,
|
||||
dst_bind_group: Box<dyn hal::DynBindGroup>,
|
||||
}
|
||||
|
||||
pub struct Params<'a> {
|
||||
pub pipeline_layout: &'a dyn hal::DynPipelineLayout,
|
||||
pub pipeline: &'a dyn hal::DynComputePipeline,
|
||||
pub dst_buffer: &'a dyn hal::DynBuffer,
|
||||
pub dst_bind_group: &'a dyn hal::DynBindGroup,
|
||||
pub aligned_offset: u64,
|
||||
pub offset_remainder: u64,
|
||||
}
|
||||
|
||||
impl IndirectValidation {
|
||||
pub fn new(
|
||||
device: &dyn hal::DynDevice,
|
||||
limits: &wgt::Limits,
|
||||
) -> Result<Self, CreateDispatchIndirectValidationPipelineError> {
|
||||
let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension;
|
||||
|
||||
let src = format!(
|
||||
"
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> dst: array<u32, 3>;
|
||||
@group(1) @binding(0)
|
||||
var<storage, read> src: array<u32>;
|
||||
struct OffsetPc {{
|
||||
inner: u32,
|
||||
}}
|
||||
var<push_constant> offset: OffsetPc;
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main() {{
|
||||
let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]);
|
||||
let max_compute_workgroups_per_dimension = {max_compute_workgroups_per_dimension}u;
|
||||
if (
|
||||
src.x > max_compute_workgroups_per_dimension ||
|
||||
src.y > max_compute_workgroups_per_dimension ||
|
||||
src.z > max_compute_workgroups_per_dimension
|
||||
) {{
|
||||
dst = array(0u, 0u, 0u);
|
||||
}} else {{
|
||||
dst = array(src.x, src.y, src.z);
|
||||
}}
|
||||
}}
|
||||
"
|
||||
);
|
||||
|
||||
let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
|
||||
CreateShaderModuleError::Parsing(naga::error::ShaderError {
|
||||
source: src.clone(),
|
||||
label: None,
|
||||
inner: Box::new(inner),
|
||||
})
|
||||
})?;
|
||||
let info = crate::device::create_validator(
|
||||
wgt::Features::PUSH_CONSTANTS,
|
||||
wgt::DownlevelFlags::empty(),
|
||||
naga::valid::ValidationFlags::all(),
|
||||
)
|
||||
.validate(&module)
|
||||
.map_err(|inner| {
|
||||
CreateShaderModuleError::Validation(naga::error::ShaderError {
|
||||
source: src,
|
||||
label: None,
|
||||
inner: Box::new(inner),
|
||||
})
|
||||
})?;
|
||||
let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
|
||||
module: std::borrow::Cow::Owned(module),
|
||||
info,
|
||||
debug_source: None,
|
||||
});
|
||||
let hal_desc = hal::ShaderModuleDescriptor {
|
||||
label: None,
|
||||
runtime_checks: false,
|
||||
};
|
||||
let module =
|
||||
unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(|error| {
|
||||
match error {
|
||||
hal::ShaderError::Device(error) => {
|
||||
CreateShaderModuleError::Device(DeviceError::from_hal(error))
|
||||
}
|
||||
hal::ShaderError::Compilation(ref msg) => {
|
||||
log::error!("Shader error: {}", msg);
|
||||
CreateShaderModuleError::Generation
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
|
||||
label: None,
|
||||
flags: hal::BindGroupLayoutFlags::empty(),
|
||||
entries: &[wgt::BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: wgt::ShaderStages::COMPUTE,
|
||||
ty: wgt::BindingType::Buffer {
|
||||
ty: wgt::BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
|
||||
},
|
||||
count: None,
|
||||
}],
|
||||
};
|
||||
let dst_bind_group_layout = unsafe {
|
||||
device
|
||||
.create_bind_group_layout(&dst_bind_group_layout_desc)
|
||||
.map_err(DeviceError::from_hal)?
|
||||
};
|
||||
|
||||
let src_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
|
||||
label: None,
|
||||
flags: hal::BindGroupLayoutFlags::empty(),
|
||||
entries: &[wgt::BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: wgt::ShaderStages::COMPUTE,
|
||||
ty: wgt::BindingType::Buffer {
|
||||
ty: wgt::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: true,
|
||||
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
|
||||
},
|
||||
count: None,
|
||||
}],
|
||||
};
|
||||
let src_bind_group_layout = unsafe {
|
||||
device
|
||||
.create_bind_group_layout(&src_bind_group_layout_desc)
|
||||
.map_err(DeviceError::from_hal)?
|
||||
};
|
||||
|
||||
let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
|
||||
label: None,
|
||||
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE,
|
||||
bind_group_layouts: &[
|
||||
dst_bind_group_layout.as_ref(),
|
||||
src_bind_group_layout.as_ref(),
|
||||
],
|
||||
push_constant_ranges: &[wgt::PushConstantRange {
|
||||
stages: wgt::ShaderStages::COMPUTE,
|
||||
range: 0..4,
|
||||
}],
|
||||
};
|
||||
let pipeline_layout = unsafe {
|
||||
device
|
||||
.create_pipeline_layout(&pipeline_layout_desc)
|
||||
.map_err(DeviceError::from_hal)?
|
||||
};
|
||||
|
||||
let pipeline_desc = hal::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: pipeline_layout.as_ref(),
|
||||
stage: hal::ProgrammableStage {
|
||||
module: module.as_ref(),
|
||||
entry_point: "main",
|
||||
constants: &Default::default(),
|
||||
zero_initialize_workgroup_memory: false,
|
||||
},
|
||||
cache: None,
|
||||
};
|
||||
let pipeline =
|
||||
unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
|
||||
hal::PipelineError::Device(error) => {
|
||||
CreateComputePipelineError::Device(DeviceError::from_hal(error))
|
||||
}
|
||||
hal::PipelineError::Linkage(_stages, msg) => {
|
||||
CreateComputePipelineError::Internal(msg)
|
||||
}
|
||||
hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
|
||||
crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
|
||||
),
|
||||
hal::PipelineError::PipelineConstants(_, error) => {
|
||||
CreateComputePipelineError::PipelineConstants(error)
|
||||
}
|
||||
})?;
|
||||
|
||||
let dst_buffer_desc = hal::BufferDescriptor {
|
||||
label: None,
|
||||
size: 4 * 3,
|
||||
usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE,
|
||||
memory_flags: hal::MemoryFlags::empty(),
|
||||
};
|
||||
let dst_buffer =
|
||||
unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?;
|
||||
|
||||
let dst_bind_group_desc = hal::BindGroupDescriptor {
|
||||
label: None,
|
||||
layout: dst_bind_group_layout.as_ref(),
|
||||
entries: &[hal::BindGroupEntry {
|
||||
binding: 0,
|
||||
resource_index: 0,
|
||||
count: 1,
|
||||
}],
|
||||
buffers: &[hal::BufferBinding {
|
||||
buffer: dst_buffer.as_ref(),
|
||||
offset: 0,
|
||||
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
|
||||
}],
|
||||
samplers: &[],
|
||||
textures: &[],
|
||||
acceleration_structures: &[],
|
||||
};
|
||||
let dst_bind_group = unsafe {
|
||||
device
|
||||
.create_bind_group(&dst_bind_group_desc)
|
||||
.map_err(DeviceError::from_hal)
|
||||
}?;
|
||||
|
||||
Ok(Self {
|
||||
module,
|
||||
dst_bind_group_layout,
|
||||
src_bind_group_layout,
|
||||
pipeline_layout,
|
||||
pipeline,
|
||||
dst_buffer,
|
||||
dst_bind_group,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_src_bind_group(
|
||||
&self,
|
||||
device: &dyn hal::DynDevice,
|
||||
limits: &wgt::Limits,
|
||||
buffer_size: u64,
|
||||
buffer: &dyn hal::DynBuffer,
|
||||
) -> Result<Box<dyn hal::DynBindGroup>, DeviceError> {
|
||||
let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
|
||||
let hal_desc = hal::BindGroupDescriptor {
|
||||
label: None,
|
||||
layout: self.src_bind_group_layout.as_ref(),
|
||||
entries: &[hal::BindGroupEntry {
|
||||
binding: 0,
|
||||
resource_index: 0,
|
||||
count: 1,
|
||||
}],
|
||||
buffers: &[hal::BufferBinding {
|
||||
buffer,
|
||||
offset: 0,
|
||||
size: Some(std::num::NonZeroU64::new(binding_size).unwrap()),
|
||||
}],
|
||||
samplers: &[],
|
||||
textures: &[],
|
||||
acceleration_structures: &[],
|
||||
};
|
||||
unsafe {
|
||||
device
|
||||
.create_bind_group(&hal_desc)
|
||||
.map_err(DeviceError::from_hal)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn params<'a>(&'a self, limits: &wgt::Limits, offset: u64, buffer_size: u64) -> Params<'a> {
|
||||
// The offset we receive is only required to be aligned to 4 bytes.
|
||||
//
|
||||
// Binding offsets and dynamic offsets are required to be aligned to
|
||||
// min_storage_buffer_offset_alignment (256 bytes by default).
|
||||
//
|
||||
// So, we work around this limitation by calculating an aligned offset
|
||||
// and pass the remainder through a push constant.
|
||||
//
|
||||
// We could bind the whole buffer and only have to pass the offset
|
||||
// through a push constant but we might run into the
|
||||
// max_storage_buffer_binding_size limit.
|
||||
//
|
||||
// See the inner docs of `calculate_src_buffer_binding_size` to
|
||||
// see how we get the appropriate `binding_size`.
|
||||
let alignment = limits.min_storage_buffer_offset_alignment as u64;
|
||||
let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
|
||||
let aligned_offset = offset - offset % alignment;
|
||||
// This works because `binding_size` is either `buffer_size` or `alignment * 2 + buffer_size % alignment`.
|
||||
let max_aligned_offset = buffer_size - binding_size;
|
||||
let aligned_offset = aligned_offset.min(max_aligned_offset);
|
||||
let offset_remainder = offset - aligned_offset;
|
||||
|
||||
Params {
|
||||
pipeline_layout: self.pipeline_layout.as_ref(),
|
||||
pipeline: self.pipeline.as_ref(),
|
||||
dst_buffer: self.dst_buffer.as_ref(),
|
||||
dst_bind_group: self.dst_bind_group.as_ref(),
|
||||
aligned_offset,
|
||||
offset_remainder,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dispose(self, device: &dyn hal::DynDevice) {
|
||||
let IndirectValidation {
|
||||
module,
|
||||
dst_bind_group_layout,
|
||||
src_bind_group_layout,
|
||||
pipeline_layout,
|
||||
pipeline,
|
||||
dst_buffer,
|
||||
dst_bind_group,
|
||||
} = self;
|
||||
|
||||
unsafe {
|
||||
device.destroy_bind_group(dst_bind_group);
|
||||
device.destroy_buffer(dst_buffer);
|
||||
device.destroy_compute_pipeline(pipeline);
|
||||
device.destroy_pipeline_layout(pipeline_layout);
|
||||
device.destroy_bind_group_layout(src_bind_group_layout);
|
||||
device.destroy_bind_group_layout(dst_bind_group_layout);
|
||||
device.destroy_shader_module(module);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &wgt::Limits) -> u64 {
|
||||
let alignment = limits.min_storage_buffer_offset_alignment as u64;
|
||||
|
||||
// We need to choose a binding size that can address all possible sets of 12 contiguous bytes in the buffer taking
|
||||
// into account that the dynamic offset needs to be a multiple of `min_storage_buffer_offset_alignment`.
|
||||
|
||||
// Given the know variables: `offset`, `buffer_size`, `alignment` and the rule `offset + 12 <= buffer_size`.
|
||||
|
||||
// Let `chunks = floor(buffer_size / alignment)`.
|
||||
// Let `chunk` be the interval `[0, chunks]`.
|
||||
// Let `offset = alignment * chunk + r` where `r` is the interval [0, alignment - 4].
|
||||
// Let `binding` be the interval `[offset, offset + 12]`.
|
||||
// Let `aligned_offset = alignment * chunk`.
|
||||
// Let `aligned_binding` be the interval `[aligned_offset, aligned_offset + r + 12]`.
|
||||
// Let `aligned_binding_size = r + 12 = [12, alignment + 8]`.
|
||||
// Let `min_aligned_binding_size = alignment + 8`.
|
||||
|
||||
// `min_aligned_binding_size` is the minimum binding size required to address all 12 contiguous bytes in the buffer
|
||||
// but the last aligned_offset + min_aligned_binding_size might overflow the buffer. In order to avoid this we must
|
||||
// pick a larger `binding_size` that satisfies: `last_aligned_offset + binding_size = buffer_size` and
|
||||
// `binding_size >= min_aligned_binding_size`.
|
||||
|
||||
// Let `buffer_size = alignment * chunks + sr` where `sr` is the interval [0, alignment - 4].
|
||||
// Let `last_aligned_offset = alignment * (chunks - u)` where `u` is the interval [0, chunks].
|
||||
// => `binding_size = buffer_size - last_aligned_offset`
|
||||
// => `binding_size = alignment * chunks + sr - alignment * (chunks - u)`
|
||||
// => `binding_size = alignment * chunks + sr - alignment * chunks + alignment * u`
|
||||
// => `binding_size = sr + alignment * u`
|
||||
// => `min_aligned_binding_size <= sr + alignment * u`
|
||||
// => `alignment + 8 <= sr + alignment * u`
|
||||
// => `u` must be at least 2
|
||||
// => `binding_size = sr + alignment * 2`
|
||||
|
||||
let binding_size = 2 * alignment + (buffer_size % alignment);
|
||||
binding_size.min(buffer_size)
|
||||
}
|
@ -67,6 +67,8 @@ mod hash_utils;
|
||||
pub mod hub;
|
||||
pub mod id;
|
||||
pub mod identity;
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
mod indirect_validation;
|
||||
mod init_tracker;
|
||||
pub mod instance;
|
||||
mod lock;
|
||||
|
@ -92,7 +92,7 @@ impl ShaderModule {
|
||||
#[derive(Clone, Debug, Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum CreateShaderModuleError {
|
||||
#[cfg(feature = "wgsl")]
|
||||
#[cfg(any(feature = "wgsl", feature = "indirect-validation"))]
|
||||
#[error(transparent)]
|
||||
Parsing(#[from] ShaderError<naga::front::wgsl::ParseError>),
|
||||
#[cfg(feature = "glsl")]
|
||||
|
@ -475,10 +475,18 @@ pub struct Buffer {
|
||||
pub(crate) tracking_data: TrackingData,
|
||||
pub(crate) map_state: Mutex<BufferMapState>,
|
||||
pub(crate) bind_groups: Mutex<Vec<Weak<BindGroup>>>,
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
pub(crate) raw_indirect_validation_bind_group: Snatchable<Box<dyn hal::DynBindGroup>>,
|
||||
}
|
||||
|
||||
impl Drop for Buffer {
|
||||
fn drop(&mut self) {
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
if let Some(raw) = self.raw_indirect_validation_bind_group.take() {
|
||||
unsafe {
|
||||
self.device.raw().destroy_bind_group(raw);
|
||||
}
|
||||
}
|
||||
if let Some(raw) = self.raw.take() {
|
||||
resource_log!("Destroy raw {}", self.error_ident());
|
||||
unsafe {
|
||||
@ -737,13 +745,22 @@ impl Buffer {
|
||||
let device = &self.device;
|
||||
|
||||
let temp = {
|
||||
let raw = match self.raw.snatch(&mut device.snatchable_lock.write()) {
|
||||
let mut snatch_guard = device.snatchable_lock.write();
|
||||
|
||||
let raw = match self.raw.snatch(&mut snatch_guard) {
|
||||
Some(raw) => raw,
|
||||
None => {
|
||||
return Err(DestroyError::AlreadyDestroyed);
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
let raw_indirect_validation_bind_group = self
|
||||
.raw_indirect_validation_bind_group
|
||||
.snatch(&mut snatch_guard);
|
||||
|
||||
drop(snatch_guard);
|
||||
|
||||
let bind_groups = {
|
||||
let mut guard = self.bind_groups.lock();
|
||||
mem::take(&mut *guard)
|
||||
@ -754,6 +771,8 @@ impl Buffer {
|
||||
device: Arc::clone(&self.device),
|
||||
label: self.label().to_owned(),
|
||||
bind_groups,
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
raw_indirect_validation_bind_group,
|
||||
})
|
||||
};
|
||||
|
||||
@ -789,6 +808,8 @@ pub enum CreateBufferError {
|
||||
MaxBufferSize { requested: u64, maximum: u64 },
|
||||
#[error(transparent)]
|
||||
MissingDownlevelFlags(#[from] MissingDownlevelFlags),
|
||||
#[error("Failed to create bind group for indirect buffer validation: {0}")]
|
||||
IndirectValidationBindGroup(DeviceError),
|
||||
}
|
||||
|
||||
crate::impl_resource_type!(Buffer);
|
||||
@ -804,6 +825,8 @@ pub struct DestroyedBuffer {
|
||||
device: Arc<Device>,
|
||||
label: String,
|
||||
bind_groups: Vec<Weak<BindGroup>>,
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
raw_indirect_validation_bind_group: Option<Box<dyn hal::DynBindGroup>>,
|
||||
}
|
||||
|
||||
impl DestroyedBuffer {
|
||||
@ -820,6 +843,13 @@ impl Drop for DestroyedBuffer {
|
||||
}
|
||||
drop(deferred);
|
||||
|
||||
#[cfg(feature = "indirect-validation")]
|
||||
if let Some(raw) = self.raw_indirect_validation_bind_group.take() {
|
||||
unsafe {
|
||||
self.device.raw().destroy_bind_group(raw);
|
||||
}
|
||||
}
|
||||
|
||||
resource_log!("Destroy raw Buffer (destroyed) {:?}", self.label());
|
||||
// SAFETY: We are in the Drop impl and we don't use self.raw anymore after this point.
|
||||
let raw = unsafe { ManuallyDrop::take(&mut self.raw) };
|
||||
|
@ -32,6 +32,12 @@ impl<T> Snatchable<T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty() -> Self {
|
||||
Snatchable {
|
||||
value: UnsafeCell::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get read access to the value. Requires a the snatchable lock's read guard.
|
||||
pub fn get<'a>(&'a self, _guard: &'a SnatchGuard) -> Option<&'a T> {
|
||||
unsafe { (*self.value.get()).as_ref() }
|
||||
|
@ -130,6 +130,12 @@ features = ["raw-window-handle"]
|
||||
workspace = true
|
||||
features = ["raw-window-handle"]
|
||||
|
||||
# If we are not targeting WebGL, enable indirect-validation.
|
||||
# WebGL doesn't support indirect execution so this is not needed.
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.wgc]
|
||||
workspace = true
|
||||
features = ["indirect-validation"]
|
||||
|
||||
# Enable `wgc` by default on macOS and iOS to allow the `metal` crate feature to
|
||||
# enable the Metal backend while being no-op on other targets.
|
||||
[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies.wgc]
|
||||
|
Loading…
Reference in New Issue
Block a user