Implement as_hal for BLASes and TLASes (#7303)

This commit is contained in:
Vecvec 2025-04-10 06:50:43 +12:00 committed by GitHub
parent 6ef9b8256a
commit 1c4b73c098
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 240 additions and 5 deletions

View File

@ -227,7 +227,8 @@ By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326).
- Added `CommandEncoder::transition_resources()` for native API interop, and allowing users to slightly optimize barriers. By @JMS55 in [#6678](https://github.com/gfx-rs/wgpu/pull/6678).
- Add `wgpu_hal::vulkan::Adapter::texture_format_as_raw` for native API interop. By @JMS55 in [#7228](https://github.com/gfx-rs/wgpu/pull/7228).
- Support getting vertices of the hit triangle when raytracing. By @Vecvec in [#7183](https://github.com/gfx-rs/wgpu/pull/7183) .
- Support getting vertices of the hit triangle when raytracing. By @Vecvec in [#7183](https://github.com/gfx-rs/wgpu/pull/7183).
- Add `as_hal` for both acceleration structures. By @Vecvec in [#7303](https://github.com/gfx-rs/wgpu/pull/7303).
- Add Metal compute shader passthrough. Use `create_shader_module_passthrough` on device. By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326).

View File

@ -28,6 +28,8 @@ use crate::{
FastHashSet,
};
use crate::id::{BlasId, TlasId};
struct TriangleBufferStore<'a> {
vertex_buffer: Arc<Buffer>,
vertex_transition: Option<PendingTransition<BufferUses>>,
@ -61,6 +63,60 @@ struct TlasBufferStore {
}
impl Global {
pub fn command_encoder_mark_acceleration_structures_built(
&self,
command_encoder_id: CommandEncoderId,
blas_ids: &[BlasId],
tlas_ids: &[TlasId],
) -> Result<(), BuildAccelerationStructureError> {
profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
let hub = &self.hub;
let cmd_buf = hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id());
let device = &cmd_buf.device;
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
let build_command_index = NonZeroU64::new(
device
.last_acceleration_structure_build_command_index
.fetch_add(1, Ordering::Relaxed),
)
.unwrap();
let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?;
let cmd_buf_data = &mut *cmd_buf_data_guard;
cmd_buf_data.blas_actions.reserve(blas_ids.len());
cmd_buf_data.tlas_actions.reserve(tlas_ids.len());
for blas in blas_ids {
let blas = hub.blas_s.get(*blas).get()?;
cmd_buf_data.blas_actions.push(BlasAction {
blas,
kind: crate::ray_tracing::BlasActionKind::Build(build_command_index),
});
}
for tlas in tlas_ids {
let tlas = hub.tlas_s.get(*tlas).get()?;
cmd_buf_data.tlas_actions.push(TlasAction {
tlas,
kind: crate::ray_tracing::TlasActionKind::Build {
build_index: build_command_index,
dependencies: Vec::new(),
},
});
}
Ok(())
}
// Currently this function is very similar to its safe counterpart, however certain parts of it are very different,
// making for the two to be implemented differently, the main difference is this function has separate buffers for each
// of the TLAS instances while the other has one large buffer

View File

@ -36,6 +36,8 @@ use crate::{
Label, LabelHelpers, SubmissionIndex,
};
use crate::id::{BlasId, TlasId};
/// Information about the wgpu-core resource.
///
/// Each type representing a `wgpu-core` resource, like [`Device`],
@ -1412,6 +1414,54 @@ impl Global {
hal_queue_callback(hal_queue)
}
/// # Safety
///
/// - The raw blas handle must not be manually destroyed
pub unsafe fn blas_as_hal<A: HalApi, F: FnOnce(Option<&A::AccelerationStructure>) -> R, R>(
&self,
id: BlasId,
hal_blas_callback: F,
) -> R {
profiling::scope!("Blas::as_hal");
let hub = &self.hub;
if let Ok(blas) = hub.blas_s.get(id).get() {
let snatch_guard = blas.device.snatchable_lock.read();
let hal_blas = blas
.try_raw(&snatch_guard)
.ok()
.and_then(|b| b.as_any().downcast_ref());
hal_blas_callback(hal_blas)
} else {
hal_blas_callback(None)
}
}
/// # Safety
///
/// - The raw tlas handle must not be manually destroyed
pub unsafe fn tlas_as_hal<A: HalApi, F: FnOnce(Option<&A::AccelerationStructure>) -> R, R>(
&self,
id: TlasId,
hal_tlas_callback: F,
) -> R {
profiling::scope!("Blas::as_hal");
let hub = &self.hub;
if let Ok(tlas) = hub.tlas_s.get(id).get() {
let snatch_guard = tlas.device.snatchable_lock.read();
let hal_tlas = tlas
.try_raw(&snatch_guard)
.ok()
.and_then(|t| t.as_any().downcast_ref());
hal_tlas_callback(hal_tlas)
} else {
hal_tlas_callback(None)
}
}
}
/// A texture that has been marked as destroyed and is staged for actual deletion soon.

View File

@ -150,6 +150,30 @@ impl Blas {
pub fn handle(&self) -> Option<u64> {
self.handle
}
/// Returns the inner hal Acceleration Structure using a callback. The hal acceleration structure
/// will be `None` if the backend type argument does not match with this wgpu Blas
///
/// This method will start the wgpu_core level command recording.
///
/// # Safety
///
/// - The raw handle obtained from the hal Acceleration Structure must not be manually destroyed
#[cfg(wgpu_core)]
pub unsafe fn as_hal<
A: wgc::hal_api::HalApi,
F: FnOnce(Option<&A::AccelerationStructure>) -> R,
R,
>(
&mut self,
hal_blas_callback: F,
) -> R {
if let Some(blas) = self.inner.as_core_opt() {
unsafe { blas.context.blas_as_hal::<A, F, R>(blas, hal_blas_callback) }
} else {
hal_blas_callback(None)
}
}
}
/// Context version of [BlasTriangleGeometry].

View File

@ -285,6 +285,22 @@ impl CommandEncoder {
/// [`Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE`] must be enabled on the device in order to call these functions.
impl CommandEncoder {
/// Mark acceleration structures as being built. ***Should only*** be used with wgpu-hal
/// functions, all wgpu functions already mark acceleration structures as built.
///
/// # Safety
///
/// - All acceleration structures must have been build in this command encoder.
/// - All BLASes inputted must have been built before all TLASes that were inputted here and
/// which use them.
pub unsafe fn mark_acceleration_structures_built<'a>(
&self,
blas: impl IntoIterator<Item = &'a Blas>,
tlas: impl IntoIterator<Item = &'a Tlas>,
) {
self.inner
.mark_acceleration_structures_built(&mut blas.into_iter(), &mut tlas.into_iter())
}
/// Build bottom and top level acceleration structures.
///
/// Builds the BLASes then the TLASes, but does ***not*** build the BLASes into the TLASes,

View File

@ -32,6 +32,33 @@ static_assertions::assert_impl_all!(Tlas: WasmNotSendSync);
crate::cmp::impl_eq_ord_hash_proxy!(Tlas => .shared.inner);
impl Tlas {
/// Returns the inner hal Acceleration Structure using a callback. The hal acceleration structure
/// will be `None` if the backend type argument does not match with this wgpu Tlas
///
/// This method will start the wgpu_core level command recording.
///
/// # Safety
///
/// - The raw handle obtained from the hal Acceleration Structure must not be manually destroyed
/// - If the raw handle is build,
#[cfg(wgpu_core)]
pub unsafe fn as_hal<
A: wgc::hal_api::HalApi,
F: FnOnce(Option<&A::AccelerationStructure>) -> R,
R,
>(
&mut self,
hal_tlas_callback: F,
) -> R {
if let Some(tlas) = self.shared.inner.as_core_opt() {
unsafe { tlas.context.tlas_as_hal::<A, F, R>(tlas, hal_tlas_callback) }
} else {
hal_tlas_callback(None)
}
}
}
/// Entry for a top level acceleration structure build.
/// Used with raw instance buffers for an unvalidated builds.
/// See [`TlasPackage`] for the safe version.

View File

@ -26,7 +26,7 @@ use wgt::Backends;
use js_sys::Promise;
use wasm_bindgen::{prelude::*, JsCast};
use crate::{dispatch, SurfaceTargetUnsafe};
use crate::{dispatch, Blas, SurfaceTargetUnsafe, Tlas};
use defined_non_null_js_value::DefinedNonNullJsValue;
@ -3084,6 +3084,14 @@ impl dispatch::CommandEncoderInterface for WebCommandEncoder {
);
}
fn mark_acceleration_structures_built<'a>(
&self,
_blas: &mut dyn Iterator<Item = &'a Blas>,
_tlas: &mut dyn Iterator<Item = &'a Tlas>,
) {
unimplemented!("Raytracing not implemented for web");
}
fn build_acceleration_structures_unsafe_tlas<'a>(
&self,
_blas: &mut dyn Iterator<Item = &'a crate::BlasBuildEntry<'a>>,

View File

@ -18,9 +18,9 @@ use wgt::WasmNotSendSync;
use crate::{
api,
dispatch::{self, BufferMappedRangeInterface},
BindingResource, BufferBinding, BufferDescriptor, CompilationInfo, CompilationMessage,
BindingResource, Blas, BufferBinding, BufferDescriptor, CompilationInfo, CompilationMessage,
CompilationMessageType, ErrorSource, Features, Label, LoadOp, MapMode, Operations,
ShaderSource, SurfaceTargetUnsafe, TextureDescriptor,
ShaderSource, SurfaceTargetUnsafe, TextureDescriptor, Tlas,
};
#[derive(Clone)]
@ -267,6 +267,30 @@ impl ContextWgpuCore {
}
}
pub unsafe fn blas_as_hal<
A: wgc::hal_api::HalApi,
F: FnOnce(Option<&A::AccelerationStructure>) -> R,
R,
>(
&self,
blas: &CoreBlas,
hal_blas_callback: F,
) -> R {
unsafe { self.0.blas_as_hal::<A, F, R>(blas.id, hal_blas_callback) }
}
pub unsafe fn tlas_as_hal<
A: wgc::hal_api::HalApi,
F: FnOnce(Option<&A::AccelerationStructure>) -> R,
R,
>(
&self,
tlas: &CoreTlas,
hal_tlas_callback: F,
) -> R {
unsafe { self.0.tlas_as_hal::<A, F, R>(tlas.id, hal_tlas_callback) }
}
pub fn generate_report(&self) -> wgc::global::GlobalReport {
self.0.generate_report()
}
@ -2480,6 +2504,30 @@ impl dispatch::CommandEncoderInterface for CoreCommandEncoder {
}
}
fn mark_acceleration_structures_built<'a>(
&self,
blas: &mut dyn Iterator<Item = &'a Blas>,
tlas: &mut dyn Iterator<Item = &'a Tlas>,
) {
let blas = blas
.map(|b| b.inner.as_core().id)
.collect::<SmallVec<[_; 4]>>();
let tlas = tlas
.map(|t| t.shared.inner.as_core().id)
.collect::<SmallVec<[_; 4]>>();
if let Err(cause) = self
.context
.0
.command_encoder_mark_acceleration_structures_built(self.id, &blas, &tlas)
{
self.context.handle_error_nolabel(
&self.error_sink,
cause,
"CommandEncoder::build_acceleration_structures_unsafe_tlas",
);
}
}
fn build_acceleration_structures_unsafe_tlas<'a>(
&self,
blas: &mut dyn Iterator<Item = &'a crate::BlasBuildEntry<'a>>,

View File

@ -12,7 +12,7 @@
#![allow(clippy::too_many_arguments)] // It's fine.
#![allow(missing_docs, clippy::missing_safety_doc)] // Interfaces are not documented
use crate::{WasmNotSend, WasmNotSendSync};
use crate::{Blas, Tlas, WasmNotSend, WasmNotSendSync};
use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
use core::{any::Any, fmt::Debug, future::Future, hash::Hash, ops::Range, pin::Pin};
@ -313,6 +313,11 @@ pub trait CommandEncoderInterface: CommonTraits {
destination: &DispatchBuffer,
destination_offset: crate::BufferAddress,
);
fn mark_acceleration_structures_built<'a>(
&self,
blas: &mut dyn Iterator<Item = &'a Blas>,
tlas: &mut dyn Iterator<Item = &'a Tlas>,
);
fn build_acceleration_structures_unsafe_tlas<'a>(
&self,