[wgpu-core/-hal] move raytracing alignments into hal (#6563)

This commit is contained in:
Vecvec 2024-11-19 23:57:48 +13:00 committed by GitHub
parent 2389106a75
commit 6f5014f0a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 102 additions and 64 deletions

View File

@ -104,6 +104,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148]
#### General #### General
- Return submission index in `map_async` and `on_submitted_work_done` to track down completion of async callbacks. By @eliemichel in [#6360](https://github.com/gfx-rs/wgpu/pull/6360). - Return submission index in `map_async` and `on_submitted_work_done` to track down completion of async callbacks. By @eliemichel in [#6360](https://github.com/gfx-rs/wgpu/pull/6360).
- Move raytracing alignments into HAL instead of in core. By @Vecvec in [#6563](https://github.com/gfx-rs/wgpu/pull/6563).
### Changes ### Changes

View File

@ -5,7 +5,7 @@ use crate::{
id::CommandEncoderId, id::CommandEncoderId,
init_tracker::MemoryInitKind, init_tracker::MemoryInitKind,
ray_tracing::{ ray_tracing::{
tlas_instance_into_bytes, BlasAction, BlasBuildEntry, BlasGeometries, BlasTriangleGeometry, BlasAction, BlasBuildEntry, BlasGeometries, BlasTriangleGeometry,
BuildAccelerationStructureError, TlasAction, TlasBuildEntry, TlasInstance, TlasPackage, BuildAccelerationStructureError, TlasAction, TlasBuildEntry, TlasInstance, TlasPackage,
TraceBlasBuildEntry, TraceBlasGeometries, TraceBlasTriangleGeometry, TraceTlasInstance, TraceBlasBuildEntry, TraceBlasGeometries, TraceBlasTriangleGeometry, TraceTlasInstance,
TraceTlasPackage, ValidateBlasActionsError, ValidateTlasActionsError, TraceTlasPackage, ValidateBlasActionsError, ValidateTlasActionsError,
@ -60,9 +60,6 @@ struct TlasBufferStore {
entry: TlasBuildEntry, entry: TlasBuildEntry,
} }
// TODO: Get this from the device (e.g. VkPhysicalDeviceAccelerationStructurePropertiesKHR.minAccelerationStructureScratchOffsetAlignment) this is currently the largest possible some devices have 0, 64, 128 (lower limits) so this could create excess allocation (Note: dx12 has 256).
const SCRATCH_BUFFER_ALIGNMENT: u32 = 256;
impl Global { impl Global {
// Currently this function is very similar to its safe counterpart, however certain parts of it are very different, // Currently this function is very similar to its safe counterpart, however certain parts of it are very different,
// making for the two to be implemented differently, the main difference is this function has separate buffers for each // making for the two to be implemented differently, the main difference is this function has separate buffers for each
@ -193,6 +190,7 @@ impl Global {
&mut scratch_buffer_blas_size, &mut scratch_buffer_blas_size,
&mut blas_storage, &mut blas_storage,
hub, hub,
device.alignments.ray_tracing_scratch_buffer_alignment,
)?; )?;
let mut scratch_buffer_tlas_size = 0; let mut scratch_buffer_tlas_size = 0;
@ -260,7 +258,7 @@ impl Global {
let scratch_buffer_offset = scratch_buffer_tlas_size; let scratch_buffer_offset = scratch_buffer_tlas_size;
scratch_buffer_tlas_size += align_to( scratch_buffer_tlas_size += align_to(
tlas.size_info.build_scratch_size as u32, tlas.size_info.build_scratch_size as u32,
SCRATCH_BUFFER_ALIGNMENT, device.alignments.ray_tracing_scratch_buffer_alignment,
) as u64; ) as u64;
tlas_storage.push(UnsafeTlasStore { tlas_storage.push(UnsafeTlasStore {
@ -508,6 +506,7 @@ impl Global {
&mut scratch_buffer_blas_size, &mut scratch_buffer_blas_size,
&mut blas_storage, &mut blas_storage,
hub, hub,
device.alignments.ray_tracing_scratch_buffer_alignment,
)?; )?;
let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new(); let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
@ -535,7 +534,7 @@ impl Global {
let scratch_buffer_offset = scratch_buffer_tlas_size; let scratch_buffer_offset = scratch_buffer_tlas_size;
scratch_buffer_tlas_size += align_to( scratch_buffer_tlas_size += align_to(
tlas.size_info.build_scratch_size as u32, tlas.size_info.build_scratch_size as u32,
SCRATCH_BUFFER_ALIGNMENT, device.alignments.ray_tracing_scratch_buffer_alignment,
) as u64; ) as u64;
let first_byte_index = instance_buffer_staging_source.len(); let first_byte_index = instance_buffer_staging_source.len();
@ -558,10 +557,13 @@ impl Global {
cmd_buf_data.trackers.blas_s.set_single(blas.clone()); cmd_buf_data.trackers.blas_s.set_single(blas.clone());
instance_buffer_staging_source.extend(tlas_instance_into_bytes( instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
&instance, hal::TlasInstance {
blas.handle, transform: *instance.transform,
device.backend(), custom_index: instance.custom_index,
mask: instance.mask,
blas_address: blas.handle,
},
)); ));
instance_count += 1; instance_count += 1;
@ -1013,6 +1015,7 @@ fn iter_buffers<'a, 'b>(
scratch_buffer_blas_size: &mut u64, scratch_buffer_blas_size: &mut u64,
blas_storage: &mut Vec<BlasStore<'a>>, blas_storage: &mut Vec<BlasStore<'a>>,
hub: &Hub, hub: &Hub,
ray_tracing_scratch_buffer_alignment: u32,
) -> Result<(), BuildAccelerationStructureError> { ) -> Result<(), BuildAccelerationStructureError> {
let mut triangle_entries = let mut triangle_entries =
Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new(); Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
@ -1192,7 +1195,7 @@ fn iter_buffers<'a, 'b>(
let scratch_buffer_offset = *scratch_buffer_blas_size; let scratch_buffer_offset = *scratch_buffer_blas_size;
*scratch_buffer_blas_size += align_to( *scratch_buffer_blas_size += align_to(
blas.size_info.build_scratch_size as u32, blas.size_info.build_scratch_size as u32,
SCRATCH_BUFFER_ALIGNMENT, ray_tracing_scratch_buffer_alignment,
) as u64; ) as u64;
blas_storage.push(BlasStore { blas_storage.push(BlasStore {

View File

@ -12,7 +12,7 @@ use crate::{
global::Global, global::Global,
id::{self, BlasId, TlasId}, id::{self, BlasId, TlasId},
lock::RwLock, lock::RwLock,
ray_tracing::{get_raw_tlas_instance_size, CreateBlasError, CreateTlasError}, ray_tracing::{CreateBlasError, CreateTlasError},
resource, LabelHelpers, resource, LabelHelpers,
}; };
use hal::AccelerationStructureTriangleIndices; use hal::AccelerationStructureTriangleIndices;
@ -135,7 +135,7 @@ impl Device {
.map_err(DeviceError::from_hal)?; .map_err(DeviceError::from_hal)?;
let instance_buffer_size = let instance_buffer_size =
get_raw_tlas_instance_size(self.backend()) * desc.max_instances.max(1) as usize; self.alignments.raw_tlas_instance_size * desc.max_instances.max(1) as usize;
let instance_buffer = unsafe { let instance_buffer = unsafe {
self.raw().create_buffer(&hal::BufferDescriptor { self.raw().create_buffer(&hal::BufferDescriptor {
label: Some("(wgpu-core) instances_buffer"), label: Some("(wgpu-core) instances_buffer"),

View File

@ -13,8 +13,8 @@ use crate::{
id::{BlasId, BufferId, TlasId}, id::{BlasId, BufferId, TlasId},
resource::CreateBufferError, resource::CreateBufferError,
}; };
use std::{mem::size_of, sync::Arc}; use std::num::NonZeroU64;
use std::{num::NonZeroU64, slice}; use std::sync::Arc;
use crate::resource::{Blas, ResourceErrorIdent, Tlas}; use crate::resource::{Blas, ResourceErrorIdent, Tlas};
use thiserror::Error; use thiserror::Error;
@ -276,48 +276,3 @@ pub struct TraceTlasPackage {
pub instances: Vec<Option<TraceTlasInstance>>, pub instances: Vec<Option<TraceTlasInstance>>,
pub lowest_unmodified: u32, pub lowest_unmodified: u32,
} }
pub(crate) fn get_raw_tlas_instance_size(backend: wgt::Backend) -> usize {
// TODO: this should be provided by the backend
match backend {
wgt::Backend::Empty => 0,
wgt::Backend::Vulkan => 64,
_ => unimplemented!(),
}
}
#[derive(Clone)]
#[repr(C)]
struct RawTlasInstance {
transform: [f32; 12],
custom_index_and_mask: u32,
shader_binding_table_record_offset_and_flags: u32,
acceleration_structure_reference: u64,
}
pub(crate) fn tlas_instance_into_bytes(
instance: &TlasInstance,
blas_address: u64,
backend: wgt::Backend,
) -> Vec<u8> {
// TODO: get the device to do this
match backend {
wgt::Backend::Empty => vec![],
wgt::Backend::Vulkan => {
const MAX_U24: u32 = (1u32 << 24u32) - 1u32;
let temp = RawTlasInstance {
transform: *instance.transform,
custom_index_and_mask: (instance.custom_index & MAX_U24)
| (u32::from(instance.mask) << 24),
shader_binding_table_record_offset_and_flags: 0,
acceleration_structure_reference: blas_address,
};
let temp: *const _ = &temp;
unsafe {
slice::from_raw_parts::<u8>(temp.cast::<u8>(), size_of::<RawTlasInstance>())
.to_vec()
}
}
_ => unimplemented!(),
}
}

View File

@ -522,6 +522,8 @@ impl super::Adapter {
// Direct3D correctly bounds-checks all array accesses: // Direct3D correctly bounds-checks all array accesses:
// https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm#18.6.8.2%20Device%20Memory%20Reads // https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm#18.6.8.2%20Device%20Memory%20Reads
uniform_bounds_check_alignment: wgt::BufferSize::new(1).unwrap(), uniform_bounds_check_alignment: wgt::BufferSize::new(1).unwrap(),
raw_tlas_instance_size: 0,
ray_tracing_scratch_buffer_alignment: 0,
}, },
downlevel, downlevel,
}, },

View File

@ -21,6 +21,7 @@ use super::{conv, descriptor, D3D12Lib};
use crate::{ use crate::{
auxil::{self, dxgi::result::HResult}, auxil::{self, dxgi::result::HResult},
dx12::{borrow_optional_interface_temporarily, shader_compilation, Event}, dx12::{borrow_optional_interface_temporarily, shader_compilation, Event},
TlasInstance,
}; };
// this has to match Naga's HLSL backend, and also needs to be null-terminated // this has to match Naga's HLSL backend, and also needs to be null-terminated
@ -1939,4 +1940,8 @@ impl crate::Device for super::Device {
total_reserved_bytes: upstream.total_reserved_bytes, total_reserved_bytes: upstream.total_reserved_bytes,
}) })
} }
fn tlas_instance_to_bytes(&self, _instance: TlasInstance) -> Vec<u8> {
todo!()
}
} }

View File

@ -5,7 +5,7 @@ use crate::{
GetAccelerationStructureBuildSizesDescriptor, Label, MemoryRange, PipelineCacheDescriptor, GetAccelerationStructureBuildSizesDescriptor, Label, MemoryRange, PipelineCacheDescriptor,
PipelineCacheError, PipelineError, PipelineLayoutDescriptor, RenderPipelineDescriptor, PipelineCacheError, PipelineError, PipelineLayoutDescriptor, RenderPipelineDescriptor,
SamplerDescriptor, ShaderError, ShaderInput, ShaderModuleDescriptor, TextureDescriptor, SamplerDescriptor, ShaderError, ShaderInput, ShaderModuleDescriptor, TextureDescriptor,
TextureViewDescriptor, TextureViewDescriptor, TlasInstance,
}; };
use super::{ use super::{
@ -158,6 +158,7 @@ pub trait DynDevice: DynResource {
&self, &self,
acceleration_structure: Box<dyn DynAccelerationStructure>, acceleration_structure: Box<dyn DynAccelerationStructure>,
); );
fn tlas_instance_to_bytes(&self, instance: TlasInstance) -> Vec<u8>;
fn get_internal_counters(&self) -> wgt::HalCounters; fn get_internal_counters(&self) -> wgt::HalCounters;
fn generate_allocator_report(&self) -> Option<wgt::AllocatorReport>; fn generate_allocator_report(&self) -> Option<wgt::AllocatorReport>;
@ -520,6 +521,10 @@ impl<D: Device + DynResource> DynDevice for D {
unsafe { D::destroy_acceleration_structure(self, acceleration_structure.unbox()) } unsafe { D::destroy_acceleration_structure(self, acceleration_structure.unbox()) }
} }
fn tlas_instance_to_bytes(&self, instance: TlasInstance) -> Vec<u8> {
D::tlas_instance_to_bytes(self, instance)
}
fn get_internal_counters(&self) -> wgt::HalCounters { fn get_internal_counters(&self) -> wgt::HalCounters {
D::get_internal_counters(self) D::get_internal_counters(self)
} }

View File

@ -1,5 +1,6 @@
#![allow(unused_variables)] #![allow(unused_variables)]
use crate::TlasInstance;
use std::ops::Range; use std::ops::Range;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -306,6 +307,10 @@ impl crate::Device for Context {
} }
unsafe fn destroy_acceleration_structure(&self, _acceleration_structure: Resource) {} unsafe fn destroy_acceleration_structure(&self, _acceleration_structure: Resource) {}
fn tlas_instance_to_bytes(&self, instance: TlasInstance) -> Vec<u8> {
vec![]
}
fn get_internal_counters(&self) -> wgt::HalCounters { fn get_internal_counters(&self) -> wgt::HalCounters {
Default::default() Default::default()
} }

View File

@ -851,6 +851,8 @@ impl super::Adapter {
// being, provide 1 as the value here, to cause as little // being, provide 1 as the value here, to cause as little
// trouble as possible. // trouble as possible.
uniform_bounds_check_alignment: wgt::BufferSize::new(1).unwrap(), uniform_bounds_check_alignment: wgt::BufferSize::new(1).unwrap(),
raw_tlas_instance_size: 0,
ray_tracing_scratch_buffer_alignment: 0,
}, },
}, },
}) })

View File

@ -8,7 +8,7 @@ use std::{
sync::{Arc, Mutex}, sync::{Arc, Mutex},
}; };
use crate::AtomicFenceValue; use crate::{AtomicFenceValue, TlasInstance};
use arrayvec::ArrayVec; use arrayvec::ArrayVec;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
@ -1633,6 +1633,10 @@ impl crate::Device for super::Device {
) { ) {
} }
fn tlas_instance_to_bytes(&self, _instance: TlasInstance) -> Vec<u8> {
unimplemented!()
}
fn get_internal_counters(&self) -> wgt::HalCounters { fn get_internal_counters(&self) -> wgt::HalCounters {
self.counters.clone() self.counters.clone()
} }

View File

@ -971,6 +971,7 @@ pub trait Device: WasmNotSendSync {
&self, &self,
acceleration_structure: <Self::A as Api>::AccelerationStructure, acceleration_structure: <Self::A as Api>::AccelerationStructure,
); );
fn tlas_instance_to_bytes(&self, instance: TlasInstance) -> Vec<u8>;
fn get_internal_counters(&self) -> wgt::HalCounters; fn get_internal_counters(&self) -> wgt::HalCounters;
@ -1771,6 +1772,12 @@ pub struct Alignments {
/// [`Uniform`]: wgt::BufferBindingType::Uniform /// [`Uniform`]: wgt::BufferBindingType::Uniform
/// [size]: BufferBinding::size /// [size]: BufferBinding::size
pub uniform_bounds_check_alignment: wgt::BufferSize, pub uniform_bounds_check_alignment: wgt::BufferSize,
/// The size of the raw TLAS instance
pub raw_tlas_instance_size: usize,
/// What the scratch buffer for building an acceleration structure must be aligned to
pub ray_tracing_scratch_buffer_alignment: u32,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -2519,3 +2526,11 @@ bitflags::bitflags! {
pub struct AccelerationStructureBarrier { pub struct AccelerationStructureBarrier {
pub usage: Range<AccelerationStructureUses>, pub usage: Range<AccelerationStructureUses>,
} }
#[derive(Debug, Copy, Clone)]
pub struct TlasInstance {
pub transform: [f32; 12],
pub custom_index: u32,
pub mask: u8,
pub blas_address: u64,
}

View File

@ -1001,6 +1001,8 @@ impl super::PrivateCapabilities {
// Metal Shading Language it generates, so from `wgpu_hal`'s // Metal Shading Language it generates, so from `wgpu_hal`'s
// users' point of view, references are tightly checked. // users' point of view, references are tightly checked.
uniform_bounds_check_alignment: wgt::BufferSize::new(1).unwrap(), uniform_bounds_check_alignment: wgt::BufferSize::new(1).unwrap(),
raw_tlas_instance_size: 0,
ray_tracing_scratch_buffer_alignment: 0,
}, },
downlevel, downlevel,
} }

View File

@ -8,6 +8,7 @@ use std::{
use super::conv; use super::conv;
use crate::auxil::map_naga_stage; use crate::auxil::map_naga_stage;
use crate::TlasInstance;
type DeviceResult<T> = Result<T, crate::DeviceError>; type DeviceResult<T> = Result<T, crate::DeviceError>;
@ -1426,6 +1427,10 @@ impl crate::Device for super::Device {
unimplemented!() unimplemented!()
} }
fn tlas_instance_to_bytes(&self, _instance: TlasInstance) -> Vec<u8> {
unimplemented!()
}
fn get_internal_counters(&self) -> wgt::HalCounters { fn get_internal_counters(&self) -> wgt::HalCounters {
self.counters.clone() self.counters.clone()
} }

View File

@ -1140,6 +1140,13 @@ impl PhysicalDeviceProperties {
}; };
wgt::BufferSize::new(alignment).unwrap() wgt::BufferSize::new(alignment).unwrap()
}, },
raw_tlas_instance_size: 64,
ray_tracing_scratch_buffer_alignment: self.acceleration_structure.map_or(
0,
|acceleration_structure| {
acceleration_structure.min_acceleration_structure_scratch_offset_alignment
},
),
} }
} }
} }

View File

@ -1,16 +1,18 @@
use super::conv; use super::{conv, RawTlasInstance};
use arrayvec::ArrayVec; use arrayvec::ArrayVec;
use ash::{khr, vk}; use ash::{khr, vk};
use parking_lot::Mutex; use parking_lot::Mutex;
use crate::TlasInstance;
use std::{ use std::{
borrow::Cow, borrow::Cow,
collections::{hash_map::Entry, BTreeMap}, collections::{hash_map::Entry, BTreeMap},
ffi::{CStr, CString}, ffi::{CStr, CString},
mem,
mem::MaybeUninit, mem::MaybeUninit,
num::NonZeroU32, num::NonZeroU32,
ptr, ptr, slice,
sync::Arc, sync::Arc,
}; };
@ -2557,6 +2559,22 @@ impl crate::Device for super::Device {
self.counters.clone() self.counters.clone()
} }
fn tlas_instance_to_bytes(&self, instance: TlasInstance) -> Vec<u8> {
const MAX_U24: u32 = (1u32 << 24u32) - 1u32;
let temp = RawTlasInstance {
transform: instance.transform,
custom_index_and_mask: (instance.custom_index & MAX_U24)
| (u32::from(instance.mask) << 24),
shader_binding_table_record_offset_and_flags: 0,
acceleration_structure_reference: instance.blas_address,
};
let temp: *const _ = &temp;
unsafe {
slice::from_raw_parts::<u8>(temp.cast::<u8>(), mem::size_of::<RawTlasInstance>())
.to_vec()
}
}
} }
impl super::DeviceShared { impl super::DeviceShared {

View File

@ -1423,3 +1423,12 @@ fn get_lost_err() -> crate::DeviceError {
#[allow(unreachable_code)] #[allow(unreachable_code)]
crate::DeviceError::Lost crate::DeviceError::Lost
} }
#[derive(Clone)]
#[repr(C)]
struct RawTlasInstance {
transform: [f32; 12],
custom_index_and_mask: u32,
shader_binding_table_record_offset_and_flags: u32,
acceleration_structure_reference: u64,
}