From 77a83fb0dd9f2295f25e99a850b9a031738925c3 Mon Sep 17 00:00:00 2001 From: Andreas Reich Date: Tue, 14 May 2024 22:05:17 +0200 Subject: [PATCH] Remove lifetime constraints from `wgpu::ComputePass` methods (#5570) * basic test setup * remove lifetime and drop resources on test - test fails now just as expected * compute pass recording is now hub dependent (needs gfx_select) * compute pass recording now bumps reference count of uses resources directly on recording TODO: * bind groups don't work because the Binder gets an id only * wgpu level error handling is missing * simplify compute pass state flush, compute pass execution no longer needs to lock bind_group storage * wgpu sided error handling * make ComputePass hal dependent, removing command cast hack. Introduce DynComputePass on wgpu side * remove stray repr(C) * changelog entry * fix deno issues -> move DynComputePass into wgc * split out resources setup from test --- CHANGELOG.md | 8 + deno_webgpu/command_encoder.rs | 5 +- deno_webgpu/compute_pass.rs | 67 ++-- .../tests/compute_pass_resource_ownership.rs | 174 +++++++++ tests/tests/root.rs | 1 + wgpu-core/src/command/bind.rs | 5 +- wgpu-core/src/command/compute.rs | 347 +++++++++++------- wgpu-core/src/command/dyn_compute_pass.rs | 136 +++++++ wgpu-core/src/command/mod.rs | 5 +- wgpu/src/backend/wgpu_core.rs | 85 +++-- wgpu/src/lib.rs | 6 +- 11 files changed, 631 insertions(+), 208 deletions(-) create mode 100644 tests/tests/compute_pass_resource_ownership.rs create mode 100644 wgpu-core/src/command/dyn_compute_pass.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index e3ba44fe5..f26392b38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,14 @@ Bottom level categories: ### Major Changes +#### Remove lifetime bounds on `wgpu::ComputePass` + +TODO(wumpf): This is still work in progress. Should write a bit more about it. Also will very likely extend to `wgpu::RenderPass` before release. + +`wgpu::ComputePass` recording methods (e.g. `wgpu::ComputePass:set_render_pipeline`) no longer impose a lifetime constraint passed in resources. + +By @wumpf in [#5569](https://github.com/gfx-rs/wgpu/pull/5569). + #### Querying shader compilation errors Wgpu now supports querying [shader compilation info](https://www.w3.org/TR/webgpu/#dom-gpushadermodule-getcompilationinfo). diff --git a/deno_webgpu/command_encoder.rs b/deno_webgpu/command_encoder.rs index 20dfe0db0..b82fba92e 100644 --- a/deno_webgpu/command_encoder.rs +++ b/deno_webgpu/command_encoder.rs @@ -254,13 +254,14 @@ pub fn op_webgpu_command_encoder_begin_compute_pass( None }; + let instance = state.borrow::(); + let command_encoder = &command_encoder_resource.1; let descriptor = wgpu_core::command::ComputePassDescriptor { label: Some(label), timestamp_writes: timestamp_writes.as_ref(), }; - let compute_pass = - wgpu_core::command::ComputePass::new(command_encoder_resource.1, &descriptor); + let compute_pass = gfx_select!(command_encoder => instance.command_encoder_create_compute_pass_dyn(*command_encoder, &descriptor)); let rid = state .resource_table diff --git a/deno_webgpu/compute_pass.rs b/deno_webgpu/compute_pass.rs index 2cdea2c8f..fb499e7e0 100644 --- a/deno_webgpu/compute_pass.rs +++ b/deno_webgpu/compute_pass.rs @@ -10,7 +10,9 @@ use std::cell::RefCell; use super::error::WebGpuResult; -pub(crate) struct WebGpuComputePass(pub(crate) RefCell); +pub(crate) struct WebGpuComputePass( + pub(crate) RefCell>, +); impl Resource for WebGpuComputePass { fn name(&self) -> Cow { "webGPUComputePass".into() @@ -31,10 +33,10 @@ pub fn op_webgpu_compute_pass_set_pipeline( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_set_pipeline( - &mut compute_pass_resource.0.borrow_mut(), - compute_pipeline_resource.1, - ); + compute_pass_resource + .0 + .borrow_mut() + .set_pipeline(state.borrow(), compute_pipeline_resource.1)?; Ok(WebGpuResult::empty()) } @@ -52,12 +54,10 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_dispatch_workgroups( - &mut compute_pass_resource.0.borrow_mut(), - x, - y, - z, - ); + compute_pass_resource + .0 + .borrow_mut() + .dispatch_workgroups(state.borrow(), x, y, z); Ok(WebGpuResult::empty()) } @@ -77,11 +77,10 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups_indirect( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_dispatch_workgroups_indirect( - &mut compute_pass_resource.0.borrow_mut(), - buffer_resource.1, - indirect_offset, - ); + compute_pass_resource + .0 + .borrow_mut() + .dispatch_workgroups_indirect(state.borrow(), buffer_resource.1, indirect_offset)?; Ok(WebGpuResult::empty()) } @@ -90,24 +89,15 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups_indirect( #[serde] pub fn op_webgpu_compute_pass_end( state: &mut OpState, - #[smi] command_encoder_rid: ResourceId, #[smi] compute_pass_rid: ResourceId, ) -> Result { - let command_encoder_resource = - state - .resource_table - .get::(command_encoder_rid)?; - let command_encoder = command_encoder_resource.1; let compute_pass_resource = state .resource_table .take::(compute_pass_rid)?; - let compute_pass = &compute_pass_resource.0.borrow(); - let instance = state.borrow::(); - gfx_ok!(command_encoder => instance.command_encoder_run_compute_pass( - command_encoder, - compute_pass - )) + compute_pass_resource.0.borrow_mut().run(state.borrow())?; + + Ok(WebGpuResult::empty()) } #[op2] @@ -137,12 +127,12 @@ pub fn op_webgpu_compute_pass_set_bind_group( let dynamic_offsets_data: &[u32] = &dynamic_offsets_data[start..start + len]; - wgpu_core::command::compute_commands::wgpu_compute_pass_set_bind_group( - &mut compute_pass_resource.0.borrow_mut(), + compute_pass_resource.0.borrow_mut().set_bind_group( + state.borrow(), index, bind_group_resource.1, dynamic_offsets_data, - ); + )?; Ok(WebGpuResult::empty()) } @@ -158,8 +148,8 @@ pub fn op_webgpu_compute_pass_push_debug_group( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_push_debug_group( - &mut compute_pass_resource.0.borrow_mut(), + compute_pass_resource.0.borrow_mut().push_debug_group( + state.borrow(), group_label, 0, // wgpu#975 ); @@ -177,9 +167,10 @@ pub fn op_webgpu_compute_pass_pop_debug_group( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_pop_debug_group( - &mut compute_pass_resource.0.borrow_mut(), - ); + compute_pass_resource + .0 + .borrow_mut() + .pop_debug_group(state.borrow()); Ok(WebGpuResult::empty()) } @@ -195,8 +186,8 @@ pub fn op_webgpu_compute_pass_insert_debug_marker( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_insert_debug_marker( - &mut compute_pass_resource.0.borrow_mut(), + compute_pass_resource.0.borrow_mut().insert_debug_marker( + state.borrow(), marker_label, 0, // wgpu#975 ); diff --git a/tests/tests/compute_pass_resource_ownership.rs b/tests/tests/compute_pass_resource_ownership.rs new file mode 100644 index 000000000..6612ad006 --- /dev/null +++ b/tests/tests/compute_pass_resource_ownership.rs @@ -0,0 +1,174 @@ +//! Tests that compute passes take ownership of resources that are associated with. +//! I.e. once a resource is passed in to a compute pass, it can be dropped. +//! +//! TODO: Test doesn't check on timestamp writes & pipeline statistics queries yet. +//! (Not important as long as they are lifetime constrained to the command encoder, +//! but once we lift this constraint, we should add tests for this as well!) +//! TODO: Also should test resource ownership for: +//! * write_timestamp +//! * begin_pipeline_statistics_query + +use std::num::NonZeroU64; + +use wgpu::util::DeviceExt as _; +use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; + +const SHADER_SRC: &str = " +@group(0) @binding(0) +var buffer: array; + +@compute @workgroup_size(1, 1, 1) fn main() { + buffer[0] *= 2.0; +} +"; + +#[gpu_test] +static COMPUTE_PASS_RESOURCE_OWNERSHIP: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters(TestParameters::default().test_features_limits()) + .run_async(compute_pass_resource_ownership); + +async fn compute_pass_resource_ownership(ctx: TestingContext) { + let ResourceSetup { + gpu_buffer, + cpu_buffer, + buffer_size, + indirect_buffer, + bind_group, + pipeline, + } = resource_setup(&ctx); + + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("encoder"), + }); + + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("compute_pass"), + timestamp_writes: None, // TODO: See description above, we should test this as well once we lift the lifetime bound. + }); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch_workgroups_indirect(&indirect_buffer, 0); + + // Now drop all resources we set. Then do a device poll to make sure the resources are really not dropped too early, no matter what. + drop(pipeline); + drop(bind_group); + drop(indirect_buffer); + ctx.async_poll(wgpu::Maintain::wait()) + .await + .panic_on_timeout(); + } + + // Ensure that the compute pass still executed normally. + encoder.copy_buffer_to_buffer(&gpu_buffer, 0, &cpu_buffer, 0, buffer_size); + ctx.queue.submit([encoder.finish()]); + cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ()); + ctx.async_poll(wgpu::Maintain::wait()) + .await + .panic_on_timeout(); + + let data = cpu_buffer.slice(..).get_mapped_range(); + + let floats: &[f32] = bytemuck::cast_slice(&data); + assert_eq!(floats, [2.0, 4.0, 6.0, 8.0]); +} + +// Setup ------------------------------------------------------------ + +struct ResourceSetup { + gpu_buffer: wgpu::Buffer, + cpu_buffer: wgpu::Buffer, + buffer_size: u64, + + indirect_buffer: wgpu::Buffer, + bind_group: wgpu::BindGroup, + pipeline: wgpu::ComputePipeline, +} + +fn resource_setup(ctx: &TestingContext) -> ResourceSetup { + let sm = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("shader"), + source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()), + }); + + let buffer_size = 4 * std::mem::size_of::() as u64; + + let bgl = ctx + .device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("bind_group_layout"), + entries: &[wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: NonZeroU64::new(buffer_size), + }, + count: None, + }], + }); + + let gpu_buffer = ctx + .device + .create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("gpu_buffer"), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + contents: bytemuck::bytes_of(&[1.0_f32, 2.0, 3.0, 4.0]), + }); + + let cpu_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: Some("cpu_buffer"), + size: buffer_size, + usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + + let indirect_buffer = ctx + .device + .create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("gpu_buffer"), + usage: wgpu::BufferUsages::INDIRECT, + contents: wgpu::util::DispatchIndirectArgs { x: 1, y: 1, z: 1 }.as_bytes(), + }); + + let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("bind_group"), + layout: &bgl, + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: gpu_buffer.as_entire_binding(), + }], + }); + + let pipeline_layout = ctx + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("pipeline_layout"), + bind_group_layouts: &[&bgl], + push_constant_ranges: &[], + }); + + let pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("pipeline"), + layout: Some(&pipeline_layout), + module: &sm, + entry_point: "main", + compilation_options: Default::default(), + }); + + ResourceSetup { + gpu_buffer, + cpu_buffer, + buffer_size, + indirect_buffer, + bind_group, + pipeline, + } +} diff --git a/tests/tests/root.rs b/tests/tests/root.rs index 6dc7af56e..ba5e02079 100644 --- a/tests/tests/root.rs +++ b/tests/tests/root.rs @@ -11,6 +11,7 @@ mod buffer; mod buffer_copy; mod buffer_usages; mod clear_texture; +mod compute_pass_resource_ownership; mod create_surface_error; mod device; mod encoder; diff --git a/wgpu-core/src/command/bind.rs b/wgpu-core/src/command/bind.rs index 7b2ac5455..c643611a9 100644 --- a/wgpu-core/src/command/bind.rs +++ b/wgpu-core/src/command/bind.rs @@ -4,7 +4,6 @@ use crate::{ binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout}, device::SHADER_STAGE_COUNT, hal_api::HalApi, - id::BindGroupId, pipeline::LateSizedBufferGroup, resource::Resource, }; @@ -359,11 +358,11 @@ impl Binder { &self.payloads[bind_range] } - pub(super) fn list_active(&self) -> impl Iterator + '_ { + pub(super) fn list_active<'a>(&'a self) -> impl Iterator>> + '_ { let payloads = &self.payloads; self.manager .list_active() - .map(move |index| payloads[index].group.as_ref().unwrap().as_info().id()) + .map(move |index| payloads[index].group.as_ref().unwrap()) } pub(super) fn invalid_mask(&self) -> BindGroupMask { diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 4ee48f008..997c62e8b 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -1,29 +1,23 @@ -use crate::command::compute_command::{ArcComputeCommand, ComputeCommand}; -use crate::device::DeviceError; -use crate::resource::Resource; -use crate::snatch::SnatchGuard; -use crate::track::TrackerIndex; use crate::{ - binding_model::{ - BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError, - }, + binding_model::{BindError, LateMinBufferBindingSizeMismatch, PushConstantUploadError}, command::{ bind::Binder, + compute_command::{ArcComputeCommand, ComputeCommand}, end_pipeline_statistics_query, memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState}, BasePass, BasePassRef, BindGroupStateChange, CommandBuffer, CommandEncoderError, CommandEncoderStatus, MapPassErr, PassErrorScope, QueryUseError, StateChange, }, - device::{MissingDownlevelFlags, MissingFeatures}, + device::{DeviceError, MissingDownlevelFlags, MissingFeatures}, error::{ErrorFormatter, PrettyError}, global::Global, hal_api::HalApi, - hal_label, id, - id::DeviceId, + hal_label, + id::{self, DeviceId}, init_tracker::MemoryInitKind, - resource::{self}, - storage::Storage, - track::{Tracker, UsageConflict, UsageScope}, + resource::{self, Resource}, + snatch::SnatchGuard, + track::{Tracker, TrackerIndex, UsageConflict, UsageScope}, validation::{check_buffer_usage, MissingBufferUsageError}, Label, }; @@ -35,27 +29,25 @@ use serde::Deserialize; use serde::Serialize; use thiserror::Error; +use wgt::{BufferAddress, DynamicOffset}; use std::sync::Arc; use std::{fmt, mem, str}; -#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] -pub struct ComputePass { - base: BasePass, +pub struct ComputePass { + base: BasePass>, parent_id: id::CommandEncoderId, timestamp_writes: Option, // Resource binding dedupe state. - #[cfg_attr(feature = "serde", serde(skip))] current_bind_groups: BindGroupStateChange, - #[cfg_attr(feature = "serde", serde(skip))] current_pipeline: StateChange, } -impl ComputePass { - pub fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self { +impl ComputePass { + fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self { Self { - base: BasePass::new(&desc.label), + base: BasePass::>::new(&desc.label), parent_id, timestamp_writes: desc.timestamp_writes.cloned(), @@ -67,30 +59,15 @@ impl ComputePass { pub fn parent_id(&self) -> id::CommandEncoderId { self.parent_id } - - #[cfg(feature = "trace")] - pub fn into_command(self) -> crate::device::trace::Command { - crate::device::trace::Command::RunComputePass { - base: self.base, - timestamp_writes: self.timestamp_writes, - } - } } -impl fmt::Debug for ComputePass { +impl fmt::Debug for ComputePass { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}", - self.parent_id, - self.base.commands.len(), - self.base.dynamic_offsets.len() - ) + write!(f, "ComputePass {{ encoder_id: {:?} }}", self.parent_id) } } /// Describes the writing of timestamp values in a compute pass. -#[repr(C)] #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct ComputePassTimestampWrites { @@ -253,22 +230,19 @@ impl<'a, A: HalApi> State<'a, A> { &mut self, raw_encoder: &mut A::CommandEncoder, base_trackers: &mut Tracker, - bind_group_guard: &Storage>, indirect_buffer: Option, snatch_guard: &SnatchGuard, ) -> Result<(), UsageConflict> { - for id in self.binder.list_active() { - unsafe { self.scope.merge_bind_group(&bind_group_guard[id].used)? }; + for bind_group in self.binder.list_active() { + unsafe { self.scope.merge_bind_group(&bind_group.used)? }; // Note: stateless trackers are not merged: the lifetime reference // is held to the bind group itself. } - for id in self.binder.list_active() { + for bind_group in self.binder.list_active() { unsafe { - base_trackers.set_and_remove_from_usage_scope_sparse( - &mut self.scope, - &bind_group_guard[id].used, - ) + base_trackers + .set_and_remove_from_usage_scope_sparse(&mut self.scope, &bind_group.used) } } @@ -286,17 +260,31 @@ impl<'a, A: HalApi> State<'a, A> { } } -// Common routines between render/compute +// Running the compute pass. impl Global { + pub fn command_encoder_create_compute_pass( + &self, + parent_id: id::CommandEncoderId, + desc: &ComputePassDescriptor, + ) -> ComputePass { + ComputePass::new(parent_id, desc) + } + + pub fn command_encoder_create_compute_pass_dyn( + &self, + parent_id: id::CommandEncoderId, + desc: &ComputePassDescriptor, + ) -> Box { + Box::new(ComputePass::::new(parent_id, desc)) + } + pub fn command_encoder_run_compute_pass( &self, - encoder_id: id::CommandEncoderId, - pass: &ComputePass, + pass: &ComputePass, ) -> Result<(), ComputePassError> { - // TODO: This should go directly to `command_encoder_run_compute_pass_impl` by means of storing `ArcComputeCommand` internally. - self.command_encoder_run_compute_pass_with_unresolved_commands::( - encoder_id, + self.command_encoder_run_compute_pass_impl( + pass.parent_id, pass.base.as_ref(), pass.timestamp_writes.as_ref(), ) @@ -377,7 +365,6 @@ impl Global { *status = CommandEncoderStatus::Error; let raw = encoder.open().map_pass_err(pass_scope)?; - let bind_group_guard = hub.bind_groups.read(); let query_set_guard = hub.query_sets.read(); let mut state = State { @@ -642,13 +629,7 @@ impl Global { state.is_ready().map_pass_err(scope)?; state - .flush_states( - raw, - &mut intermediate_trackers, - &*bind_group_guard, - None, - &snatch_guard, - ) + .flush_states(raw, &mut intermediate_trackers, None, &snatch_guard) .map_pass_err(scope)?; let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension; @@ -721,7 +702,6 @@ impl Global { .flush_states( raw, &mut intermediate_trackers, - &*bind_group_guard, Some(buffer.as_info().tracker_index()), &snatch_guard, ) @@ -845,18 +825,15 @@ impl Global { } } -pub mod compute_commands { - use super::{ComputeCommand, ComputePass}; - use crate::id; - use std::convert::TryInto; - use wgt::{BufferAddress, DynamicOffset}; - - pub fn wgpu_compute_pass_set_bind_group( - pass: &mut ComputePass, +// Recording a compute pass. +impl Global { + pub fn compute_pass_set_bind_group( + &self, + pass: &mut ComputePass, index: u32, bind_group_id: id::BindGroupId, offsets: &[DynamicOffset], - ) { + ) -> Result<(), ComputePassError> { let redundant = pass.current_bind_groups.set_and_check_redundant( bind_group_id, index, @@ -865,30 +842,62 @@ pub mod compute_commands { ); if redundant { - return; + return Ok(()); } - pass.base.commands.push(ComputeCommand::SetBindGroup { + let hub = A::hub(self); + let bind_group = hub + .bind_groups + .read() + .get(bind_group_id) + .map_err(|_| ComputePassError { + scope: PassErrorScope::SetBindGroup(bind_group_id), + inner: ComputePassErrorInner::InvalidBindGroup(index), + })? + .clone(); + + pass.base.commands.push(ArcComputeCommand::SetBindGroup { index, num_dynamic_offsets: offsets.len(), - bind_group_id, + bind_group, }); + + Ok(()) } - pub fn wgpu_compute_pass_set_pipeline( - pass: &mut ComputePass, + pub fn compute_pass_set_pipeline( + &self, + pass: &mut ComputePass, pipeline_id: id::ComputePipelineId, - ) { + ) -> Result<(), ComputePassError> { if pass.current_pipeline.set_and_check_redundant(pipeline_id) { - return; + return Ok(()); } + let hub = A::hub(self); + let pipeline = hub + .compute_pipelines + .read() + .get(pipeline_id) + .map_err(|_| ComputePassError { + scope: PassErrorScope::SetPipelineCompute(pipeline_id), + inner: ComputePassErrorInner::InvalidPipeline(pipeline_id), + })? + .clone(); + pass.base .commands - .push(ComputeCommand::SetPipeline(pipeline_id)); + .push(ArcComputeCommand::SetPipeline(pipeline)); + + Ok(()) } - pub fn wgpu_compute_pass_set_push_constant(pass: &mut ComputePass, offset: u32, data: &[u8]) { + pub fn compute_pass_set_push_constant( + &self, + pass: &mut ComputePass, + offset: u32, + data: &[u8], + ) { assert_eq!( offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1), 0, @@ -901,92 +910,156 @@ pub mod compute_commands { ); let value_offset = pass.base.push_constant_data.len().try_into().expect( "Ran out of push constant space. Don't set 4gb of push constants per ComputePass.", - ); + ); // TODO: make this an error that can be handled pass.base.push_constant_data.extend( data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize) .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])), ); - pass.base.commands.push(ComputeCommand::SetPushConstant { - offset, - size_bytes: data.len() as u32, - values_offset: value_offset, - }); + pass.base + .commands + .push(ArcComputeCommand::::SetPushConstant { + offset, + size_bytes: data.len() as u32, + values_offset: value_offset, + }); } - pub fn wgpu_compute_pass_dispatch_workgroups( - pass: &mut ComputePass, + pub fn compute_pass_dispatch_workgroups( + &self, + pass: &mut ComputePass, groups_x: u32, groups_y: u32, groups_z: u32, ) { - pass.base - .commands - .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z])); + pass.base.commands.push(ArcComputeCommand::::Dispatch([ + groups_x, groups_y, groups_z, + ])); } - pub fn wgpu_compute_pass_dispatch_workgroups_indirect( - pass: &mut ComputePass, + pub fn compute_pass_dispatch_workgroups_indirect( + &self, + pass: &mut ComputePass, buffer_id: id::BufferId, offset: BufferAddress, - ) { + ) -> Result<(), ComputePassError> { + let hub = A::hub(self); + let buffer = hub + .buffers + .read() + .get(buffer_id) + .map_err(|_| ComputePassError { + scope: PassErrorScope::Dispatch { + indirect: true, + pipeline: pass.current_pipeline.last_state, + }, + inner: ComputePassErrorInner::InvalidBuffer(buffer_id), + })? + .clone(); + pass.base .commands - .push(ComputeCommand::DispatchIndirect { buffer_id, offset }); + .push(ArcComputeCommand::::DispatchIndirect { buffer, offset }); + + Ok(()) } - pub fn wgpu_compute_pass_push_debug_group(pass: &mut ComputePass, label: &str, color: u32) { + pub fn compute_pass_push_debug_group( + &self, + pass: &mut ComputePass, + label: &str, + color: u32, + ) { let bytes = label.as_bytes(); pass.base.string_data.extend_from_slice(bytes); - pass.base.commands.push(ComputeCommand::PushDebugGroup { - color, - len: bytes.len(), - }); - } - - pub fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) { - pass.base.commands.push(ComputeCommand::PopDebugGroup); - } - - pub fn wgpu_compute_pass_insert_debug_marker(pass: &mut ComputePass, label: &str, color: u32) { - let bytes = label.as_bytes(); - pass.base.string_data.extend_from_slice(bytes); - - pass.base.commands.push(ComputeCommand::InsertDebugMarker { - color, - len: bytes.len(), - }); - } - - pub fn wgpu_compute_pass_write_timestamp( - pass: &mut ComputePass, - query_set_id: id::QuerySetId, - query_index: u32, - ) { - pass.base.commands.push(ComputeCommand::WriteTimestamp { - query_set_id, - query_index, - }); - } - - pub fn wgpu_compute_pass_begin_pipeline_statistics_query( - pass: &mut ComputePass, - query_set_id: id::QuerySetId, - query_index: u32, - ) { pass.base .commands - .push(ComputeCommand::BeginPipelineStatisticsQuery { - query_set_id, - query_index, + .push(ArcComputeCommand::::PushDebugGroup { + color, + len: bytes.len(), }); } - pub fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) { + pub fn compute_pass_pop_debug_group(&self, pass: &mut ComputePass) { pass.base .commands - .push(ComputeCommand::EndPipelineStatisticsQuery); + .push(ArcComputeCommand::::PopDebugGroup); + } + + pub fn compute_pass_insert_debug_marker( + &self, + pass: &mut ComputePass, + label: &str, + color: u32, + ) { + let bytes = label.as_bytes(); + pass.base.string_data.extend_from_slice(bytes); + + pass.base + .commands + .push(ArcComputeCommand::::InsertDebugMarker { + color, + len: bytes.len(), + }); + } + + pub fn compute_pass_write_timestamp( + &self, + pass: &mut ComputePass, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError> { + let hub = A::hub(self); + let query_set = hub + .query_sets + .read() + .get(query_set_id) + .map_err(|_| ComputePassError { + scope: PassErrorScope::WriteTimestamp, + inner: ComputePassErrorInner::InvalidQuerySet(query_set_id), + })? + .clone(); + + pass.base.commands.push(ArcComputeCommand::WriteTimestamp { + query_set, + query_index, + }); + + Ok(()) + } + + pub fn compute_pass_begin_pipeline_statistics_query( + &self, + pass: &mut ComputePass, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError> { + let hub = A::hub(self); + let query_set = hub + .query_sets + .read() + .get(query_set_id) + .map_err(|_| ComputePassError { + scope: PassErrorScope::WriteTimestamp, + inner: ComputePassErrorInner::InvalidQuerySet(query_set_id), + })? + .clone(); + + pass.base + .commands + .push(ArcComputeCommand::BeginPipelineStatisticsQuery { + query_set, + query_index, + }); + + Ok(()) + } + + pub fn compute_pass_end_pipeline_statistics_query(&self, pass: &mut ComputePass) { + pass.base + .commands + .push(ArcComputeCommand::::EndPipelineStatisticsQuery); } } diff --git a/wgpu-core/src/command/dyn_compute_pass.rs b/wgpu-core/src/command/dyn_compute_pass.rs new file mode 100644 index 000000000..b7ffea3d4 --- /dev/null +++ b/wgpu-core/src/command/dyn_compute_pass.rs @@ -0,0 +1,136 @@ +use wgt::WasmNotSendSync; + +use crate::{global, hal_api::HalApi, id}; + +use super::{ComputePass, ComputePassError}; + +/// Trait for type erasing ComputePass. +// TODO(#5124): wgpu-core's ComputePass trait should not be hal type dependent. +// Practically speaking this allows us merge gfx_select with type erasure: +// The alternative would be to introduce ComputePassId which then first needs to be looked up and then dispatch via gfx_select. +pub trait DynComputePass: std::fmt::Debug + WasmNotSendSync { + fn run(&mut self, context: &global::Global) -> Result<(), ComputePassError>; + fn set_bind_group( + &mut self, + context: &global::Global, + index: u32, + bind_group_id: id::BindGroupId, + offsets: &[wgt::DynamicOffset], + ) -> Result<(), ComputePassError>; + fn set_pipeline( + &mut self, + context: &global::Global, + pipeline_id: id::ComputePipelineId, + ) -> Result<(), ComputePassError>; + fn set_push_constant(&mut self, context: &global::Global, offset: u32, data: &[u8]); + fn dispatch_workgroups( + &mut self, + context: &global::Global, + groups_x: u32, + groups_y: u32, + groups_z: u32, + ); + fn dispatch_workgroups_indirect( + &mut self, + context: &global::Global, + buffer_id: id::BufferId, + offset: wgt::BufferAddress, + ) -> Result<(), ComputePassError>; + fn push_debug_group(&mut self, context: &global::Global, label: &str, color: u32); + fn pop_debug_group(&mut self, context: &global::Global); + fn insert_debug_marker(&mut self, context: &global::Global, label: &str, color: u32); + fn write_timestamp( + &mut self, + context: &global::Global, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError>; + fn begin_pipeline_statistics_query( + &mut self, + context: &global::Global, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError>; + fn end_pipeline_statistics_query(&mut self, context: &global::Global); +} + +impl DynComputePass for ComputePass { + fn run(&mut self, context: &global::Global) -> Result<(), ComputePassError> { + context.command_encoder_run_compute_pass(self) + } + + fn set_bind_group( + &mut self, + context: &global::Global, + index: u32, + bind_group_id: id::BindGroupId, + offsets: &[wgt::DynamicOffset], + ) -> Result<(), ComputePassError> { + context.compute_pass_set_bind_group(self, index, bind_group_id, offsets) + } + + fn set_pipeline( + &mut self, + context: &global::Global, + pipeline_id: id::ComputePipelineId, + ) -> Result<(), ComputePassError> { + context.compute_pass_set_pipeline(self, pipeline_id) + } + + fn set_push_constant(&mut self, context: &global::Global, offset: u32, data: &[u8]) { + context.compute_pass_set_push_constant(self, offset, data) + } + + fn dispatch_workgroups( + &mut self, + context: &global::Global, + groups_x: u32, + groups_y: u32, + groups_z: u32, + ) { + context.compute_pass_dispatch_workgroups(self, groups_x, groups_y, groups_z) + } + + fn dispatch_workgroups_indirect( + &mut self, + context: &global::Global, + buffer_id: id::BufferId, + offset: wgt::BufferAddress, + ) -> Result<(), ComputePassError> { + context.compute_pass_dispatch_workgroups_indirect(self, buffer_id, offset) + } + + fn push_debug_group(&mut self, context: &global::Global, label: &str, color: u32) { + context.compute_pass_push_debug_group(self, label, color) + } + + fn pop_debug_group(&mut self, context: &global::Global) { + context.compute_pass_pop_debug_group(self) + } + + fn insert_debug_marker(&mut self, context: &global::Global, label: &str, color: u32) { + context.compute_pass_insert_debug_marker(self, label, color) + } + + fn write_timestamp( + &mut self, + context: &global::Global, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError> { + context.compute_pass_write_timestamp(self, query_set_id, query_index) + } + + fn begin_pipeline_statistics_query( + &mut self, + context: &global::Global, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError> { + context.compute_pass_begin_pipeline_statistics_query(self, query_set_id, query_index) + } + + fn end_pipeline_statistics_query(&mut self, context: &global::Global) { + context.compute_pass_end_pipeline_statistics_query(self) + } +} diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index 5730960e5..5159d6fa8 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -5,6 +5,7 @@ mod clear; mod compute; mod compute_command; mod draw; +mod dyn_compute_pass; mod memory_init; mod query; mod render; @@ -14,8 +15,8 @@ use std::sync::Arc; pub(crate) use self::clear::clear_texture; pub use self::{ - bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, query::*, - render::*, transfer::*, + bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, + dyn_compute_pass::DynComputePass, query::*, render::*, transfer::*, }; pub(crate) use allocator::CommandAllocator; diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 65a3e3997..f03e4a569 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -24,9 +24,11 @@ use std::{ sync::Arc, }; use wgc::{ - command::{bundle_ffi::*, compute_commands::*, render_commands::*}, + command::{bundle_ffi::*, render_commands::*}, device::DeviceLostClosure, - id::{CommandEncoderId, TextureViewId}, + gfx_select, + id::CommandEncoderId, + id::TextureViewId, pipeline::CreateShaderModuleError, }; use wgt::WasmNotSendSync; @@ -476,6 +478,12 @@ impl Queue { } } +#[derive(Debug)] +pub struct ComputePass { + pass: Box, + error_sink: ErrorSink, +} + #[derive(Debug)] pub struct CommandEncoder { error_sink: ErrorSink, @@ -514,7 +522,7 @@ impl crate::Context for ContextWgpuCore { type CommandEncoderId = wgc::id::CommandEncoderId; type CommandEncoderData = CommandEncoder; type ComputePassId = Unused; - type ComputePassData = wgc::command::ComputePass; + type ComputePassData = ComputePass; type RenderPassId = Unused; type RenderPassData = wgc::command::RenderPass; type CommandBufferId = wgc::id::CommandBufferId; @@ -1841,7 +1849,7 @@ impl crate::Context for ContextWgpuCore { fn command_encoder_begin_compute_pass( &self, encoder: &Self::CommandEncoderId, - _encoder_data: &Self::CommandEncoderData, + encoder_data: &Self::CommandEncoderData, desc: &ComputePassDescriptor<'_>, ) -> (Self::ComputePassId, Self::ComputePassData) { let timestamp_writes = @@ -1852,15 +1860,16 @@ impl crate::Context for ContextWgpuCore { beginning_of_pass_write_index: tw.beginning_of_pass_write_index, end_of_pass_write_index: tw.end_of_pass_write_index, }); + ( Unused, - wgc::command::ComputePass::new( - *encoder, - &wgc::command::ComputePassDescriptor { + Self::ComputePassData { + pass: gfx_select!(encoder => self.0.command_encoder_create_compute_pass_dyn(*encoder, &wgc::command::ComputePassDescriptor { label: desc.label.map(Borrowed), timestamp_writes: timestamp_writes.as_ref(), - }, - ), + })), + error_sink: encoder_data.error_sink.clone(), + }, ) } @@ -1871,9 +1880,7 @@ impl crate::Context for ContextWgpuCore { _pass: &mut Self::ComputePassId, pass_data: &mut Self::ComputePassData, ) { - if let Err(cause) = wgc::gfx_select!( - encoder => self.0.command_encoder_run_compute_pass(*encoder, pass_data) - ) { + if let Err(cause) = pass_data.pass.run(&self.0) { let name = wgc::gfx_select!(encoder => self.0.command_buffer_label(encoder.into_command_buffer_id())); self.handle_error( &encoder_data.error_sink, @@ -2336,7 +2343,9 @@ impl crate::Context for ContextWgpuCore { pipeline: &Self::ComputePipelineId, _pipeline_data: &Self::ComputePipelineData, ) { - wgpu_compute_pass_set_pipeline(pass_data, *pipeline) + if let Err(cause) = pass_data.pass.set_pipeline(&self.0, *pipeline) { + self.handle_error_nolabel(&pass_data.error_sink, cause, "ComputePass::set_pipeline"); + } } fn compute_pass_set_bind_group( @@ -2348,7 +2357,12 @@ impl crate::Context for ContextWgpuCore { _bind_group_data: &Self::BindGroupData, offsets: &[wgt::DynamicOffset], ) { - wgpu_compute_pass_set_bind_group(pass_data, index, *bind_group, offsets); + if let Err(cause) = pass_data + .pass + .set_bind_group(&self.0, index, *bind_group, offsets) + { + self.handle_error_nolabel(&pass_data.error_sink, cause, "ComputePass::set_bind_group"); + } } fn compute_pass_set_push_constants( @@ -2358,7 +2372,7 @@ impl crate::Context for ContextWgpuCore { offset: u32, data: &[u8], ) { - wgpu_compute_pass_set_push_constant(pass_data, offset, data); + pass_data.pass.set_push_constant(&self.0, offset, data); } fn compute_pass_insert_debug_marker( @@ -2367,7 +2381,7 @@ impl crate::Context for ContextWgpuCore { pass_data: &mut Self::ComputePassData, label: &str, ) { - wgpu_compute_pass_insert_debug_marker(pass_data, label, 0); + pass_data.pass.insert_debug_marker(&self.0, label, 0); } fn compute_pass_push_debug_group( @@ -2376,7 +2390,7 @@ impl crate::Context for ContextWgpuCore { pass_data: &mut Self::ComputePassData, group_label: &str, ) { - wgpu_compute_pass_push_debug_group(pass_data, group_label, 0); + pass_data.pass.push_debug_group(&self.0, group_label, 0); } fn compute_pass_pop_debug_group( @@ -2384,7 +2398,7 @@ impl crate::Context for ContextWgpuCore { _pass: &mut Self::ComputePassId, pass_data: &mut Self::ComputePassData, ) { - wgpu_compute_pass_pop_debug_group(pass_data); + pass_data.pass.pop_debug_group(&self.0); } fn compute_pass_write_timestamp( @@ -2395,7 +2409,12 @@ impl crate::Context for ContextWgpuCore { _query_set_data: &Self::QuerySetData, query_index: u32, ) { - wgpu_compute_pass_write_timestamp(pass_data, *query_set, query_index) + if let Err(cause) = pass_data + .pass + .write_timestamp(&self.0, *query_set, query_index) + { + self.handle_error_nolabel(&pass_data.error_sink, cause, "ComputePass::write_timestamp"); + } } fn compute_pass_begin_pipeline_statistics_query( @@ -2406,7 +2425,17 @@ impl crate::Context for ContextWgpuCore { _query_set_data: &Self::QuerySetData, query_index: u32, ) { - wgpu_compute_pass_begin_pipeline_statistics_query(pass_data, *query_set, query_index) + if let Err(cause) = + pass_data + .pass + .begin_pipeline_statistics_query(&self.0, *query_set, query_index) + { + self.handle_error_nolabel( + &pass_data.error_sink, + cause, + "ComputePass::begin_pipeline_statistics_query", + ); + } } fn compute_pass_end_pipeline_statistics_query( @@ -2414,7 +2443,7 @@ impl crate::Context for ContextWgpuCore { _pass: &mut Self::ComputePassId, pass_data: &mut Self::ComputePassData, ) { - wgpu_compute_pass_end_pipeline_statistics_query(pass_data) + pass_data.pass.end_pipeline_statistics_query(&self.0); } fn compute_pass_dispatch_workgroups( @@ -2425,7 +2454,7 @@ impl crate::Context for ContextWgpuCore { y: u32, z: u32, ) { - wgpu_compute_pass_dispatch_workgroups(pass_data, x, y, z) + pass_data.pass.dispatch_workgroups(&self.0, x, y, z); } fn compute_pass_dispatch_workgroups_indirect( @@ -2436,7 +2465,17 @@ impl crate::Context for ContextWgpuCore { _indirect_buffer_data: &Self::BufferData, indirect_offset: wgt::BufferAddress, ) { - wgpu_compute_pass_dispatch_workgroups_indirect(pass_data, *indirect_buffer, indirect_offset) + if let Err(cause) = + pass_data + .pass + .dispatch_workgroups_indirect(&self.0, *indirect_buffer, indirect_offset) + { + self.handle_error_nolabel( + &pass_data.error_sink, + cause, + "ComputePass::dispatch_workgroups_indirect", + ); + } } fn render_bundle_encoder_set_pipeline( diff --git a/wgpu/src/lib.rs b/wgpu/src/lib.rs index 5cbfc04bf..ed5694173 100644 --- a/wgpu/src/lib.rs +++ b/wgpu/src/lib.rs @@ -4538,7 +4538,7 @@ impl<'a> ComputePass<'a> { pub fn set_bind_group( &mut self, index: u32, - bind_group: &'a BindGroup, + bind_group: &BindGroup, offsets: &[DynamicOffset], ) { DynContext::compute_pass_set_bind_group( @@ -4553,7 +4553,7 @@ impl<'a> ComputePass<'a> { } /// Sets the active compute pipeline. - pub fn set_pipeline(&mut self, pipeline: &'a ComputePipeline) { + pub fn set_pipeline(&mut self, pipeline: &ComputePipeline) { DynContext::compute_pass_set_pipeline( &*self.parent.context, &mut self.id, @@ -4611,7 +4611,7 @@ impl<'a> ComputePass<'a> { /// The structure expected in `indirect_buffer` must conform to [`DispatchIndirectArgs`](crate::util::DispatchIndirectArgs). pub fn dispatch_workgroups_indirect( &mut self, - indirect_buffer: &'a Buffer, + indirect_buffer: &Buffer, indirect_offset: BufferAddress, ) { DynContext::compute_pass_dispatch_workgroups_indirect(