mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-21 22:33:49 +00:00
Fix QuerySet ownership of ComputePass (#5671)
* add new tests for checking on query set lifetime * Fix ownership management of query sets on compute passes for write_timestamp, timestamp_writes (on desc) and pipeline statistic queries * changelog entry
This commit is contained in:
parent
d258d6ce73
commit
9a27ba53ca
@ -58,7 +58,7 @@ fn independent_cpass<'enc>(encoder: &'enc mut wgpu::CommandEncoder) -> wgpu::Com
|
||||
This is very useful for library authors, but opens up an easy way for incorrect use, so use with care.
|
||||
`forget_lifetime` is zero overhead and has no side effects on pass recording.
|
||||
|
||||
By @wumpf in [#5569](https://github.com/gfx-rs/wgpu/pull/5569), [#5575](https://github.com/gfx-rs/wgpu/pull/5575), [#5620](https://github.com/gfx-rs/wgpu/pull/5620), [#5768](https://github.com/gfx-rs/wgpu/pull/5768) (together with @kpreid).
|
||||
By @wumpf in [#5569](https://github.com/gfx-rs/wgpu/pull/5569), [#5575](https://github.com/gfx-rs/wgpu/pull/5575), [#5620](https://github.com/gfx-rs/wgpu/pull/5620), [#5768](https://github.com/gfx-rs/wgpu/pull/5768) (together with @kpreid), [#5671](https://github.com/gfx-rs/wgpu/pull/5671).
|
||||
|
||||
#### Querying shader compilation errors
|
||||
|
||||
@ -116,6 +116,7 @@ By @stefnotch in [#5410](https://github.com/gfx-rs/wgpu/pull/5410)
|
||||
#### General
|
||||
|
||||
- Ensure render pipelines have at least 1 target. By @ErichDonGubler in [#5715](https://github.com/gfx-rs/wgpu/pull/5715)
|
||||
- `wgpu::ComputePass` now internally takes ownership of `QuerySet` for both `wgpu::ComputePassTimestampWrites` as well as timestamp writes and statistics query, fixing crashes when destroying `QuerySet` before ending the pass. By @wumpf in [#5671](https://github.com/gfx-rs/wgpu/pull/5671)
|
||||
|
||||
#### Metal
|
||||
|
||||
|
@ -1,9 +1,5 @@
|
||||
//! Tests that compute passes take ownership of resources that are associated with.
|
||||
//! I.e. once a resource is passed in to a compute pass, it can be dropped.
|
||||
//!
|
||||
//! TODO: Also should test resource ownership for:
|
||||
//! * write_timestamp
|
||||
//! * begin_pipeline_statistics_query
|
||||
|
||||
use std::num::NonZeroU64;
|
||||
|
||||
@ -36,15 +32,10 @@ async fn compute_pass_resource_ownership(ctx: TestingContext) {
|
||||
|
||||
let mut encoder = ctx
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("encoder"),
|
||||
});
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
|
||||
|
||||
{
|
||||
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some("compute_pass"),
|
||||
timestamp_writes: None, // TODO: See description above, we should test this as well once we lift the lifetime bound.
|
||||
});
|
||||
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
|
||||
cpass.set_pipeline(&pipeline);
|
||||
cpass.set_bind_group(0, &bind_group, &[]);
|
||||
cpass.dispatch_workgroups_indirect(&indirect_buffer, 0);
|
||||
@ -58,18 +49,115 @@ async fn compute_pass_resource_ownership(ctx: TestingContext) {
|
||||
.panic_on_timeout();
|
||||
}
|
||||
|
||||
// Ensure that the compute pass still executed normally.
|
||||
encoder.copy_buffer_to_buffer(&gpu_buffer, 0, &cpu_buffer, 0, buffer_size);
|
||||
ctx.queue.submit([encoder.finish()]);
|
||||
cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ());
|
||||
ctx.async_poll(wgpu::Maintain::wait())
|
||||
.await
|
||||
.panic_on_timeout();
|
||||
assert_compute_pass_executed_normally(encoder, gpu_buffer, cpu_buffer, buffer_size, ctx).await;
|
||||
}
|
||||
|
||||
let data = cpu_buffer.slice(..).get_mapped_range();
|
||||
#[gpu_test]
|
||||
static COMPUTE_PASS_QUERY_SET_OWNERSHIP_PIPELINE_STATISTICS: GpuTestConfiguration =
|
||||
GpuTestConfiguration::new()
|
||||
.parameters(
|
||||
TestParameters::default()
|
||||
.test_features_limits()
|
||||
.features(wgpu::Features::PIPELINE_STATISTICS_QUERY),
|
||||
)
|
||||
.run_async(compute_pass_query_set_ownership_pipeline_statistics);
|
||||
|
||||
let floats: &[f32] = bytemuck::cast_slice(&data);
|
||||
assert_eq!(floats, [2.0, 4.0, 6.0, 8.0]);
|
||||
async fn compute_pass_query_set_ownership_pipeline_statistics(ctx: TestingContext) {
|
||||
let ResourceSetup {
|
||||
gpu_buffer,
|
||||
cpu_buffer,
|
||||
buffer_size,
|
||||
indirect_buffer: _,
|
||||
bind_group,
|
||||
pipeline,
|
||||
} = resource_setup(&ctx);
|
||||
|
||||
let query_set = ctx.device.create_query_set(&wgpu::QuerySetDescriptor {
|
||||
label: Some("query_set"),
|
||||
ty: wgpu::QueryType::PipelineStatistics(
|
||||
wgpu::PipelineStatisticsTypes::COMPUTE_SHADER_INVOCATIONS,
|
||||
),
|
||||
count: 1,
|
||||
});
|
||||
|
||||
let mut encoder = ctx
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
|
||||
|
||||
{
|
||||
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
|
||||
cpass.set_pipeline(&pipeline);
|
||||
cpass.set_bind_group(0, &bind_group, &[]);
|
||||
cpass.begin_pipeline_statistics_query(&query_set, 0);
|
||||
cpass.dispatch_workgroups(1, 1, 1);
|
||||
cpass.end_pipeline_statistics_query();
|
||||
|
||||
// Drop the query set. Then do a device poll to make sure it's not dropped too early, no matter what.
|
||||
drop(query_set);
|
||||
ctx.async_poll(wgpu::Maintain::wait())
|
||||
.await
|
||||
.panic_on_timeout();
|
||||
}
|
||||
|
||||
assert_compute_pass_executed_normally(encoder, gpu_buffer, cpu_buffer, buffer_size, ctx).await;
|
||||
}
|
||||
|
||||
#[gpu_test]
|
||||
static COMPUTE_PASS_QUERY_TIMESTAMPS: GpuTestConfiguration =
|
||||
GpuTestConfiguration::new()
|
||||
.parameters(TestParameters::default().test_features_limits().features(
|
||||
wgpu::Features::TIMESTAMP_QUERY | wgpu::Features::TIMESTAMP_QUERY_INSIDE_PASSES,
|
||||
))
|
||||
.run_async(compute_pass_query_timestamps);
|
||||
|
||||
async fn compute_pass_query_timestamps(ctx: TestingContext) {
|
||||
let ResourceSetup {
|
||||
gpu_buffer,
|
||||
cpu_buffer,
|
||||
buffer_size,
|
||||
indirect_buffer: _,
|
||||
bind_group,
|
||||
pipeline,
|
||||
} = resource_setup(&ctx);
|
||||
|
||||
let query_set_timestamp_writes = ctx.device.create_query_set(&wgpu::QuerySetDescriptor {
|
||||
label: Some("query_set_timestamp_writes"),
|
||||
ty: wgpu::QueryType::Timestamp,
|
||||
count: 2,
|
||||
});
|
||||
let query_set_write_timestamp = ctx.device.create_query_set(&wgpu::QuerySetDescriptor {
|
||||
label: Some("query_set_write_timestamp"),
|
||||
ty: wgpu::QueryType::Timestamp,
|
||||
count: 1,
|
||||
});
|
||||
|
||||
let mut encoder = ctx
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
|
||||
|
||||
{
|
||||
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some("compute_pass"),
|
||||
timestamp_writes: Some(wgpu::ComputePassTimestampWrites {
|
||||
query_set: &query_set_timestamp_writes,
|
||||
beginning_of_pass_write_index: Some(0),
|
||||
end_of_pass_write_index: Some(1),
|
||||
}),
|
||||
});
|
||||
cpass.set_pipeline(&pipeline);
|
||||
cpass.set_bind_group(0, &bind_group, &[]);
|
||||
cpass.write_timestamp(&query_set_write_timestamp, 0);
|
||||
cpass.dispatch_workgroups(1, 1, 1);
|
||||
|
||||
// Drop the query sets. Then do a device poll to make sure they're not dropped too early, no matter what.
|
||||
drop(query_set_timestamp_writes);
|
||||
drop(query_set_write_timestamp);
|
||||
ctx.async_poll(wgpu::Maintain::wait())
|
||||
.await
|
||||
.panic_on_timeout();
|
||||
}
|
||||
|
||||
assert_compute_pass_executed_normally(encoder, gpu_buffer, cpu_buffer, buffer_size, ctx).await;
|
||||
}
|
||||
|
||||
#[gpu_test]
|
||||
@ -89,9 +177,7 @@ async fn compute_pass_keep_encoder_alive(ctx: TestingContext) {
|
||||
|
||||
let mut encoder = ctx
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("encoder"),
|
||||
});
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
|
||||
|
||||
let cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some("compute_pass"),
|
||||
@ -119,6 +205,26 @@ async fn compute_pass_keep_encoder_alive(ctx: TestingContext) {
|
||||
valid(&ctx.device, || drop(cpass));
|
||||
}
|
||||
|
||||
async fn assert_compute_pass_executed_normally(
|
||||
mut encoder: wgpu::CommandEncoder,
|
||||
gpu_buffer: wgpu::Buffer,
|
||||
cpu_buffer: wgpu::Buffer,
|
||||
buffer_size: u64,
|
||||
ctx: TestingContext,
|
||||
) {
|
||||
encoder.copy_buffer_to_buffer(&gpu_buffer, 0, &cpu_buffer, 0, buffer_size);
|
||||
ctx.queue.submit([encoder.finish()]);
|
||||
cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ());
|
||||
ctx.async_poll(wgpu::Maintain::wait())
|
||||
.await
|
||||
.panic_on_timeout();
|
||||
|
||||
let data = cpu_buffer.slice(..).get_mapped_range();
|
||||
|
||||
let floats: &[f32] = bytemuck::cast_slice(&data);
|
||||
assert_eq!(floats, [2.0, 4.0, 6.0, 8.0]);
|
||||
}
|
||||
|
||||
// Setup ------------------------------------------------------------
|
||||
|
||||
struct ResourceSetup {
|
||||
|
@ -5,15 +5,15 @@ use crate::{
|
||||
compute_command::{ArcComputeCommand, ComputeCommand},
|
||||
end_pipeline_statistics_query,
|
||||
memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
|
||||
BasePass, BindGroupStateChange, CommandBuffer, CommandEncoderError, CommandEncoderStatus,
|
||||
MapPassErr, PassErrorScope, QueryUseError, StateChange,
|
||||
validate_and_begin_pipeline_statistics_query, BasePass, BindGroupStateChange,
|
||||
CommandBuffer, CommandEncoderError, CommandEncoderStatus, MapPassErr, PassErrorScope,
|
||||
QueryUseError, StateChange,
|
||||
},
|
||||
device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
|
||||
error::{ErrorFormatter, PrettyError},
|
||||
global::Global,
|
||||
hal_api::HalApi,
|
||||
hal_label,
|
||||
id::{self},
|
||||
hal_label, id,
|
||||
init_tracker::MemoryInitKind,
|
||||
resource::{self, Resource},
|
||||
snatch::SnatchGuard,
|
||||
@ -48,7 +48,7 @@ pub struct ComputePass<A: HalApi> {
|
||||
/// If it is none, this pass is invalid and any operation on it will return an error.
|
||||
parent: Option<Arc<CommandBuffer<A>>>,
|
||||
|
||||
timestamp_writes: Option<ComputePassTimestampWrites>,
|
||||
timestamp_writes: Option<ArcComputePassTimestampWrites<A>>,
|
||||
|
||||
// Resource binding dedupe state.
|
||||
current_bind_groups: BindGroupStateChange,
|
||||
@ -57,11 +57,16 @@ pub struct ComputePass<A: HalApi> {
|
||||
|
||||
impl<A: HalApi> ComputePass<A> {
|
||||
/// If the parent command buffer is invalid, the returned pass will be invalid.
|
||||
fn new(parent: Option<Arc<CommandBuffer<A>>>, desc: &ComputePassDescriptor) -> Self {
|
||||
fn new(parent: Option<Arc<CommandBuffer<A>>>, desc: ArcComputePassDescriptor<A>) -> Self {
|
||||
let ArcComputePassDescriptor {
|
||||
label,
|
||||
timestamp_writes,
|
||||
} = desc;
|
||||
|
||||
Self {
|
||||
base: Some(BasePass::new(&desc.label)),
|
||||
base: Some(BasePass::new(label)),
|
||||
parent,
|
||||
timestamp_writes: desc.timestamp_writes.cloned(),
|
||||
timestamp_writes,
|
||||
|
||||
current_bind_groups: BindGroupStateChange::new(),
|
||||
current_pipeline: StateChange::new(),
|
||||
@ -107,6 +112,16 @@ pub struct ComputePassTimestampWrites {
|
||||
pub end_of_pass_write_index: Option<u32>,
|
||||
}
|
||||
|
||||
/// Describes the writing of timestamp values in a compute pass with the query set resolved.
|
||||
struct ArcComputePassTimestampWrites<A: HalApi> {
|
||||
/// The query set to write the timestamps to.
|
||||
pub query_set: Arc<resource::QuerySet<A>>,
|
||||
/// The index of the query set at which a start timestamp of this pass is written, if any.
|
||||
pub beginning_of_pass_write_index: Option<u32>,
|
||||
/// The index of the query set at which an end timestamp of this pass is written, if any.
|
||||
pub end_of_pass_write_index: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ComputePassDescriptor<'a> {
|
||||
pub label: Label<'a>,
|
||||
@ -114,6 +129,12 @@ pub struct ComputePassDescriptor<'a> {
|
||||
pub timestamp_writes: Option<&'a ComputePassTimestampWrites>,
|
||||
}
|
||||
|
||||
struct ArcComputePassDescriptor<'a, A: HalApi> {
|
||||
pub label: &'a Label<'a>,
|
||||
/// Defines where and when timestamp values will be written for this pass.
|
||||
pub timestamp_writes: Option<ArcComputePassTimestampWrites<A>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, Eq, PartialEq)]
|
||||
#[non_exhaustive]
|
||||
pub enum DispatchError {
|
||||
@ -310,13 +331,37 @@ impl Global {
|
||||
pub fn command_encoder_create_compute_pass<A: HalApi>(
|
||||
&self,
|
||||
encoder_id: id::CommandEncoderId,
|
||||
desc: &ComputePassDescriptor,
|
||||
desc: &ComputePassDescriptor<'_>,
|
||||
) -> (ComputePass<A>, Option<CommandEncoderError>) {
|
||||
let hub = A::hub(self);
|
||||
|
||||
let mut arc_desc = ArcComputePassDescriptor {
|
||||
label: &desc.label,
|
||||
timestamp_writes: None, // Handle only once we resolved the encoder.
|
||||
};
|
||||
|
||||
match CommandBuffer::lock_encoder(hub, encoder_id) {
|
||||
Ok(cmd_buf) => (ComputePass::new(Some(cmd_buf), desc), None),
|
||||
Err(err) => (ComputePass::new(None, desc), Some(err)),
|
||||
Ok(cmd_buf) => {
|
||||
arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
|
||||
let Ok(query_set) = hub.query_sets.read().get_owned(tw.query_set) else {
|
||||
return (
|
||||
ComputePass::new(None, arc_desc),
|
||||
Some(CommandEncoderError::InvalidTimestampWritesQuerySetId),
|
||||
);
|
||||
};
|
||||
|
||||
Some(ArcComputePassTimestampWrites {
|
||||
query_set,
|
||||
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
|
||||
end_of_pass_write_index: tw.end_of_pass_write_index,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(ComputePass::new(Some(cmd_buf), arc_desc), None)
|
||||
}
|
||||
Err(err) => (ComputePass::new(None, arc_desc), Some(err)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -349,7 +394,7 @@ impl Global {
|
||||
.take()
|
||||
.ok_or(ComputePassErrorInner::PassEnded)
|
||||
.map_pass_err(scope)?;
|
||||
self.compute_pass_end_impl(parent, base, pass.timestamp_writes.as_ref())
|
||||
self.compute_pass_end_impl(parent, base, pass.timestamp_writes.take())
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
@ -360,11 +405,26 @@ impl Global {
|
||||
timestamp_writes: Option<&ComputePassTimestampWrites>,
|
||||
) -> Result<(), ComputePassError> {
|
||||
let hub = A::hub(self);
|
||||
let scope = PassErrorScope::PassEncoder(encoder_id);
|
||||
|
||||
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id)
|
||||
.map_pass_err(PassErrorScope::PassEncoder(encoder_id))?;
|
||||
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(scope)?;
|
||||
let commands = ComputeCommand::resolve_compute_command_ids(A::hub(self), &base.commands)?;
|
||||
|
||||
let timestamp_writes = if let Some(tw) = timestamp_writes {
|
||||
Some(ArcComputePassTimestampWrites {
|
||||
query_set: hub
|
||||
.query_sets
|
||||
.read()
|
||||
.get_owned(tw.query_set)
|
||||
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(tw.query_set))
|
||||
.map_pass_err(scope)?,
|
||||
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
|
||||
end_of_pass_write_index: tw.end_of_pass_write_index,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
self.compute_pass_end_impl::<A>(
|
||||
&cmd_buf,
|
||||
BasePass {
|
||||
@ -382,13 +442,11 @@ impl Global {
|
||||
&self,
|
||||
cmd_buf: &CommandBuffer<A>,
|
||||
base: BasePass<ArcComputeCommand<A>>,
|
||||
timestamp_writes: Option<&ComputePassTimestampWrites>,
|
||||
mut timestamp_writes: Option<ArcComputePassTimestampWrites<A>>,
|
||||
) -> Result<(), ComputePassError> {
|
||||
profiling::scope!("CommandEncoder::run_compute_pass");
|
||||
let pass_scope = PassErrorScope::Pass(Some(cmd_buf.as_info().id()));
|
||||
|
||||
let hub = A::hub(self);
|
||||
|
||||
let device = &cmd_buf.device;
|
||||
if !device.is_valid() {
|
||||
return Err(ComputePassErrorInner::InvalidDevice(
|
||||
@ -410,7 +468,13 @@ impl Global {
|
||||
string_data: base.string_data.to_vec(),
|
||||
push_constant_data: base.push_constant_data.to_vec(),
|
||||
},
|
||||
timestamp_writes: timestamp_writes.cloned(),
|
||||
timestamp_writes: timestamp_writes
|
||||
.as_ref()
|
||||
.map(|tw| ComputePassTimestampWrites {
|
||||
query_set: tw.query_set.as_info().id(),
|
||||
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
|
||||
end_of_pass_write_index: tw.end_of_pass_write_index,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
@ -428,8 +492,6 @@ impl Global {
|
||||
*status = CommandEncoderStatus::Error;
|
||||
let raw = encoder.open().map_pass_err(pass_scope)?;
|
||||
|
||||
let query_set_guard = hub.query_sets.read();
|
||||
|
||||
let mut state = State {
|
||||
binder: Binder::new(),
|
||||
pipeline: None,
|
||||
@ -441,12 +503,19 @@ impl Global {
|
||||
let mut string_offset = 0;
|
||||
let mut active_query = None;
|
||||
|
||||
let timestamp_writes = if let Some(tw) = timestamp_writes {
|
||||
let query_set: &resource::QuerySet<A> = tracker
|
||||
.query_sets
|
||||
.add_single(&*query_set_guard, tw.query_set)
|
||||
.ok_or(ComputePassErrorInner::InvalidQuerySet(tw.query_set))
|
||||
.map_pass_err(pass_scope)?;
|
||||
let snatch_guard = device.snatchable_lock.read();
|
||||
|
||||
let indices = &device.tracker_indices;
|
||||
tracker.buffers.set_size(indices.buffers.size());
|
||||
tracker.textures.set_size(indices.textures.size());
|
||||
tracker.bind_groups.set_size(indices.bind_groups.size());
|
||||
tracker
|
||||
.compute_pipelines
|
||||
.set_size(indices.compute_pipelines.size());
|
||||
tracker.query_sets.set_size(indices.query_sets.size());
|
||||
|
||||
let timestamp_writes = if let Some(tw) = timestamp_writes.take() {
|
||||
let query_set = tracker.query_sets.insert_single(tw.query_set);
|
||||
|
||||
// Unlike in render passes we can't delay resetting the query sets since
|
||||
// there is no auxiliary pass.
|
||||
@ -476,17 +545,6 @@ impl Global {
|
||||
None
|
||||
};
|
||||
|
||||
let snatch_guard = device.snatchable_lock.read();
|
||||
|
||||
let indices = &device.tracker_indices;
|
||||
tracker.buffers.set_size(indices.buffers.size());
|
||||
tracker.textures.set_size(indices.textures.size());
|
||||
tracker.bind_groups.set_size(indices.bind_groups.size());
|
||||
tracker
|
||||
.compute_pipelines
|
||||
.set_size(indices.compute_pipelines.size());
|
||||
tracker.query_sets.set_size(indices.query_sets.size());
|
||||
|
||||
let discard_hal_labels = self
|
||||
.instance
|
||||
.flags
|
||||
@ -812,7 +870,6 @@ impl Global {
|
||||
query_set,
|
||||
query_index,
|
||||
} => {
|
||||
let query_set_id = query_set.as_info().id();
|
||||
let scope = PassErrorScope::WriteTimestamp;
|
||||
|
||||
device
|
||||
@ -822,33 +879,29 @@ impl Global {
|
||||
let query_set = tracker.query_sets.insert_single(query_set);
|
||||
|
||||
query_set
|
||||
.validate_and_write_timestamp(raw, query_set_id, query_index, None)
|
||||
.validate_and_write_timestamp(raw, query_index, None)
|
||||
.map_pass_err(scope)?;
|
||||
}
|
||||
ArcComputeCommand::BeginPipelineStatisticsQuery {
|
||||
query_set,
|
||||
query_index,
|
||||
} => {
|
||||
let query_set_id = query_set.as_info().id();
|
||||
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
|
||||
|
||||
let query_set = tracker.query_sets.insert_single(query_set);
|
||||
|
||||
query_set
|
||||
.validate_and_begin_pipeline_statistics_query(
|
||||
raw,
|
||||
query_set_id,
|
||||
query_index,
|
||||
None,
|
||||
&mut active_query,
|
||||
)
|
||||
.map_pass_err(scope)?;
|
||||
validate_and_begin_pipeline_statistics_query(
|
||||
query_set.clone(),
|
||||
raw,
|
||||
query_index,
|
||||
None,
|
||||
&mut active_query,
|
||||
)
|
||||
.map_pass_err(scope)?;
|
||||
}
|
||||
ArcComputeCommand::EndPipelineStatisticsQuery => {
|
||||
let scope = PassErrorScope::EndPipelineStatisticsQuery;
|
||||
|
||||
end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
|
||||
.map_pass_err(scope)?;
|
||||
end_pipeline_statistics_query(raw, &mut active_query).map_pass_err(scope)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -919,10 +972,9 @@ impl Global {
|
||||
let bind_group = hub
|
||||
.bind_groups
|
||||
.read()
|
||||
.get(bind_group_id)
|
||||
.get_owned(bind_group_id)
|
||||
.map_err(|_| ComputePassErrorInner::InvalidBindGroup(index))
|
||||
.map_pass_err(scope)?
|
||||
.clone();
|
||||
.map_pass_err(scope)?;
|
||||
|
||||
base.commands.push(ArcComputeCommand::SetBindGroup {
|
||||
index,
|
||||
@ -952,10 +1004,9 @@ impl Global {
|
||||
let pipeline = hub
|
||||
.compute_pipelines
|
||||
.read()
|
||||
.get(pipeline_id)
|
||||
.get_owned(pipeline_id)
|
||||
.map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id))
|
||||
.map_pass_err(scope)?
|
||||
.clone();
|
||||
.map_pass_err(scope)?;
|
||||
|
||||
base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
|
||||
|
||||
@ -1035,10 +1086,9 @@ impl Global {
|
||||
let buffer = hub
|
||||
.buffers
|
||||
.read()
|
||||
.get(buffer_id)
|
||||
.get_owned(buffer_id)
|
||||
.map_err(|_| ComputePassErrorInner::InvalidBuffer(buffer_id))
|
||||
.map_pass_err(scope)?
|
||||
.clone();
|
||||
.map_pass_err(scope)?;
|
||||
|
||||
base.commands
|
||||
.push(ArcComputeCommand::<A>::DispatchIndirect { buffer, offset });
|
||||
@ -1109,10 +1159,9 @@ impl Global {
|
||||
let query_set = hub
|
||||
.query_sets
|
||||
.read()
|
||||
.get(query_set_id)
|
||||
.get_owned(query_set_id)
|
||||
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
|
||||
.map_pass_err(scope)?
|
||||
.clone();
|
||||
.map_pass_err(scope)?;
|
||||
|
||||
base.commands.push(ArcComputeCommand::WriteTimestamp {
|
||||
query_set,
|
||||
@ -1135,10 +1184,9 @@ impl Global {
|
||||
let query_set = hub
|
||||
.query_sets
|
||||
.read()
|
||||
.get(query_set_id)
|
||||
.get_owned(query_set_id)
|
||||
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
|
||||
.map_pass_err(scope)?
|
||||
.clone();
|
||||
.map_pass_err(scope)?;
|
||||
|
||||
base.commands
|
||||
.push(ArcComputeCommand::BeginPipelineStatisticsQuery {
|
||||
|
@ -633,6 +633,8 @@ pub enum CommandEncoderError {
|
||||
Device(#[from] DeviceError),
|
||||
#[error("Command encoder is locked by a previously created render/compute pass. Before recording any new commands, the pass must be ended.")]
|
||||
Locked,
|
||||
#[error("QuerySet provided for pass timestamp writes is invalid.")]
|
||||
InvalidTimestampWritesQuerySetId,
|
||||
}
|
||||
|
||||
impl Global {
|
||||
|
@ -13,7 +13,7 @@ use crate::{
|
||||
storage::Storage,
|
||||
Epoch, FastHashMap, Index,
|
||||
};
|
||||
use std::{iter, marker::PhantomData};
|
||||
use std::{iter, marker::PhantomData, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use wgt::BufferAddress;
|
||||
|
||||
@ -185,15 +185,14 @@ pub enum ResolveError {
|
||||
impl<A: HalApi> QuerySet<A> {
|
||||
fn validate_query(
|
||||
&self,
|
||||
query_set_id: id::QuerySetId,
|
||||
query_type: SimplifiedQueryType,
|
||||
query_index: u32,
|
||||
reset_state: Option<&mut QueryResetMap<A>>,
|
||||
) -> Result<&A::QuerySet, QueryUseError> {
|
||||
) -> Result<(), QueryUseError> {
|
||||
// We need to defer our resets because we are in a renderpass,
|
||||
// add the usage to the reset map.
|
||||
if let Some(reset) = reset_state {
|
||||
let used = reset.use_query_set(query_set_id, self, query_index);
|
||||
let used = reset.use_query_set(self.info.id(), self, query_index);
|
||||
if used {
|
||||
return Err(QueryUseError::UsedTwiceInsideRenderpass { query_index });
|
||||
}
|
||||
@ -214,133 +213,110 @@ impl<A: HalApi> QuerySet<A> {
|
||||
});
|
||||
}
|
||||
|
||||
Ok(self.raw())
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn validate_and_write_timestamp(
|
||||
&self,
|
||||
raw_encoder: &mut A::CommandEncoder,
|
||||
query_set_id: id::QuerySetId,
|
||||
query_index: u32,
|
||||
reset_state: Option<&mut QueryResetMap<A>>,
|
||||
) -> Result<(), QueryUseError> {
|
||||
let needs_reset = reset_state.is_none();
|
||||
let query_set = self.validate_query(
|
||||
query_set_id,
|
||||
SimplifiedQueryType::Timestamp,
|
||||
query_index,
|
||||
reset_state,
|
||||
)?;
|
||||
self.validate_query(SimplifiedQueryType::Timestamp, query_index, reset_state)?;
|
||||
|
||||
unsafe {
|
||||
// If we don't have a reset state tracker which can defer resets, we must reset now.
|
||||
if needs_reset {
|
||||
raw_encoder.reset_queries(self.raw(), query_index..(query_index + 1));
|
||||
}
|
||||
raw_encoder.write_timestamp(query_set, query_index);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn validate_and_begin_occlusion_query(
|
||||
&self,
|
||||
raw_encoder: &mut A::CommandEncoder,
|
||||
query_set_id: id::QuerySetId,
|
||||
query_index: u32,
|
||||
reset_state: Option<&mut QueryResetMap<A>>,
|
||||
active_query: &mut Option<(id::QuerySetId, u32)>,
|
||||
) -> Result<(), QueryUseError> {
|
||||
let needs_reset = reset_state.is_none();
|
||||
let query_set = self.validate_query(
|
||||
query_set_id,
|
||||
SimplifiedQueryType::Occlusion,
|
||||
query_index,
|
||||
reset_state,
|
||||
)?;
|
||||
|
||||
if let Some((_old_id, old_idx)) = active_query.replace((query_set_id, query_index)) {
|
||||
return Err(QueryUseError::AlreadyStarted {
|
||||
active_query_index: old_idx,
|
||||
new_query_index: query_index,
|
||||
});
|
||||
}
|
||||
|
||||
unsafe {
|
||||
// If we don't have a reset state tracker which can defer resets, we must reset now.
|
||||
if needs_reset {
|
||||
raw_encoder
|
||||
.reset_queries(self.raw.as_ref().unwrap(), query_index..(query_index + 1));
|
||||
}
|
||||
raw_encoder.begin_query(query_set, query_index);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn validate_and_begin_pipeline_statistics_query(
|
||||
&self,
|
||||
raw_encoder: &mut A::CommandEncoder,
|
||||
query_set_id: id::QuerySetId,
|
||||
query_index: u32,
|
||||
reset_state: Option<&mut QueryResetMap<A>>,
|
||||
active_query: &mut Option<(id::QuerySetId, u32)>,
|
||||
) -> Result<(), QueryUseError> {
|
||||
let needs_reset = reset_state.is_none();
|
||||
let query_set = self.validate_query(
|
||||
query_set_id,
|
||||
SimplifiedQueryType::PipelineStatistics,
|
||||
query_index,
|
||||
reset_state,
|
||||
)?;
|
||||
|
||||
if let Some((_old_id, old_idx)) = active_query.replace((query_set_id, query_index)) {
|
||||
return Err(QueryUseError::AlreadyStarted {
|
||||
active_query_index: old_idx,
|
||||
new_query_index: query_index,
|
||||
});
|
||||
}
|
||||
|
||||
unsafe {
|
||||
// If we don't have a reset state tracker which can defer resets, we must reset now.
|
||||
if needs_reset {
|
||||
raw_encoder.reset_queries(self.raw(), query_index..(query_index + 1));
|
||||
}
|
||||
raw_encoder.begin_query(query_set, query_index);
|
||||
raw_encoder.write_timestamp(self.raw(), query_index);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn validate_and_begin_occlusion_query<A: HalApi>(
|
||||
query_set: Arc<QuerySet<A>>,
|
||||
raw_encoder: &mut A::CommandEncoder,
|
||||
query_index: u32,
|
||||
reset_state: Option<&mut QueryResetMap<A>>,
|
||||
active_query: &mut Option<(Arc<QuerySet<A>>, u32)>,
|
||||
) -> Result<(), QueryUseError> {
|
||||
let needs_reset = reset_state.is_none();
|
||||
query_set.validate_query(SimplifiedQueryType::Occlusion, query_index, reset_state)?;
|
||||
|
||||
if let Some((_old, old_idx)) = active_query.take() {
|
||||
return Err(QueryUseError::AlreadyStarted {
|
||||
active_query_index: old_idx,
|
||||
new_query_index: query_index,
|
||||
});
|
||||
}
|
||||
let (query_set, _) = &active_query.insert((query_set, query_index));
|
||||
|
||||
unsafe {
|
||||
// If we don't have a reset state tracker which can defer resets, we must reset now.
|
||||
if needs_reset {
|
||||
raw_encoder.reset_queries(query_set.raw(), query_index..(query_index + 1));
|
||||
}
|
||||
raw_encoder.begin_query(query_set.raw(), query_index);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn end_occlusion_query<A: HalApi>(
|
||||
raw_encoder: &mut A::CommandEncoder,
|
||||
storage: &Storage<QuerySet<A>>,
|
||||
active_query: &mut Option<(id::QuerySetId, u32)>,
|
||||
active_query: &mut Option<(Arc<QuerySet<A>>, u32)>,
|
||||
) -> Result<(), QueryUseError> {
|
||||
if let Some((query_set_id, query_index)) = active_query.take() {
|
||||
// We can unwrap here as the validity was validated when the active query was set
|
||||
let query_set = storage.get(query_set_id).unwrap();
|
||||
|
||||
if let Some((query_set, query_index)) = active_query.take() {
|
||||
unsafe { raw_encoder.end_query(query_set.raw.as_ref().unwrap(), query_index) };
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
Err(QueryUseError::AlreadyStopped)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn validate_and_begin_pipeline_statistics_query<A: HalApi>(
|
||||
query_set: Arc<QuerySet<A>>,
|
||||
raw_encoder: &mut A::CommandEncoder,
|
||||
query_index: u32,
|
||||
reset_state: Option<&mut QueryResetMap<A>>,
|
||||
active_query: &mut Option<(Arc<QuerySet<A>>, u32)>,
|
||||
) -> Result<(), QueryUseError> {
|
||||
let needs_reset = reset_state.is_none();
|
||||
query_set.validate_query(
|
||||
SimplifiedQueryType::PipelineStatistics,
|
||||
query_index,
|
||||
reset_state,
|
||||
)?;
|
||||
|
||||
if let Some((_old, old_idx)) = active_query.take() {
|
||||
return Err(QueryUseError::AlreadyStarted {
|
||||
active_query_index: old_idx,
|
||||
new_query_index: query_index,
|
||||
});
|
||||
}
|
||||
let (query_set, _) = &active_query.insert((query_set, query_index));
|
||||
|
||||
unsafe {
|
||||
// If we don't have a reset state tracker which can defer resets, we must reset now.
|
||||
if needs_reset {
|
||||
raw_encoder.reset_queries(query_set.raw(), query_index..(query_index + 1));
|
||||
}
|
||||
raw_encoder.begin_query(query_set.raw(), query_index);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn end_pipeline_statistics_query<A: HalApi>(
|
||||
raw_encoder: &mut A::CommandEncoder,
|
||||
storage: &Storage<QuerySet<A>>,
|
||||
active_query: &mut Option<(id::QuerySetId, u32)>,
|
||||
active_query: &mut Option<(Arc<QuerySet<A>>, u32)>,
|
||||
) -> Result<(), QueryUseError> {
|
||||
if let Some((query_set_id, query_index)) = active_query.take() {
|
||||
// We can unwrap here as the validity was validated when the active query was set
|
||||
let query_set = storage.get(query_set_id).unwrap();
|
||||
|
||||
if let Some((query_set, query_index)) = active_query.take() {
|
||||
unsafe { raw_encoder.end_query(query_set.raw(), query_index) };
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
Err(QueryUseError::AlreadyStopped)
|
||||
@ -384,7 +360,7 @@ impl Global {
|
||||
.add_single(&*query_set_guard, query_set_id)
|
||||
.ok_or(QueryError::InvalidQuerySet(query_set_id))?;
|
||||
|
||||
query_set.validate_and_write_timestamp(raw_encoder, query_set_id, query_index, None)?;
|
||||
query_set.validate_and_write_timestamp(raw_encoder, query_index, None)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,3 +1,6 @@
|
||||
use crate::command::{
|
||||
validate_and_begin_occlusion_query, validate_and_begin_pipeline_statistics_query,
|
||||
};
|
||||
use crate::resource::Resource;
|
||||
use crate::snatch::SnatchGuard;
|
||||
use crate::{
|
||||
@ -2258,7 +2261,6 @@ impl Global {
|
||||
query_set
|
||||
.validate_and_write_timestamp(
|
||||
raw,
|
||||
query_set_id,
|
||||
query_index,
|
||||
Some(&mut cmd_buf_data.pending_query_resets),
|
||||
)
|
||||
@ -2278,22 +2280,20 @@ impl Global {
|
||||
.ok_or(RenderCommandError::InvalidQuerySet(query_set_id))
|
||||
.map_pass_err(scope)?;
|
||||
|
||||
query_set
|
||||
.validate_and_begin_occlusion_query(
|
||||
raw,
|
||||
query_set_id,
|
||||
query_index,
|
||||
Some(&mut cmd_buf_data.pending_query_resets),
|
||||
&mut active_query,
|
||||
)
|
||||
.map_pass_err(scope)?;
|
||||
validate_and_begin_occlusion_query(
|
||||
query_set.clone(),
|
||||
raw,
|
||||
query_index,
|
||||
Some(&mut cmd_buf_data.pending_query_resets),
|
||||
&mut active_query,
|
||||
)
|
||||
.map_pass_err(scope)?;
|
||||
}
|
||||
RenderCommand::EndOcclusionQuery => {
|
||||
api_log!("RenderPass::end_occlusion_query");
|
||||
let scope = PassErrorScope::EndOcclusionQuery;
|
||||
|
||||
end_occlusion_query(raw, &*query_set_guard, &mut active_query)
|
||||
.map_pass_err(scope)?;
|
||||
end_occlusion_query(raw, &mut active_query).map_pass_err(scope)?;
|
||||
}
|
||||
RenderCommand::BeginPipelineStatisticsQuery {
|
||||
query_set_id,
|
||||
@ -2308,21 +2308,20 @@ impl Global {
|
||||
.ok_or(RenderCommandError::InvalidQuerySet(query_set_id))
|
||||
.map_pass_err(scope)?;
|
||||
|
||||
query_set
|
||||
.validate_and_begin_pipeline_statistics_query(
|
||||
raw,
|
||||
query_set_id,
|
||||
query_index,
|
||||
Some(&mut cmd_buf_data.pending_query_resets),
|
||||
&mut active_query,
|
||||
)
|
||||
.map_pass_err(scope)?;
|
||||
validate_and_begin_pipeline_statistics_query(
|
||||
query_set.clone(),
|
||||
raw,
|
||||
query_index,
|
||||
Some(&mut cmd_buf_data.pending_query_resets),
|
||||
&mut active_query,
|
||||
)
|
||||
.map_pass_err(scope)?;
|
||||
}
|
||||
RenderCommand::EndPipelineStatisticsQuery => {
|
||||
api_log!("RenderPass::end_pipeline_statistics_query");
|
||||
let scope = PassErrorScope::EndPipelineStatisticsQuery;
|
||||
|
||||
end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
|
||||
end_pipeline_statistics_query(raw, &mut active_query)
|
||||
.map_pass_err(scope)?;
|
||||
}
|
||||
RenderCommand::ExecuteBundle(bundle_id) => {
|
||||
|
Loading…
Reference in New Issue
Block a user