Remove lifetime constraints from wgpu::ComputePass methods (#5570)

* basic test setup

* remove lifetime and drop resources on test - test fails now just as expected

* compute pass recording is now hub dependent (needs gfx_select)

* compute pass recording now bumps reference count of uses resources directly on recording

TODO:
* bind groups don't work because the Binder gets an id only
* wgpu level error handling is missing

* simplify compute pass state flush, compute pass execution no longer needs to lock bind_group storage

* wgpu sided error handling

* make ComputePass hal dependent, removing command cast hack. Introduce DynComputePass on wgpu side

* remove stray repr(C)

* changelog entry

* fix deno issues -> move DynComputePass into wgc

* split out resources setup from test
This commit is contained in:
Andreas Reich 2024-05-14 22:05:17 +02:00 committed by GitHub
parent 00456cfb37
commit 77a83fb0dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 631 additions and 208 deletions

View File

@ -41,6 +41,14 @@ Bottom level categories:
### Major Changes
#### Remove lifetime bounds on `wgpu::ComputePass`
TODO(wumpf): This is still work in progress. Should write a bit more about it. Also will very likely extend to `wgpu::RenderPass` before release.
`wgpu::ComputePass` recording methods (e.g. `wgpu::ComputePass:set_render_pipeline`) no longer impose a lifetime constraint passed in resources.
By @wumpf in [#5569](https://github.com/gfx-rs/wgpu/pull/5569).
#### Querying shader compilation errors
Wgpu now supports querying [shader compilation info](https://www.w3.org/TR/webgpu/#dom-gpushadermodule-getcompilationinfo).

View File

@ -254,13 +254,14 @@ pub fn op_webgpu_command_encoder_begin_compute_pass(
None
};
let instance = state.borrow::<super::Instance>();
let command_encoder = &command_encoder_resource.1;
let descriptor = wgpu_core::command::ComputePassDescriptor {
label: Some(label),
timestamp_writes: timestamp_writes.as_ref(),
};
let compute_pass =
wgpu_core::command::ComputePass::new(command_encoder_resource.1, &descriptor);
let compute_pass = gfx_select!(command_encoder => instance.command_encoder_create_compute_pass_dyn(*command_encoder, &descriptor));
let rid = state
.resource_table

View File

@ -10,7 +10,9 @@ use std::cell::RefCell;
use super::error::WebGpuResult;
pub(crate) struct WebGpuComputePass(pub(crate) RefCell<wgpu_core::command::ComputePass>);
pub(crate) struct WebGpuComputePass(
pub(crate) RefCell<Box<dyn wgpu_core::command::DynComputePass>>,
);
impl Resource for WebGpuComputePass {
fn name(&self) -> Cow<str> {
"webGPUComputePass".into()
@ -31,10 +33,10 @@ pub fn op_webgpu_compute_pass_set_pipeline(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;
wgpu_core::command::compute_commands::wgpu_compute_pass_set_pipeline(
&mut compute_pass_resource.0.borrow_mut(),
compute_pipeline_resource.1,
);
compute_pass_resource
.0
.borrow_mut()
.set_pipeline(state.borrow(), compute_pipeline_resource.1)?;
Ok(WebGpuResult::empty())
}
@ -52,12 +54,10 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;
wgpu_core::command::compute_commands::wgpu_compute_pass_dispatch_workgroups(
&mut compute_pass_resource.0.borrow_mut(),
x,
y,
z,
);
compute_pass_resource
.0
.borrow_mut()
.dispatch_workgroups(state.borrow(), x, y, z);
Ok(WebGpuResult::empty())
}
@ -77,11 +77,10 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups_indirect(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;
wgpu_core::command::compute_commands::wgpu_compute_pass_dispatch_workgroups_indirect(
&mut compute_pass_resource.0.borrow_mut(),
buffer_resource.1,
indirect_offset,
);
compute_pass_resource
.0
.borrow_mut()
.dispatch_workgroups_indirect(state.borrow(), buffer_resource.1, indirect_offset)?;
Ok(WebGpuResult::empty())
}
@ -90,24 +89,15 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups_indirect(
#[serde]
pub fn op_webgpu_compute_pass_end(
state: &mut OpState,
#[smi] command_encoder_rid: ResourceId,
#[smi] compute_pass_rid: ResourceId,
) -> Result<WebGpuResult, AnyError> {
let command_encoder_resource =
state
.resource_table
.get::<super::command_encoder::WebGpuCommandEncoder>(command_encoder_rid)?;
let command_encoder = command_encoder_resource.1;
let compute_pass_resource = state
.resource_table
.take::<WebGpuComputePass>(compute_pass_rid)?;
let compute_pass = &compute_pass_resource.0.borrow();
let instance = state.borrow::<super::Instance>();
gfx_ok!(command_encoder => instance.command_encoder_run_compute_pass(
command_encoder,
compute_pass
))
compute_pass_resource.0.borrow_mut().run(state.borrow())?;
Ok(WebGpuResult::empty())
}
#[op2]
@ -137,12 +127,12 @@ pub fn op_webgpu_compute_pass_set_bind_group(
let dynamic_offsets_data: &[u32] = &dynamic_offsets_data[start..start + len];
wgpu_core::command::compute_commands::wgpu_compute_pass_set_bind_group(
&mut compute_pass_resource.0.borrow_mut(),
compute_pass_resource.0.borrow_mut().set_bind_group(
state.borrow(),
index,
bind_group_resource.1,
dynamic_offsets_data,
);
)?;
Ok(WebGpuResult::empty())
}
@ -158,8 +148,8 @@ pub fn op_webgpu_compute_pass_push_debug_group(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;
wgpu_core::command::compute_commands::wgpu_compute_pass_push_debug_group(
&mut compute_pass_resource.0.borrow_mut(),
compute_pass_resource.0.borrow_mut().push_debug_group(
state.borrow(),
group_label,
0, // wgpu#975
);
@ -177,9 +167,10 @@ pub fn op_webgpu_compute_pass_pop_debug_group(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;
wgpu_core::command::compute_commands::wgpu_compute_pass_pop_debug_group(
&mut compute_pass_resource.0.borrow_mut(),
);
compute_pass_resource
.0
.borrow_mut()
.pop_debug_group(state.borrow());
Ok(WebGpuResult::empty())
}
@ -195,8 +186,8 @@ pub fn op_webgpu_compute_pass_insert_debug_marker(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;
wgpu_core::command::compute_commands::wgpu_compute_pass_insert_debug_marker(
&mut compute_pass_resource.0.borrow_mut(),
compute_pass_resource.0.borrow_mut().insert_debug_marker(
state.borrow(),
marker_label,
0, // wgpu#975
);

View File

@ -0,0 +1,174 @@
//! Tests that compute passes take ownership of resources that are associated with.
//! I.e. once a resource is passed in to a compute pass, it can be dropped.
//!
//! TODO: Test doesn't check on timestamp writes & pipeline statistics queries yet.
//! (Not important as long as they are lifetime constrained to the command encoder,
//! but once we lift this constraint, we should add tests for this as well!)
//! TODO: Also should test resource ownership for:
//! * write_timestamp
//! * begin_pipeline_statistics_query
use std::num::NonZeroU64;
use wgpu::util::DeviceExt as _;
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};
const SHADER_SRC: &str = "
@group(0) @binding(0)
var<storage, read_write> buffer: array<vec4f>;
@compute @workgroup_size(1, 1, 1) fn main() {
buffer[0] *= 2.0;
}
";
#[gpu_test]
static COMPUTE_PASS_RESOURCE_OWNERSHIP: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default().test_features_limits())
.run_async(compute_pass_resource_ownership);
async fn compute_pass_resource_ownership(ctx: TestingContext) {
let ResourceSetup {
gpu_buffer,
cpu_buffer,
buffer_size,
indirect_buffer,
bind_group,
pipeline,
} = resource_setup(&ctx);
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("encoder"),
});
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute_pass"),
timestamp_writes: None, // TODO: See description above, we should test this as well once we lift the lifetime bound.
});
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups_indirect(&indirect_buffer, 0);
// Now drop all resources we set. Then do a device poll to make sure the resources are really not dropped too early, no matter what.
drop(pipeline);
drop(bind_group);
drop(indirect_buffer);
ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();
}
// Ensure that the compute pass still executed normally.
encoder.copy_buffer_to_buffer(&gpu_buffer, 0, &cpu_buffer, 0, buffer_size);
ctx.queue.submit([encoder.finish()]);
cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ());
ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();
let data = cpu_buffer.slice(..).get_mapped_range();
let floats: &[f32] = bytemuck::cast_slice(&data);
assert_eq!(floats, [2.0, 4.0, 6.0, 8.0]);
}
// Setup ------------------------------------------------------------
struct ResourceSetup {
gpu_buffer: wgpu::Buffer,
cpu_buffer: wgpu::Buffer,
buffer_size: u64,
indirect_buffer: wgpu::Buffer,
bind_group: wgpu::BindGroup,
pipeline: wgpu::ComputePipeline,
}
fn resource_setup(ctx: &TestingContext) -> ResourceSetup {
let sm = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("shader"),
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
});
let buffer_size = 4 * std::mem::size_of::<f32>() as u64;
let bgl = ctx
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("bind_group_layout"),
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: NonZeroU64::new(buffer_size),
},
count: None,
}],
});
let gpu_buffer = ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gpu_buffer"),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
contents: bytemuck::bytes_of(&[1.0_f32, 2.0, 3.0, 4.0]),
});
let cpu_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("cpu_buffer"),
size: buffer_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let indirect_buffer = ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gpu_buffer"),
usage: wgpu::BufferUsages::INDIRECT,
contents: wgpu::util::DispatchIndirectArgs { x: 1, y: 1, z: 1 }.as_bytes(),
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bind_group"),
layout: &bgl,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: gpu_buffer.as_entire_binding(),
}],
});
let pipeline_layout = ctx
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("pipeline_layout"),
bind_group_layouts: &[&bgl],
push_constant_ranges: &[],
});
let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("pipeline"),
layout: Some(&pipeline_layout),
module: &sm,
entry_point: "main",
compilation_options: Default::default(),
});
ResourceSetup {
gpu_buffer,
cpu_buffer,
buffer_size,
indirect_buffer,
bind_group,
pipeline,
}
}

View File

@ -11,6 +11,7 @@ mod buffer;
mod buffer_copy;
mod buffer_usages;
mod clear_texture;
mod compute_pass_resource_ownership;
mod create_surface_error;
mod device;
mod encoder;

View File

@ -4,7 +4,6 @@ use crate::{
binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
device::SHADER_STAGE_COUNT,
hal_api::HalApi,
id::BindGroupId,
pipeline::LateSizedBufferGroup,
resource::Resource,
};
@ -359,11 +358,11 @@ impl<A: HalApi> Binder<A> {
&self.payloads[bind_range]
}
pub(super) fn list_active(&self) -> impl Iterator<Item = BindGroupId> + '_ {
pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup<A>>> + '_ {
let payloads = &self.payloads;
self.manager
.list_active()
.map(move |index| payloads[index].group.as_ref().unwrap().as_info().id())
.map(move |index| payloads[index].group.as_ref().unwrap())
}
pub(super) fn invalid_mask(&self) -> BindGroupMask {

View File

@ -1,29 +1,23 @@
use crate::command::compute_command::{ArcComputeCommand, ComputeCommand};
use crate::device::DeviceError;
use crate::resource::Resource;
use crate::snatch::SnatchGuard;
use crate::track::TrackerIndex;
use crate::{
binding_model::{
BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
},
binding_model::{BindError, LateMinBufferBindingSizeMismatch, PushConstantUploadError},
command::{
bind::Binder,
compute_command::{ArcComputeCommand, ComputeCommand},
end_pipeline_statistics_query,
memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
BasePass, BasePassRef, BindGroupStateChange, CommandBuffer, CommandEncoderError,
CommandEncoderStatus, MapPassErr, PassErrorScope, QueryUseError, StateChange,
},
device::{MissingDownlevelFlags, MissingFeatures},
device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
error::{ErrorFormatter, PrettyError},
global::Global,
hal_api::HalApi,
hal_label, id,
id::DeviceId,
hal_label,
id::{self, DeviceId},
init_tracker::MemoryInitKind,
resource::{self},
storage::Storage,
track::{Tracker, UsageConflict, UsageScope},
resource::{self, Resource},
snatch::SnatchGuard,
track::{Tracker, TrackerIndex, UsageConflict, UsageScope},
validation::{check_buffer_usage, MissingBufferUsageError},
Label,
};
@ -35,27 +29,25 @@ use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
use wgt::{BufferAddress, DynamicOffset};
use std::sync::Arc;
use std::{fmt, mem, str};
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct ComputePass {
base: BasePass<ComputeCommand>,
pub struct ComputePass<A: HalApi> {
base: BasePass<ArcComputeCommand<A>>,
parent_id: id::CommandEncoderId,
timestamp_writes: Option<ComputePassTimestampWrites>,
// Resource binding dedupe state.
#[cfg_attr(feature = "serde", serde(skip))]
current_bind_groups: BindGroupStateChange,
#[cfg_attr(feature = "serde", serde(skip))]
current_pipeline: StateChange<id::ComputePipelineId>,
}
impl ComputePass {
pub fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self {
impl<A: HalApi> ComputePass<A> {
fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self {
Self {
base: BasePass::new(&desc.label),
base: BasePass::<ArcComputeCommand<A>>::new(&desc.label),
parent_id,
timestamp_writes: desc.timestamp_writes.cloned(),
@ -67,30 +59,15 @@ impl ComputePass {
pub fn parent_id(&self) -> id::CommandEncoderId {
self.parent_id
}
#[cfg(feature = "trace")]
pub fn into_command(self) -> crate::device::trace::Command {
crate::device::trace::Command::RunComputePass {
base: self.base,
timestamp_writes: self.timestamp_writes,
}
}
}
impl fmt::Debug for ComputePass {
impl<A: HalApi> fmt::Debug for ComputePass<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}",
self.parent_id,
self.base.commands.len(),
self.base.dynamic_offsets.len()
)
write!(f, "ComputePass {{ encoder_id: {:?} }}", self.parent_id)
}
}
/// Describes the writing of timestamp values in a compute pass.
#[repr(C)]
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ComputePassTimestampWrites {
@ -253,22 +230,19 @@ impl<'a, A: HalApi> State<'a, A> {
&mut self,
raw_encoder: &mut A::CommandEncoder,
base_trackers: &mut Tracker<A>,
bind_group_guard: &Storage<BindGroup<A>>,
indirect_buffer: Option<TrackerIndex>,
snatch_guard: &SnatchGuard,
) -> Result<(), UsageConflict> {
for id in self.binder.list_active() {
unsafe { self.scope.merge_bind_group(&bind_group_guard[id].used)? };
for bind_group in self.binder.list_active() {
unsafe { self.scope.merge_bind_group(&bind_group.used)? };
// Note: stateless trackers are not merged: the lifetime reference
// is held to the bind group itself.
}
for id in self.binder.list_active() {
for bind_group in self.binder.list_active() {
unsafe {
base_trackers.set_and_remove_from_usage_scope_sparse(
&mut self.scope,
&bind_group_guard[id].used,
)
base_trackers
.set_and_remove_from_usage_scope_sparse(&mut self.scope, &bind_group.used)
}
}
@ -286,17 +260,31 @@ impl<'a, A: HalApi> State<'a, A> {
}
}
// Common routines between render/compute
// Running the compute pass.
impl Global {
pub fn command_encoder_create_compute_pass<A: HalApi>(
&self,
parent_id: id::CommandEncoderId,
desc: &ComputePassDescriptor,
) -> ComputePass<A> {
ComputePass::new(parent_id, desc)
}
pub fn command_encoder_create_compute_pass_dyn<A: HalApi>(
&self,
parent_id: id::CommandEncoderId,
desc: &ComputePassDescriptor,
) -> Box<dyn super::DynComputePass> {
Box::new(ComputePass::<A>::new(parent_id, desc))
}
pub fn command_encoder_run_compute_pass<A: HalApi>(
&self,
encoder_id: id::CommandEncoderId,
pass: &ComputePass,
pass: &ComputePass<A>,
) -> Result<(), ComputePassError> {
// TODO: This should go directly to `command_encoder_run_compute_pass_impl` by means of storing `ArcComputeCommand` internally.
self.command_encoder_run_compute_pass_with_unresolved_commands::<A>(
encoder_id,
self.command_encoder_run_compute_pass_impl(
pass.parent_id,
pass.base.as_ref(),
pass.timestamp_writes.as_ref(),
)
@ -377,7 +365,6 @@ impl Global {
*status = CommandEncoderStatus::Error;
let raw = encoder.open().map_pass_err(pass_scope)?;
let bind_group_guard = hub.bind_groups.read();
let query_set_guard = hub.query_sets.read();
let mut state = State {
@ -642,13 +629,7 @@ impl Global {
state.is_ready().map_pass_err(scope)?;
state
.flush_states(
raw,
&mut intermediate_trackers,
&*bind_group_guard,
None,
&snatch_guard,
)
.flush_states(raw, &mut intermediate_trackers, None, &snatch_guard)
.map_pass_err(scope)?;
let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension;
@ -721,7 +702,6 @@ impl Global {
.flush_states(
raw,
&mut intermediate_trackers,
&*bind_group_guard,
Some(buffer.as_info().tracker_index()),
&snatch_guard,
)
@ -845,18 +825,15 @@ impl Global {
}
}
pub mod compute_commands {
use super::{ComputeCommand, ComputePass};
use crate::id;
use std::convert::TryInto;
use wgt::{BufferAddress, DynamicOffset};
pub fn wgpu_compute_pass_set_bind_group(
pass: &mut ComputePass,
// Recording a compute pass.
impl Global {
pub fn compute_pass_set_bind_group<A: HalApi>(
&self,
pass: &mut ComputePass<A>,
index: u32,
bind_group_id: id::BindGroupId,
offsets: &[DynamicOffset],
) {
) -> Result<(), ComputePassError> {
let redundant = pass.current_bind_groups.set_and_check_redundant(
bind_group_id,
index,
@ -865,30 +842,62 @@ pub mod compute_commands {
);
if redundant {
return;
return Ok(());
}
pass.base.commands.push(ComputeCommand::SetBindGroup {
let hub = A::hub(self);
let bind_group = hub
.bind_groups
.read()
.get(bind_group_id)
.map_err(|_| ComputePassError {
scope: PassErrorScope::SetBindGroup(bind_group_id),
inner: ComputePassErrorInner::InvalidBindGroup(index),
})?
.clone();
pass.base.commands.push(ArcComputeCommand::SetBindGroup {
index,
num_dynamic_offsets: offsets.len(),
bind_group_id,
bind_group,
});
Ok(())
}
pub fn wgpu_compute_pass_set_pipeline(
pass: &mut ComputePass,
pub fn compute_pass_set_pipeline<A: HalApi>(
&self,
pass: &mut ComputePass<A>,
pipeline_id: id::ComputePipelineId,
) {
) -> Result<(), ComputePassError> {
if pass.current_pipeline.set_and_check_redundant(pipeline_id) {
return;
return Ok(());
}
let hub = A::hub(self);
let pipeline = hub
.compute_pipelines
.read()
.get(pipeline_id)
.map_err(|_| ComputePassError {
scope: PassErrorScope::SetPipelineCompute(pipeline_id),
inner: ComputePassErrorInner::InvalidPipeline(pipeline_id),
})?
.clone();
pass.base
.commands
.push(ComputeCommand::SetPipeline(pipeline_id));
.push(ArcComputeCommand::SetPipeline(pipeline));
Ok(())
}
pub fn wgpu_compute_pass_set_push_constant(pass: &mut ComputePass, offset: u32, data: &[u8]) {
pub fn compute_pass_set_push_constant<A: HalApi>(
&self,
pass: &mut ComputePass<A>,
offset: u32,
data: &[u8],
) {
assert_eq!(
offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
0,
@ -901,92 +910,156 @@ pub mod compute_commands {
);
let value_offset = pass.base.push_constant_data.len().try_into().expect(
"Ran out of push constant space. Don't set 4gb of push constants per ComputePass.",
);
); // TODO: make this an error that can be handled
pass.base.push_constant_data.extend(
data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
.map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
);
pass.base.commands.push(ComputeCommand::SetPushConstant {
pass.base
.commands
.push(ArcComputeCommand::<A>::SetPushConstant {
offset,
size_bytes: data.len() as u32,
values_offset: value_offset,
});
}
pub fn wgpu_compute_pass_dispatch_workgroups(
pass: &mut ComputePass,
pub fn compute_pass_dispatch_workgroups<A: HalApi>(
&self,
pass: &mut ComputePass<A>,
groups_x: u32,
groups_y: u32,
groups_z: u32,
) {
pass.base
.commands
.push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
pass.base.commands.push(ArcComputeCommand::<A>::Dispatch([
groups_x, groups_y, groups_z,
]));
}
pub fn wgpu_compute_pass_dispatch_workgroups_indirect(
pass: &mut ComputePass,
pub fn compute_pass_dispatch_workgroups_indirect<A: HalApi>(
&self,
pass: &mut ComputePass<A>,
buffer_id: id::BufferId,
offset: BufferAddress,
) {
) -> Result<(), ComputePassError> {
let hub = A::hub(self);
let buffer = hub
.buffers
.read()
.get(buffer_id)
.map_err(|_| ComputePassError {
scope: PassErrorScope::Dispatch {
indirect: true,
pipeline: pass.current_pipeline.last_state,
},
inner: ComputePassErrorInner::InvalidBuffer(buffer_id),
})?
.clone();
pass.base
.commands
.push(ComputeCommand::DispatchIndirect { buffer_id, offset });
.push(ArcComputeCommand::<A>::DispatchIndirect { buffer, offset });
Ok(())
}
pub fn wgpu_compute_pass_push_debug_group(pass: &mut ComputePass, label: &str, color: u32) {
pub fn compute_pass_push_debug_group<A: HalApi>(
&self,
pass: &mut ComputePass<A>,
label: &str,
color: u32,
) {
let bytes = label.as_bytes();
pass.base.string_data.extend_from_slice(bytes);
pass.base.commands.push(ComputeCommand::PushDebugGroup {
pass.base
.commands
.push(ArcComputeCommand::<A>::PushDebugGroup {
color,
len: bytes.len(),
});
}
pub fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) {
pass.base.commands.push(ComputeCommand::PopDebugGroup);
pub fn compute_pass_pop_debug_group<A: HalApi>(&self, pass: &mut ComputePass<A>) {
pass.base
.commands
.push(ArcComputeCommand::<A>::PopDebugGroup);
}
pub fn wgpu_compute_pass_insert_debug_marker(pass: &mut ComputePass, label: &str, color: u32) {
pub fn compute_pass_insert_debug_marker<A: HalApi>(
&self,
pass: &mut ComputePass<A>,
label: &str,
color: u32,
) {
let bytes = label.as_bytes();
pass.base.string_data.extend_from_slice(bytes);
pass.base.commands.push(ComputeCommand::InsertDebugMarker {
pass.base
.commands
.push(ArcComputeCommand::<A>::InsertDebugMarker {
color,
len: bytes.len(),
});
}
pub fn wgpu_compute_pass_write_timestamp(
pass: &mut ComputePass,
pub fn compute_pass_write_timestamp<A: HalApi>(
&self,
pass: &mut ComputePass<A>,
query_set_id: id::QuerySetId,
query_index: u32,
) {
pass.base.commands.push(ComputeCommand::WriteTimestamp {
query_set_id,
) -> Result<(), ComputePassError> {
let hub = A::hub(self);
let query_set = hub
.query_sets
.read()
.get(query_set_id)
.map_err(|_| ComputePassError {
scope: PassErrorScope::WriteTimestamp,
inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
})?
.clone();
pass.base.commands.push(ArcComputeCommand::WriteTimestamp {
query_set,
query_index,
});
Ok(())
}
pub fn wgpu_compute_pass_begin_pipeline_statistics_query(
pass: &mut ComputePass,
pub fn compute_pass_begin_pipeline_statistics_query<A: HalApi>(
&self,
pass: &mut ComputePass<A>,
query_set_id: id::QuerySetId,
query_index: u32,
) {
) -> Result<(), ComputePassError> {
let hub = A::hub(self);
let query_set = hub
.query_sets
.read()
.get(query_set_id)
.map_err(|_| ComputePassError {
scope: PassErrorScope::WriteTimestamp,
inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
})?
.clone();
pass.base
.commands
.push(ComputeCommand::BeginPipelineStatisticsQuery {
query_set_id,
.push(ArcComputeCommand::BeginPipelineStatisticsQuery {
query_set,
query_index,
});
Ok(())
}
pub fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) {
pub fn compute_pass_end_pipeline_statistics_query<A: HalApi>(&self, pass: &mut ComputePass<A>) {
pass.base
.commands
.push(ComputeCommand::EndPipelineStatisticsQuery);
.push(ArcComputeCommand::<A>::EndPipelineStatisticsQuery);
}
}

View File

@ -0,0 +1,136 @@
use wgt::WasmNotSendSync;
use crate::{global, hal_api::HalApi, id};
use super::{ComputePass, ComputePassError};
/// Trait for type erasing ComputePass.
// TODO(#5124): wgpu-core's ComputePass trait should not be hal type dependent.
// Practically speaking this allows us merge gfx_select with type erasure:
// The alternative would be to introduce ComputePassId which then first needs to be looked up and then dispatch via gfx_select.
pub trait DynComputePass: std::fmt::Debug + WasmNotSendSync {
fn run(&mut self, context: &global::Global) -> Result<(), ComputePassError>;
fn set_bind_group(
&mut self,
context: &global::Global,
index: u32,
bind_group_id: id::BindGroupId,
offsets: &[wgt::DynamicOffset],
) -> Result<(), ComputePassError>;
fn set_pipeline(
&mut self,
context: &global::Global,
pipeline_id: id::ComputePipelineId,
) -> Result<(), ComputePassError>;
fn set_push_constant(&mut self, context: &global::Global, offset: u32, data: &[u8]);
fn dispatch_workgroups(
&mut self,
context: &global::Global,
groups_x: u32,
groups_y: u32,
groups_z: u32,
);
fn dispatch_workgroups_indirect(
&mut self,
context: &global::Global,
buffer_id: id::BufferId,
offset: wgt::BufferAddress,
) -> Result<(), ComputePassError>;
fn push_debug_group(&mut self, context: &global::Global, label: &str, color: u32);
fn pop_debug_group(&mut self, context: &global::Global);
fn insert_debug_marker(&mut self, context: &global::Global, label: &str, color: u32);
fn write_timestamp(
&mut self,
context: &global::Global,
query_set_id: id::QuerySetId,
query_index: u32,
) -> Result<(), ComputePassError>;
fn begin_pipeline_statistics_query(
&mut self,
context: &global::Global,
query_set_id: id::QuerySetId,
query_index: u32,
) -> Result<(), ComputePassError>;
fn end_pipeline_statistics_query(&mut self, context: &global::Global);
}
impl<A: HalApi> DynComputePass for ComputePass<A> {
fn run(&mut self, context: &global::Global) -> Result<(), ComputePassError> {
context.command_encoder_run_compute_pass(self)
}
fn set_bind_group(
&mut self,
context: &global::Global,
index: u32,
bind_group_id: id::BindGroupId,
offsets: &[wgt::DynamicOffset],
) -> Result<(), ComputePassError> {
context.compute_pass_set_bind_group(self, index, bind_group_id, offsets)
}
fn set_pipeline(
&mut self,
context: &global::Global,
pipeline_id: id::ComputePipelineId,
) -> Result<(), ComputePassError> {
context.compute_pass_set_pipeline(self, pipeline_id)
}
fn set_push_constant(&mut self, context: &global::Global, offset: u32, data: &[u8]) {
context.compute_pass_set_push_constant(self, offset, data)
}
fn dispatch_workgroups(
&mut self,
context: &global::Global,
groups_x: u32,
groups_y: u32,
groups_z: u32,
) {
context.compute_pass_dispatch_workgroups(self, groups_x, groups_y, groups_z)
}
fn dispatch_workgroups_indirect(
&mut self,
context: &global::Global,
buffer_id: id::BufferId,
offset: wgt::BufferAddress,
) -> Result<(), ComputePassError> {
context.compute_pass_dispatch_workgroups_indirect(self, buffer_id, offset)
}
fn push_debug_group(&mut self, context: &global::Global, label: &str, color: u32) {
context.compute_pass_push_debug_group(self, label, color)
}
fn pop_debug_group(&mut self, context: &global::Global) {
context.compute_pass_pop_debug_group(self)
}
fn insert_debug_marker(&mut self, context: &global::Global, label: &str, color: u32) {
context.compute_pass_insert_debug_marker(self, label, color)
}
fn write_timestamp(
&mut self,
context: &global::Global,
query_set_id: id::QuerySetId,
query_index: u32,
) -> Result<(), ComputePassError> {
context.compute_pass_write_timestamp(self, query_set_id, query_index)
}
fn begin_pipeline_statistics_query(
&mut self,
context: &global::Global,
query_set_id: id::QuerySetId,
query_index: u32,
) -> Result<(), ComputePassError> {
context.compute_pass_begin_pipeline_statistics_query(self, query_set_id, query_index)
}
fn end_pipeline_statistics_query(&mut self, context: &global::Global) {
context.compute_pass_end_pipeline_statistics_query(self)
}
}

View File

@ -5,6 +5,7 @@ mod clear;
mod compute;
mod compute_command;
mod draw;
mod dyn_compute_pass;
mod memory_init;
mod query;
mod render;
@ -14,8 +15,8 @@ use std::sync::Arc;
pub(crate) use self::clear::clear_texture;
pub use self::{
bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, query::*,
render::*, transfer::*,
bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*,
dyn_compute_pass::DynComputePass, query::*, render::*, transfer::*,
};
pub(crate) use allocator::CommandAllocator;

View File

@ -24,9 +24,11 @@ use std::{
sync::Arc,
};
use wgc::{
command::{bundle_ffi::*, compute_commands::*, render_commands::*},
command::{bundle_ffi::*, render_commands::*},
device::DeviceLostClosure,
id::{CommandEncoderId, TextureViewId},
gfx_select,
id::CommandEncoderId,
id::TextureViewId,
pipeline::CreateShaderModuleError,
};
use wgt::WasmNotSendSync;
@ -476,6 +478,12 @@ impl Queue {
}
}
#[derive(Debug)]
pub struct ComputePass {
pass: Box<dyn wgc::command::DynComputePass>,
error_sink: ErrorSink,
}
#[derive(Debug)]
pub struct CommandEncoder {
error_sink: ErrorSink,
@ -514,7 +522,7 @@ impl crate::Context for ContextWgpuCore {
type CommandEncoderId = wgc::id::CommandEncoderId;
type CommandEncoderData = CommandEncoder;
type ComputePassId = Unused;
type ComputePassData = wgc::command::ComputePass;
type ComputePassData = ComputePass;
type RenderPassId = Unused;
type RenderPassData = wgc::command::RenderPass;
type CommandBufferId = wgc::id::CommandBufferId;
@ -1841,7 +1849,7 @@ impl crate::Context for ContextWgpuCore {
fn command_encoder_begin_compute_pass(
&self,
encoder: &Self::CommandEncoderId,
_encoder_data: &Self::CommandEncoderData,
encoder_data: &Self::CommandEncoderData,
desc: &ComputePassDescriptor<'_>,
) -> (Self::ComputePassId, Self::ComputePassData) {
let timestamp_writes =
@ -1852,15 +1860,16 @@ impl crate::Context for ContextWgpuCore {
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
});
(
Unused,
wgc::command::ComputePass::new(
*encoder,
&wgc::command::ComputePassDescriptor {
Self::ComputePassData {
pass: gfx_select!(encoder => self.0.command_encoder_create_compute_pass_dyn(*encoder, &wgc::command::ComputePassDescriptor {
label: desc.label.map(Borrowed),
timestamp_writes: timestamp_writes.as_ref(),
})),
error_sink: encoder_data.error_sink.clone(),
},
),
)
}
@ -1871,9 +1880,7 @@ impl crate::Context for ContextWgpuCore {
_pass: &mut Self::ComputePassId,
pass_data: &mut Self::ComputePassData,
) {
if let Err(cause) = wgc::gfx_select!(
encoder => self.0.command_encoder_run_compute_pass(*encoder, pass_data)
) {
if let Err(cause) = pass_data.pass.run(&self.0) {
let name = wgc::gfx_select!(encoder => self.0.command_buffer_label(encoder.into_command_buffer_id()));
self.handle_error(
&encoder_data.error_sink,
@ -2336,7 +2343,9 @@ impl crate::Context for ContextWgpuCore {
pipeline: &Self::ComputePipelineId,
_pipeline_data: &Self::ComputePipelineData,
) {
wgpu_compute_pass_set_pipeline(pass_data, *pipeline)
if let Err(cause) = pass_data.pass.set_pipeline(&self.0, *pipeline) {
self.handle_error_nolabel(&pass_data.error_sink, cause, "ComputePass::set_pipeline");
}
}
fn compute_pass_set_bind_group(
@ -2348,7 +2357,12 @@ impl crate::Context for ContextWgpuCore {
_bind_group_data: &Self::BindGroupData,
offsets: &[wgt::DynamicOffset],
) {
wgpu_compute_pass_set_bind_group(pass_data, index, *bind_group, offsets);
if let Err(cause) = pass_data
.pass
.set_bind_group(&self.0, index, *bind_group, offsets)
{
self.handle_error_nolabel(&pass_data.error_sink, cause, "ComputePass::set_bind_group");
}
}
fn compute_pass_set_push_constants(
@ -2358,7 +2372,7 @@ impl crate::Context for ContextWgpuCore {
offset: u32,
data: &[u8],
) {
wgpu_compute_pass_set_push_constant(pass_data, offset, data);
pass_data.pass.set_push_constant(&self.0, offset, data);
}
fn compute_pass_insert_debug_marker(
@ -2367,7 +2381,7 @@ impl crate::Context for ContextWgpuCore {
pass_data: &mut Self::ComputePassData,
label: &str,
) {
wgpu_compute_pass_insert_debug_marker(pass_data, label, 0);
pass_data.pass.insert_debug_marker(&self.0, label, 0);
}
fn compute_pass_push_debug_group(
@ -2376,7 +2390,7 @@ impl crate::Context for ContextWgpuCore {
pass_data: &mut Self::ComputePassData,
group_label: &str,
) {
wgpu_compute_pass_push_debug_group(pass_data, group_label, 0);
pass_data.pass.push_debug_group(&self.0, group_label, 0);
}
fn compute_pass_pop_debug_group(
@ -2384,7 +2398,7 @@ impl crate::Context for ContextWgpuCore {
_pass: &mut Self::ComputePassId,
pass_data: &mut Self::ComputePassData,
) {
wgpu_compute_pass_pop_debug_group(pass_data);
pass_data.pass.pop_debug_group(&self.0);
}
fn compute_pass_write_timestamp(
@ -2395,7 +2409,12 @@ impl crate::Context for ContextWgpuCore {
_query_set_data: &Self::QuerySetData,
query_index: u32,
) {
wgpu_compute_pass_write_timestamp(pass_data, *query_set, query_index)
if let Err(cause) = pass_data
.pass
.write_timestamp(&self.0, *query_set, query_index)
{
self.handle_error_nolabel(&pass_data.error_sink, cause, "ComputePass::write_timestamp");
}
}
fn compute_pass_begin_pipeline_statistics_query(
@ -2406,7 +2425,17 @@ impl crate::Context for ContextWgpuCore {
_query_set_data: &Self::QuerySetData,
query_index: u32,
) {
wgpu_compute_pass_begin_pipeline_statistics_query(pass_data, *query_set, query_index)
if let Err(cause) =
pass_data
.pass
.begin_pipeline_statistics_query(&self.0, *query_set, query_index)
{
self.handle_error_nolabel(
&pass_data.error_sink,
cause,
"ComputePass::begin_pipeline_statistics_query",
);
}
}
fn compute_pass_end_pipeline_statistics_query(
@ -2414,7 +2443,7 @@ impl crate::Context for ContextWgpuCore {
_pass: &mut Self::ComputePassId,
pass_data: &mut Self::ComputePassData,
) {
wgpu_compute_pass_end_pipeline_statistics_query(pass_data)
pass_data.pass.end_pipeline_statistics_query(&self.0);
}
fn compute_pass_dispatch_workgroups(
@ -2425,7 +2454,7 @@ impl crate::Context for ContextWgpuCore {
y: u32,
z: u32,
) {
wgpu_compute_pass_dispatch_workgroups(pass_data, x, y, z)
pass_data.pass.dispatch_workgroups(&self.0, x, y, z);
}
fn compute_pass_dispatch_workgroups_indirect(
@ -2436,7 +2465,17 @@ impl crate::Context for ContextWgpuCore {
_indirect_buffer_data: &Self::BufferData,
indirect_offset: wgt::BufferAddress,
) {
wgpu_compute_pass_dispatch_workgroups_indirect(pass_data, *indirect_buffer, indirect_offset)
if let Err(cause) =
pass_data
.pass
.dispatch_workgroups_indirect(&self.0, *indirect_buffer, indirect_offset)
{
self.handle_error_nolabel(
&pass_data.error_sink,
cause,
"ComputePass::dispatch_workgroups_indirect",
);
}
}
fn render_bundle_encoder_set_pipeline(

View File

@ -4538,7 +4538,7 @@ impl<'a> ComputePass<'a> {
pub fn set_bind_group(
&mut self,
index: u32,
bind_group: &'a BindGroup,
bind_group: &BindGroup,
offsets: &[DynamicOffset],
) {
DynContext::compute_pass_set_bind_group(
@ -4553,7 +4553,7 @@ impl<'a> ComputePass<'a> {
}
/// Sets the active compute pipeline.
pub fn set_pipeline(&mut self, pipeline: &'a ComputePipeline) {
pub fn set_pipeline(&mut self, pipeline: &ComputePipeline) {
DynContext::compute_pass_set_pipeline(
&*self.parent.context,
&mut self.id,
@ -4611,7 +4611,7 @@ impl<'a> ComputePass<'a> {
/// The structure expected in `indirect_buffer` must conform to [`DispatchIndirectArgs`](crate::util::DispatchIndirectArgs).
pub fn dispatch_workgroups_indirect(
&mut self,
indirect_buffer: &'a Buffer,
indirect_buffer: &Buffer,
indirect_offset: BufferAddress,
) {
DynContext::compute_pass_dispatch_workgroups_indirect(