Rework acceleration structure build tracking.

This commit is contained in:
Vecvec 2025-04-10 13:31:00 +12:00 committed by Connor Fitzgerald
parent 382a1e3c9b
commit 8010203281
8 changed files with 187 additions and 169 deletions

View File

@ -295,6 +295,7 @@ By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326).
- Reduce downlevel `max_color_attachments` limit from 8 to 4 for better GLES compatibility. By @adrian17 in [#6994](https://github.com/gfx-rs/wgpu/pull/6994).
- Fix building a BLAS with a transform buffer by adding a flag to indicate usage of the transform buffer. By @Vecvec in
[#7062](https://github.com/gfx-rs/wgpu/pull/7062).
- Move incrementation of `Device::last_acceleration_structure_build_command_index` into queue submit. By @Vecvec in [#7462](https://github.com/gfx-rs/wgpu/pull/7462).
#### Vulkan

View File

@ -234,6 +234,80 @@ fn out_of_order_as_build_use(ctx: TestingContext) {
},
None,
);
let as_ctx = AsBuildContext::new(
&ctx,
AccelerationStructureFlags::empty(),
AccelerationStructureFlags::empty(),
);
//
// Build in the right order, then rebuild the BLAS so the TLAS is invalid, then use the TLAS.
//
let mut encoder_blas = ctx
.device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("BLAS 3"),
});
encoder_blas.build_acceleration_structures([&as_ctx.blas_build_entry()], []);
let mut encoder_blas2 = ctx
.device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("BLAS 4"),
});
encoder_blas2.build_acceleration_structures([&as_ctx.blas_build_entry()], []);
let mut encoder_tlas = ctx
.device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("TLAS 2"),
});
encoder_tlas.build_acceleration_structures([], [&as_ctx.tlas_package]);
ctx.queue.submit([
encoder_blas.finish(),
encoder_tlas.finish(),
encoder_blas2.finish(),
]);
let bind_group = ctx.device.create_bind_group(&BindGroupDescriptor {
label: None,
layout: &compute_pipeline.get_bind_group_layout(0),
entries: &[BindGroupEntry {
binding: 0,
resource: BindingResource::AccelerationStructure(as_ctx.tlas_package.tlas()),
}],
});
//
// Use TLAS
//
let mut encoder_compute = ctx
.device
.create_command_encoder(&CommandEncoderDescriptor::default());
{
let mut pass = encoder_compute.begin_compute_pass(&ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
pass.set_pipeline(&compute_pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(1, 1, 1)
}
fail(
&ctx.device,
|| {
ctx.queue.submit(Some(encoder_compute.finish()));
},
None,
);
}
#[gpu_test]

View File

@ -4,6 +4,7 @@ use wgt::{BufferAddress, DynamicOffset};
use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
use core::{fmt, str};
use crate::ray_tracing::AsAction;
use crate::{
binding_model::{
BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
@ -24,7 +25,6 @@ use crate::{
hal_label, id,
init_tracker::{BufferInitTrackerAction, MemoryInitKind},
pipeline::ComputePipeline,
ray_tracing::TlasAction,
resource::{
self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
MissingBufferUsageError, ParentDevice,
@ -208,7 +208,7 @@ struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
tracker: &'cmd_buf mut Tracker,
buffer_memory_init_actions: &'cmd_buf mut Vec<BufferInitTrackerAction>,
texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions,
tlas_actions: &'cmd_buf mut Vec<TlasAction>,
as_actions: &'cmd_buf mut Vec<AsAction>,
temp_offsets: Vec<u32>,
dynamic_offset_count: usize,
@ -433,7 +433,7 @@ impl Global {
tracker: &mut cmd_buf_data.trackers,
buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
tlas_actions: &mut cmd_buf_data.tlas_actions,
as_actions: &mut cmd_buf_data.as_actions,
temp_offsets: Vec::new(),
dynamic_offset_count: 0,
@ -680,12 +680,9 @@ fn set_bind_group(
.used
.acceleration_structures
.into_iter()
.map(|tlas| TlasAction {
tlas: tlas.clone(),
kind: crate::ray_tracing::TlasActionKind::Use,
});
.map(|tlas| AsAction::UseTlas(tlas.clone()));
state.tlas_actions.extend(used_resource);
state.as_actions.extend(used_resource);
let pipeline_layout = state.binder.pipeline_layout.clone();
let entries = state

View File

@ -36,7 +36,7 @@ use crate::lock::{rank, Mutex};
use crate::snatch::SnatchGuard;
use crate::init_tracker::BufferInitTrackerAction;
use crate::ray_tracing::{BlasAction, TlasAction};
use crate::ray_tracing::AsAction;
use crate::resource::{Fallible, InvalidResourceError, Labeled, ParentDevice as _, QuerySet};
use crate::storage::Storage;
use crate::track::{DeviceTracker, Tracker, UsageScope};
@ -463,8 +463,7 @@ pub struct CommandBufferMutable {
pub(crate) pending_query_resets: QueryResetMap,
blas_actions: Vec<BlasAction>,
tlas_actions: Vec<TlasAction>,
as_actions: Vec<AsAction>,
temp_resources: Vec<TempResource>,
indirect_draw_validation_resources: crate::indirect_validation::DrawResources,
@ -553,8 +552,7 @@ impl CommandBuffer {
buffer_memory_init_actions: Default::default(),
texture_memory_actions: Default::default(),
pending_query_resets: QueryResetMap::new(),
blas_actions: Default::default(),
tlas_actions: Default::default(),
as_actions: Default::default(),
temp_resources: Default::default(),
indirect_draw_validation_resources:
crate::indirect_validation::DrawResources::new(device.clone()),

View File

@ -3,11 +3,13 @@ use core::{
cmp::max,
num::NonZeroU64,
ops::{Deref, Range},
sync::atomic::Ordering,
};
use wgt::{math::align_to, BufferUsages, BufferUses, Features};
use crate::device::resource::CommandIndices;
use crate::lock::RwLockWriteGuard;
use crate::ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError};
use crate::{
command::CommandBufferMutable,
device::queue::TempResource,
@ -16,16 +18,14 @@ use crate::{
id::CommandEncoderId,
init_tracker::MemoryInitKind,
ray_tracing::{
BlasAction, BlasBuildEntry, BlasGeometries, BlasTriangleGeometry,
BuildAccelerationStructureError, TlasAction, TlasBuildEntry, TlasInstance, TlasPackage,
TraceBlasBuildEntry, TraceBlasGeometries, TraceBlasTriangleGeometry, TraceTlasInstance,
TraceTlasPackage, ValidateBlasActionsError, ValidateTlasActionsError,
BlasBuildEntry, BlasGeometries, BlasTriangleGeometry, BuildAccelerationStructureError,
TlasBuildEntry, TlasInstance, TlasPackage, TraceBlasBuildEntry, TraceBlasGeometries,
TraceBlasTriangleGeometry, TraceTlasInstance, TraceTlasPackage,
},
resource::{AccelerationStructure, Blas, Buffer, Labeled, StagingBuffer, Tlas, Trackable},
resource::{AccelerationStructure, Blas, Buffer, Labeled, StagingBuffer, Tlas},
scratch::ScratchBuffer,
snatch::SnatchGuard,
track::PendingTransition,
FastHashSet,
};
use crate::id::{BlasId, TlasId};
@ -81,39 +81,28 @@ impl Global {
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
let build_command_index = NonZeroU64::new(
device
.last_acceleration_structure_build_command_index
.fetch_add(1, Ordering::Relaxed),
)
.unwrap();
let mut build_command = AsBuild::default();
for blas in blas_ids {
let blas = hub.blas_s.get(*blas).get()?;
build_command.blas_s_built.push(blas);
}
for tlas in tlas_ids {
let tlas = hub.tlas_s.get(*tlas).get()?;
build_command.tlas_s_built.push(TlasBuild {
tlas,
dependencies: Vec::new(),
});
}
let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?;
let cmd_buf_data = &mut *cmd_buf_data_guard;
cmd_buf_data.blas_actions.reserve(blas_ids.len());
cmd_buf_data.as_actions.push(AsAction::Build(build_command));
cmd_buf_data.tlas_actions.reserve(tlas_ids.len());
for blas in blas_ids {
let blas = hub.blas_s.get(*blas).get()?;
cmd_buf_data.blas_actions.push(BlasAction {
blas,
kind: crate::ray_tracing::BlasActionKind::Build(build_command_index),
});
}
for tlas in tlas_ids {
let tlas = hub.tlas_s.get(*tlas).get()?;
cmd_buf_data.tlas_actions.push(TlasAction {
tlas,
kind: crate::ray_tracing::TlasActionKind::Build {
build_index: build_command_index,
dependencies: Vec::new(),
},
});
}
cmd_buf_data_guard.mark_successful();
Ok(())
}
@ -139,12 +128,7 @@ impl Global {
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
let build_command_index = NonZeroU64::new(
device
.last_acceleration_structure_build_command_index
.fetch_add(1, Ordering::Relaxed),
)
.unwrap();
let mut build_command = AsBuild::default();
#[cfg(feature = "trace")]
let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
@ -227,7 +211,7 @@ impl Global {
iter_blas(
blas_iter,
cmd_buf_data,
build_command_index,
&mut build_command,
&mut buf_storage,
hub,
)?;
@ -281,12 +265,9 @@ impl Global {
let tlas = hub.tlas_s.get(entry.tlas_id).get()?;
cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
cmd_buf_data.tlas_actions.push(TlasAction {
build_command.tlas_s_built.push(TlasBuild {
tlas: tlas.clone(),
kind: crate::ray_tracing::TlasActionKind::Build {
build_index: build_command_index,
dependencies: Vec::new(),
},
dependencies: Vec::new(),
});
let scratch_buffer_offset = scratch_buffer_tlas_size;
@ -388,6 +369,8 @@ impl Global {
.temp_resources
.push(TempResource::ScratchBuffer(scratch_buffer));
cmd_buf_data.as_actions.push(AsAction::Build(build_command));
cmd_buf_data_guard.mark_successful();
Ok(())
}
@ -410,12 +393,7 @@ impl Global {
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
let build_command_index = NonZeroU64::new(
device
.last_acceleration_structure_build_command_index
.fetch_add(1, Ordering::Relaxed),
)
.unwrap();
let mut build_command = AsBuild::default();
let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
.map(|blas_entry| {
@ -523,7 +501,7 @@ impl Global {
iter_blas(
blas_iter,
cmd_buf_data,
build_command_index,
&mut build_command,
&mut buf_storage,
hub,
)?;
@ -604,19 +582,11 @@ impl Global {
instance_count += 1;
dependencies.push(blas.clone());
cmd_buf_data.blas_actions.push(BlasAction {
blas,
kind: crate::ray_tracing::BlasActionKind::Use,
});
}
cmd_buf_data.tlas_actions.push(TlasAction {
build_command.tlas_s_built.push(TlasBuild {
tlas: tlas.clone(),
kind: crate::ray_tracing::TlasActionKind::Build {
build_index: build_command_index,
dependencies,
},
dependencies,
});
if instance_count > tlas.max_instance_count {
@ -800,72 +770,69 @@ impl Global {
.temp_resources
.push(TempResource::ScratchBuffer(scratch_buffer));
cmd_buf_data.as_actions.push(AsAction::Build(build_command));
cmd_buf_data_guard.mark_successful();
Ok(())
}
}
impl CommandBufferMutable {
// makes sure a blas is build before it is used
pub(crate) fn validate_blas_actions(&self) -> Result<(), ValidateBlasActionsError> {
profiling::scope!("CommandEncoder::[submission]::validate_blas_actions");
let mut built = FastHashSet::default();
for action in &self.blas_actions {
match &action.kind {
crate::ray_tracing::BlasActionKind::Build(id) => {
built.insert(action.blas.tracker_index());
*action.blas.built_index.write() = Some(*id);
}
crate::ray_tracing::BlasActionKind::Use => {
if !built.contains(&action.blas.tracker_index())
&& (*action.blas.built_index.read()).is_none()
{
return Err(ValidateBlasActionsError::UsedUnbuilt(
action.blas.error_ident(),
));
}
}
}
}
Ok(())
}
// makes sure a tlas is built before it is used
pub(crate) fn validate_tlas_actions(
pub(crate) fn validate_acceleration_structure_actions(
&self,
snatch_guard: &SnatchGuard,
) -> Result<(), ValidateTlasActionsError> {
profiling::scope!("CommandEncoder::[submission]::validate_tlas_actions");
for action in &self.tlas_actions {
match &action.kind {
crate::ray_tracing::TlasActionKind::Build {
build_index,
dependencies,
} => {
*action.tlas.built_index.write() = Some(*build_index);
action.tlas.dependencies.write().clone_from(dependencies);
command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
) -> Result<(), ValidateAsActionsError> {
profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
for action in &self.as_actions {
match action {
AsAction::Build(build) => {
let build_command_index = NonZeroU64::new(
command_index_guard.next_acceleration_structure_build_command_index,
)
.unwrap();
command_index_guard.next_acceleration_structure_build_command_index += 1;
for blas in build.blas_s_built.iter() {
*blas.built_index.write() = Some(build_command_index);
}
for tlas_build in build.tlas_s_built.iter() {
for blas in &tlas_build.dependencies {
if blas.built_index.read().is_none() {
return Err(ValidateAsActionsError::UsedUnbuiltBlas(
blas.error_ident(),
tlas_build.tlas.error_ident(),
));
}
}
*tlas_build.tlas.built_index.write() = Some(build_command_index);
tlas_build
.tlas
.dependencies
.write()
.clone_from(&tlas_build.dependencies)
}
}
crate::ray_tracing::TlasActionKind::Use => {
let tlas_build_index = action.tlas.built_index.read();
let dependencies = action.tlas.dependencies.read();
AsAction::UseTlas(tlas) => {
let tlas_build_index = tlas.built_index.read();
let dependencies = tlas.dependencies.read();
if (*tlas_build_index).is_none() {
return Err(ValidateTlasActionsError::UsedUnbuilt(
action.tlas.error_ident(),
));
return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
}
for blas in dependencies.deref() {
let blas_build_index = *blas.built_index.read();
if blas_build_index.is_none() {
return Err(ValidateTlasActionsError::UsedUnbuiltBlas(
action.tlas.error_ident(),
return Err(ValidateAsActionsError::UsedUnbuiltBlas(
tlas.error_ident(),
blas.error_ident(),
));
}
if blas_build_index.unwrap() > tlas_build_index.unwrap() {
return Err(ValidateTlasActionsError::BlasNewerThenTlas(
return Err(ValidateAsActionsError::BlasNewerThenTlas(
blas.error_ident(),
action.tlas.error_ident(),
tlas.error_ident(),
));
}
blas.try_raw(snatch_guard)?;
@ -881,7 +848,7 @@ impl CommandBufferMutable {
fn iter_blas<'a>(
blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
cmd_buf_data: &mut CommandBufferMutable,
build_command_index: NonZeroU64,
build_command: &mut AsBuild,
buf_storage: &mut Vec<TriangleBufferStore<'a>>,
hub: &Hub,
) -> Result<(), BuildAccelerationStructureError> {
@ -890,10 +857,7 @@ fn iter_blas<'a>(
let blas = hub.blas_s.get(entry.blas_id).get()?;
cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
cmd_buf_data.blas_actions.push(BlasAction {
blas: blas.clone(),
kind: crate::ray_tracing::BlasActionKind::Build(build_command_index),
});
build_command.blas_s_built.push(blas.clone());
match entry.geometries {
BlasGeometries::TriangleGeometries(triangle_geometries) => {

View File

@ -10,6 +10,7 @@ use smallvec::SmallVec;
use thiserror::Error;
use super::{life::LifetimeTracker, Device};
use crate::device::resource::CommandIndices;
#[cfg(feature = "trace")]
use crate::device::trace::Action;
use crate::scratch::ScratchBuffer;
@ -447,9 +448,7 @@ pub enum QueueSubmitError {
#[error(transparent)]
CommandEncoder(#[from] CommandEncoderError),
#[error(transparent)]
ValidateBlasActionsError(#[from] crate::ray_tracing::ValidateBlasActionsError),
#[error(transparent)]
ValidateTlasActionsError(#[from] crate::ray_tracing::ValidateTlasActionsError),
ValidateAsActionsError(#[from] crate::ray_tracing::ValidateAsActionsError),
}
//TODO: move out common parts of write_xxx.
@ -1126,6 +1125,7 @@ impl Queue {
&snatch_guard,
&mut submit_surface_textures_owned,
&mut used_surface_textures,
&mut command_index_guard,
);
if let Err(err) = res {
first_error.get_or_insert(err);
@ -1518,6 +1518,7 @@ fn validate_command_buffer(
snatch_guard: &SnatchGuard,
submit_surface_textures_owned: &mut FastHashMap<*const Texture, Arc<Texture>>,
used_surface_textures: &mut track::TextureUsageScope,
command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
) -> Result<(), QueueSubmitError> {
command_buffer.same_device_as(queue)?;
@ -1557,10 +1558,9 @@ fn validate_command_buffer(
}
}
if let Err(e) = cmd_buf_data.validate_blas_actions() {
return Err(e.into());
}
if let Err(e) = cmd_buf_data.validate_tlas_actions(snatch_guard) {
if let Err(e) =
cmd_buf_data.validate_acceleration_structure_actions(snatch_guard, command_index_guard)
{
return Err(e.into());
}
}

View File

@ -71,6 +71,7 @@ pub(crate) struct CommandIndices {
///
/// [`last_successful_submission_index`]: Device::last_successful_submission_index
pub(crate) active_submission_index: hal::FenceValue,
pub(crate) next_acceleration_structure_build_command_index: u64,
}
/// Structure describing a logical device. Some members are internally mutable,
@ -133,7 +134,6 @@ pub struct Device {
pub(crate) instance_flags: wgt::InstanceFlags,
pub(crate) deferred_destroy: Mutex<Vec<DeferredDestroy>>,
pub(crate) usage_scopes: UsageScopePool,
pub(crate) last_acceleration_structure_build_command_index: AtomicU64,
pub(crate) indirect_validation: Option<crate::indirect_validation::IndirectValidation>,
// Optional so that we can late-initialize this after the queue is created.
pub(crate) timestamp_normalizer:
@ -284,6 +284,8 @@ impl Device {
rank::DEVICE_COMMAND_INDICES,
CommandIndices {
active_submission_index: 0,
// By starting at one, we can put the result in a NonZeroU64.
next_acceleration_structure_build_command_index: 1,
},
),
last_successful_submission_index: AtomicU64::new(0),
@ -321,8 +323,6 @@ impl Device {
instance_flags,
deferred_destroy: Mutex::new(rank::DEVICE_DEFERRED_DESTROY, Vec::new()),
usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()),
// By starting at one, we can put the result in a NonZeroU64.
last_acceleration_structure_build_command_index: AtomicU64::new(1),
timestamp_normalizer: OnceCellOrLock::new(),
indirect_validation,
})

View File

@ -8,7 +8,6 @@
// - ([non performance] extract function in build (rust function extraction with guards is a pain))
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use core::num::NonZeroU64;
use thiserror::Error;
use wgt::{AccelerationStructureGeometryFlags, BufferAddress, IndexFormat, VertexFormat};
@ -137,18 +136,12 @@ pub enum BuildAccelerationStructureError {
}
#[derive(Clone, Debug, Error)]
pub enum ValidateBlasActionsError {
#[error("Blas {0:?} is used before it is built")]
UsedUnbuilt(ResourceErrorIdent),
}
#[derive(Clone, Debug, Error)]
pub enum ValidateTlasActionsError {
pub enum ValidateAsActionsError {
#[error(transparent)]
DestroyedResource(#[from] DestroyedResourceError),
#[error("Tlas {0:?} is used before it is built")]
UsedUnbuilt(ResourceErrorIdent),
UsedUnbuiltTlas(ResourceErrorIdent),
#[error("Blas {0:?} is used before it is built (in Tlas {1:?})")]
UsedUnbuiltBlas(ResourceErrorIdent, ResourceErrorIdent),
@ -200,31 +193,22 @@ pub struct TlasPackage<'a> {
pub lowest_unmodified: u32,
}
#[derive(Debug, Copy, Clone)]
pub(crate) enum BlasActionKind {
Build(NonZeroU64),
Use,
}
#[derive(Debug, Clone)]
pub(crate) enum TlasActionKind {
Build {
build_index: NonZeroU64,
dependencies: Vec<Arc<Blas>>,
},
Use,
}
#[derive(Debug, Clone)]
pub(crate) struct BlasAction {
pub blas: Arc<Blas>,
pub kind: BlasActionKind,
}
#[derive(Debug, Clone)]
pub(crate) struct TlasAction {
pub(crate) struct TlasBuild {
pub tlas: Arc<Tlas>,
pub kind: TlasActionKind,
pub dependencies: Vec<Arc<Blas>>,
}
#[derive(Debug, Clone, Default)]
pub(crate) struct AsBuild {
pub blas_s_built: Vec<Arc<Blas>>,
pub tlas_s_built: Vec<TlasBuild>,
}
#[derive(Debug, Clone)]
pub(crate) enum AsAction {
Build(AsBuild),
UseTlas(Arc<Tlas>),
}
#[derive(Debug, Clone)]