Fix QuerySet ownership of ComputePass (#5671)

* add new tests for checking on query set lifetime

* Fix ownership management of query sets on compute passes for write_timestamp, timestamp_writes (on desc) and pipeline statistic queries

* changelog entry
This commit is contained in:
Andreas Reich 2024-06-04 09:47:27 +02:00 committed by GitHub
parent d258d6ce73
commit 9a27ba53ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 343 additions and 211 deletions

View File

@ -58,7 +58,7 @@ fn independent_cpass<'enc>(encoder: &'enc mut wgpu::CommandEncoder) -> wgpu::Com
This is very useful for library authors, but opens up an easy way for incorrect use, so use with care.
`forget_lifetime` is zero overhead and has no side effects on pass recording.
By @wumpf in [#5569](https://github.com/gfx-rs/wgpu/pull/5569), [#5575](https://github.com/gfx-rs/wgpu/pull/5575), [#5620](https://github.com/gfx-rs/wgpu/pull/5620), [#5768](https://github.com/gfx-rs/wgpu/pull/5768) (together with @kpreid).
By @wumpf in [#5569](https://github.com/gfx-rs/wgpu/pull/5569), [#5575](https://github.com/gfx-rs/wgpu/pull/5575), [#5620](https://github.com/gfx-rs/wgpu/pull/5620), [#5768](https://github.com/gfx-rs/wgpu/pull/5768) (together with @kpreid), [#5671](https://github.com/gfx-rs/wgpu/pull/5671).
#### Querying shader compilation errors
@ -116,6 +116,7 @@ By @stefnotch in [#5410](https://github.com/gfx-rs/wgpu/pull/5410)
#### General
- Ensure render pipelines have at least 1 target. By @ErichDonGubler in [#5715](https://github.com/gfx-rs/wgpu/pull/5715)
- `wgpu::ComputePass` now internally takes ownership of `QuerySet` for both `wgpu::ComputePassTimestampWrites` as well as timestamp writes and statistics query, fixing crashes when destroying `QuerySet` before ending the pass. By @wumpf in [#5671](https://github.com/gfx-rs/wgpu/pull/5671)
#### Metal

View File

@ -1,9 +1,5 @@
//! Tests that compute passes take ownership of resources that are associated with.
//! I.e. once a resource is passed in to a compute pass, it can be dropped.
//!
//! TODO: Also should test resource ownership for:
//! * write_timestamp
//! * begin_pipeline_statistics_query
use std::num::NonZeroU64;
@ -36,15 +32,10 @@ async fn compute_pass_resource_ownership(ctx: TestingContext) {
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("encoder"),
});
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute_pass"),
timestamp_writes: None, // TODO: See description above, we should test this as well once we lift the lifetime bound.
});
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups_indirect(&indirect_buffer, 0);
@ -58,18 +49,115 @@ async fn compute_pass_resource_ownership(ctx: TestingContext) {
.panic_on_timeout();
}
// Ensure that the compute pass still executed normally.
encoder.copy_buffer_to_buffer(&gpu_buffer, 0, &cpu_buffer, 0, buffer_size);
ctx.queue.submit([encoder.finish()]);
cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ());
ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();
assert_compute_pass_executed_normally(encoder, gpu_buffer, cpu_buffer, buffer_size, ctx).await;
}
let data = cpu_buffer.slice(..).get_mapped_range();
#[gpu_test]
static COMPUTE_PASS_QUERY_SET_OWNERSHIP_PIPELINE_STATISTICS: GpuTestConfiguration =
GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.test_features_limits()
.features(wgpu::Features::PIPELINE_STATISTICS_QUERY),
)
.run_async(compute_pass_query_set_ownership_pipeline_statistics);
let floats: &[f32] = bytemuck::cast_slice(&data);
assert_eq!(floats, [2.0, 4.0, 6.0, 8.0]);
async fn compute_pass_query_set_ownership_pipeline_statistics(ctx: TestingContext) {
let ResourceSetup {
gpu_buffer,
cpu_buffer,
buffer_size,
indirect_buffer: _,
bind_group,
pipeline,
} = resource_setup(&ctx);
let query_set = ctx.device.create_query_set(&wgpu::QuerySetDescriptor {
label: Some("query_set"),
ty: wgpu::QueryType::PipelineStatistics(
wgpu::PipelineStatisticsTypes::COMPUTE_SHADER_INVOCATIONS,
),
count: 1,
});
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.begin_pipeline_statistics_query(&query_set, 0);
cpass.dispatch_workgroups(1, 1, 1);
cpass.end_pipeline_statistics_query();
// Drop the query set. Then do a device poll to make sure it's not dropped too early, no matter what.
drop(query_set);
ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();
}
assert_compute_pass_executed_normally(encoder, gpu_buffer, cpu_buffer, buffer_size, ctx).await;
}
#[gpu_test]
static COMPUTE_PASS_QUERY_TIMESTAMPS: GpuTestConfiguration =
GpuTestConfiguration::new()
.parameters(TestParameters::default().test_features_limits().features(
wgpu::Features::TIMESTAMP_QUERY | wgpu::Features::TIMESTAMP_QUERY_INSIDE_PASSES,
))
.run_async(compute_pass_query_timestamps);
async fn compute_pass_query_timestamps(ctx: TestingContext) {
let ResourceSetup {
gpu_buffer,
cpu_buffer,
buffer_size,
indirect_buffer: _,
bind_group,
pipeline,
} = resource_setup(&ctx);
let query_set_timestamp_writes = ctx.device.create_query_set(&wgpu::QuerySetDescriptor {
label: Some("query_set_timestamp_writes"),
ty: wgpu::QueryType::Timestamp,
count: 2,
});
let query_set_write_timestamp = ctx.device.create_query_set(&wgpu::QuerySetDescriptor {
label: Some("query_set_write_timestamp"),
ty: wgpu::QueryType::Timestamp,
count: 1,
});
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute_pass"),
timestamp_writes: Some(wgpu::ComputePassTimestampWrites {
query_set: &query_set_timestamp_writes,
beginning_of_pass_write_index: Some(0),
end_of_pass_write_index: Some(1),
}),
});
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.write_timestamp(&query_set_write_timestamp, 0);
cpass.dispatch_workgroups(1, 1, 1);
// Drop the query sets. Then do a device poll to make sure they're not dropped too early, no matter what.
drop(query_set_timestamp_writes);
drop(query_set_write_timestamp);
ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();
}
assert_compute_pass_executed_normally(encoder, gpu_buffer, cpu_buffer, buffer_size, ctx).await;
}
#[gpu_test]
@ -89,9 +177,7 @@ async fn compute_pass_keep_encoder_alive(ctx: TestingContext) {
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("encoder"),
});
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
let cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute_pass"),
@ -119,6 +205,26 @@ async fn compute_pass_keep_encoder_alive(ctx: TestingContext) {
valid(&ctx.device, || drop(cpass));
}
async fn assert_compute_pass_executed_normally(
mut encoder: wgpu::CommandEncoder,
gpu_buffer: wgpu::Buffer,
cpu_buffer: wgpu::Buffer,
buffer_size: u64,
ctx: TestingContext,
) {
encoder.copy_buffer_to_buffer(&gpu_buffer, 0, &cpu_buffer, 0, buffer_size);
ctx.queue.submit([encoder.finish()]);
cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ());
ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();
let data = cpu_buffer.slice(..).get_mapped_range();
let floats: &[f32] = bytemuck::cast_slice(&data);
assert_eq!(floats, [2.0, 4.0, 6.0, 8.0]);
}
// Setup ------------------------------------------------------------
struct ResourceSetup {

View File

@ -5,15 +5,15 @@ use crate::{
compute_command::{ArcComputeCommand, ComputeCommand},
end_pipeline_statistics_query,
memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
BasePass, BindGroupStateChange, CommandBuffer, CommandEncoderError, CommandEncoderStatus,
MapPassErr, PassErrorScope, QueryUseError, StateChange,
validate_and_begin_pipeline_statistics_query, BasePass, BindGroupStateChange,
CommandBuffer, CommandEncoderError, CommandEncoderStatus, MapPassErr, PassErrorScope,
QueryUseError, StateChange,
},
device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
error::{ErrorFormatter, PrettyError},
global::Global,
hal_api::HalApi,
hal_label,
id::{self},
hal_label, id,
init_tracker::MemoryInitKind,
resource::{self, Resource},
snatch::SnatchGuard,
@ -48,7 +48,7 @@ pub struct ComputePass<A: HalApi> {
/// If it is none, this pass is invalid and any operation on it will return an error.
parent: Option<Arc<CommandBuffer<A>>>,
timestamp_writes: Option<ComputePassTimestampWrites>,
timestamp_writes: Option<ArcComputePassTimestampWrites<A>>,
// Resource binding dedupe state.
current_bind_groups: BindGroupStateChange,
@ -57,11 +57,16 @@ pub struct ComputePass<A: HalApi> {
impl<A: HalApi> ComputePass<A> {
/// If the parent command buffer is invalid, the returned pass will be invalid.
fn new(parent: Option<Arc<CommandBuffer<A>>>, desc: &ComputePassDescriptor) -> Self {
fn new(parent: Option<Arc<CommandBuffer<A>>>, desc: ArcComputePassDescriptor<A>) -> Self {
let ArcComputePassDescriptor {
label,
timestamp_writes,
} = desc;
Self {
base: Some(BasePass::new(&desc.label)),
base: Some(BasePass::new(label)),
parent,
timestamp_writes: desc.timestamp_writes.cloned(),
timestamp_writes,
current_bind_groups: BindGroupStateChange::new(),
current_pipeline: StateChange::new(),
@ -107,6 +112,16 @@ pub struct ComputePassTimestampWrites {
pub end_of_pass_write_index: Option<u32>,
}
/// Describes the writing of timestamp values in a compute pass with the query set resolved.
struct ArcComputePassTimestampWrites<A: HalApi> {
/// The query set to write the timestamps to.
pub query_set: Arc<resource::QuerySet<A>>,
/// The index of the query set at which a start timestamp of this pass is written, if any.
pub beginning_of_pass_write_index: Option<u32>,
/// The index of the query set at which an end timestamp of this pass is written, if any.
pub end_of_pass_write_index: Option<u32>,
}
#[derive(Clone, Debug, Default)]
pub struct ComputePassDescriptor<'a> {
pub label: Label<'a>,
@ -114,6 +129,12 @@ pub struct ComputePassDescriptor<'a> {
pub timestamp_writes: Option<&'a ComputePassTimestampWrites>,
}
struct ArcComputePassDescriptor<'a, A: HalApi> {
pub label: &'a Label<'a>,
/// Defines where and when timestamp values will be written for this pass.
pub timestamp_writes: Option<ArcComputePassTimestampWrites<A>>,
}
#[derive(Clone, Debug, Error, Eq, PartialEq)]
#[non_exhaustive]
pub enum DispatchError {
@ -310,13 +331,37 @@ impl Global {
pub fn command_encoder_create_compute_pass<A: HalApi>(
&self,
encoder_id: id::CommandEncoderId,
desc: &ComputePassDescriptor,
desc: &ComputePassDescriptor<'_>,
) -> (ComputePass<A>, Option<CommandEncoderError>) {
let hub = A::hub(self);
let mut arc_desc = ArcComputePassDescriptor {
label: &desc.label,
timestamp_writes: None, // Handle only once we resolved the encoder.
};
match CommandBuffer::lock_encoder(hub, encoder_id) {
Ok(cmd_buf) => (ComputePass::new(Some(cmd_buf), desc), None),
Err(err) => (ComputePass::new(None, desc), Some(err)),
Ok(cmd_buf) => {
arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
let Ok(query_set) = hub.query_sets.read().get_owned(tw.query_set) else {
return (
ComputePass::new(None, arc_desc),
Some(CommandEncoderError::InvalidTimestampWritesQuerySetId),
);
};
Some(ArcComputePassTimestampWrites {
query_set,
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
})
} else {
None
};
(ComputePass::new(Some(cmd_buf), arc_desc), None)
}
Err(err) => (ComputePass::new(None, arc_desc), Some(err)),
}
}
@ -349,7 +394,7 @@ impl Global {
.take()
.ok_or(ComputePassErrorInner::PassEnded)
.map_pass_err(scope)?;
self.compute_pass_end_impl(parent, base, pass.timestamp_writes.as_ref())
self.compute_pass_end_impl(parent, base, pass.timestamp_writes.take())
}
#[doc(hidden)]
@ -360,11 +405,26 @@ impl Global {
timestamp_writes: Option<&ComputePassTimestampWrites>,
) -> Result<(), ComputePassError> {
let hub = A::hub(self);
let scope = PassErrorScope::PassEncoder(encoder_id);
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id)
.map_pass_err(PassErrorScope::PassEncoder(encoder_id))?;
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(scope)?;
let commands = ComputeCommand::resolve_compute_command_ids(A::hub(self), &base.commands)?;
let timestamp_writes = if let Some(tw) = timestamp_writes {
Some(ArcComputePassTimestampWrites {
query_set: hub
.query_sets
.read()
.get_owned(tw.query_set)
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(tw.query_set))
.map_pass_err(scope)?,
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
})
} else {
None
};
self.compute_pass_end_impl::<A>(
&cmd_buf,
BasePass {
@ -382,13 +442,11 @@ impl Global {
&self,
cmd_buf: &CommandBuffer<A>,
base: BasePass<ArcComputeCommand<A>>,
timestamp_writes: Option<&ComputePassTimestampWrites>,
mut timestamp_writes: Option<ArcComputePassTimestampWrites<A>>,
) -> Result<(), ComputePassError> {
profiling::scope!("CommandEncoder::run_compute_pass");
let pass_scope = PassErrorScope::Pass(Some(cmd_buf.as_info().id()));
let hub = A::hub(self);
let device = &cmd_buf.device;
if !device.is_valid() {
return Err(ComputePassErrorInner::InvalidDevice(
@ -410,7 +468,13 @@ impl Global {
string_data: base.string_data.to_vec(),
push_constant_data: base.push_constant_data.to_vec(),
},
timestamp_writes: timestamp_writes.cloned(),
timestamp_writes: timestamp_writes
.as_ref()
.map(|tw| ComputePassTimestampWrites {
query_set: tw.query_set.as_info().id(),
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
}),
});
}
@ -428,8 +492,6 @@ impl Global {
*status = CommandEncoderStatus::Error;
let raw = encoder.open().map_pass_err(pass_scope)?;
let query_set_guard = hub.query_sets.read();
let mut state = State {
binder: Binder::new(),
pipeline: None,
@ -441,12 +503,19 @@ impl Global {
let mut string_offset = 0;
let mut active_query = None;
let timestamp_writes = if let Some(tw) = timestamp_writes {
let query_set: &resource::QuerySet<A> = tracker
.query_sets
.add_single(&*query_set_guard, tw.query_set)
.ok_or(ComputePassErrorInner::InvalidQuerySet(tw.query_set))
.map_pass_err(pass_scope)?;
let snatch_guard = device.snatchable_lock.read();
let indices = &device.tracker_indices;
tracker.buffers.set_size(indices.buffers.size());
tracker.textures.set_size(indices.textures.size());
tracker.bind_groups.set_size(indices.bind_groups.size());
tracker
.compute_pipelines
.set_size(indices.compute_pipelines.size());
tracker.query_sets.set_size(indices.query_sets.size());
let timestamp_writes = if let Some(tw) = timestamp_writes.take() {
let query_set = tracker.query_sets.insert_single(tw.query_set);
// Unlike in render passes we can't delay resetting the query sets since
// there is no auxiliary pass.
@ -476,17 +545,6 @@ impl Global {
None
};
let snatch_guard = device.snatchable_lock.read();
let indices = &device.tracker_indices;
tracker.buffers.set_size(indices.buffers.size());
tracker.textures.set_size(indices.textures.size());
tracker.bind_groups.set_size(indices.bind_groups.size());
tracker
.compute_pipelines
.set_size(indices.compute_pipelines.size());
tracker.query_sets.set_size(indices.query_sets.size());
let discard_hal_labels = self
.instance
.flags
@ -812,7 +870,6 @@ impl Global {
query_set,
query_index,
} => {
let query_set_id = query_set.as_info().id();
let scope = PassErrorScope::WriteTimestamp;
device
@ -822,33 +879,29 @@ impl Global {
let query_set = tracker.query_sets.insert_single(query_set);
query_set
.validate_and_write_timestamp(raw, query_set_id, query_index, None)
.validate_and_write_timestamp(raw, query_index, None)
.map_pass_err(scope)?;
}
ArcComputeCommand::BeginPipelineStatisticsQuery {
query_set,
query_index,
} => {
let query_set_id = query_set.as_info().id();
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
let query_set = tracker.query_sets.insert_single(query_set);
query_set
.validate_and_begin_pipeline_statistics_query(
raw,
query_set_id,
query_index,
None,
&mut active_query,
)
.map_pass_err(scope)?;
validate_and_begin_pipeline_statistics_query(
query_set.clone(),
raw,
query_index,
None,
&mut active_query,
)
.map_pass_err(scope)?;
}
ArcComputeCommand::EndPipelineStatisticsQuery => {
let scope = PassErrorScope::EndPipelineStatisticsQuery;
end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
.map_pass_err(scope)?;
end_pipeline_statistics_query(raw, &mut active_query).map_pass_err(scope)?;
}
}
}
@ -919,10 +972,9 @@ impl Global {
let bind_group = hub
.bind_groups
.read()
.get(bind_group_id)
.get_owned(bind_group_id)
.map_err(|_| ComputePassErrorInner::InvalidBindGroup(index))
.map_pass_err(scope)?
.clone();
.map_pass_err(scope)?;
base.commands.push(ArcComputeCommand::SetBindGroup {
index,
@ -952,10 +1004,9 @@ impl Global {
let pipeline = hub
.compute_pipelines
.read()
.get(pipeline_id)
.get_owned(pipeline_id)
.map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id))
.map_pass_err(scope)?
.clone();
.map_pass_err(scope)?;
base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
@ -1035,10 +1086,9 @@ impl Global {
let buffer = hub
.buffers
.read()
.get(buffer_id)
.get_owned(buffer_id)
.map_err(|_| ComputePassErrorInner::InvalidBuffer(buffer_id))
.map_pass_err(scope)?
.clone();
.map_pass_err(scope)?;
base.commands
.push(ArcComputeCommand::<A>::DispatchIndirect { buffer, offset });
@ -1109,10 +1159,9 @@ impl Global {
let query_set = hub
.query_sets
.read()
.get(query_set_id)
.get_owned(query_set_id)
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
.map_pass_err(scope)?
.clone();
.map_pass_err(scope)?;
base.commands.push(ArcComputeCommand::WriteTimestamp {
query_set,
@ -1135,10 +1184,9 @@ impl Global {
let query_set = hub
.query_sets
.read()
.get(query_set_id)
.get_owned(query_set_id)
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
.map_pass_err(scope)?
.clone();
.map_pass_err(scope)?;
base.commands
.push(ArcComputeCommand::BeginPipelineStatisticsQuery {

View File

@ -633,6 +633,8 @@ pub enum CommandEncoderError {
Device(#[from] DeviceError),
#[error("Command encoder is locked by a previously created render/compute pass. Before recording any new commands, the pass must be ended.")]
Locked,
#[error("QuerySet provided for pass timestamp writes is invalid.")]
InvalidTimestampWritesQuerySetId,
}
impl Global {

View File

@ -13,7 +13,7 @@ use crate::{
storage::Storage,
Epoch, FastHashMap, Index,
};
use std::{iter, marker::PhantomData};
use std::{iter, marker::PhantomData, sync::Arc};
use thiserror::Error;
use wgt::BufferAddress;
@ -185,15 +185,14 @@ pub enum ResolveError {
impl<A: HalApi> QuerySet<A> {
fn validate_query(
&self,
query_set_id: id::QuerySetId,
query_type: SimplifiedQueryType,
query_index: u32,
reset_state: Option<&mut QueryResetMap<A>>,
) -> Result<&A::QuerySet, QueryUseError> {
) -> Result<(), QueryUseError> {
// We need to defer our resets because we are in a renderpass,
// add the usage to the reset map.
if let Some(reset) = reset_state {
let used = reset.use_query_set(query_set_id, self, query_index);
let used = reset.use_query_set(self.info.id(), self, query_index);
if used {
return Err(QueryUseError::UsedTwiceInsideRenderpass { query_index });
}
@ -214,133 +213,110 @@ impl<A: HalApi> QuerySet<A> {
});
}
Ok(self.raw())
Ok(())
}
pub(super) fn validate_and_write_timestamp(
&self,
raw_encoder: &mut A::CommandEncoder,
query_set_id: id::QuerySetId,
query_index: u32,
reset_state: Option<&mut QueryResetMap<A>>,
) -> Result<(), QueryUseError> {
let needs_reset = reset_state.is_none();
let query_set = self.validate_query(
query_set_id,
SimplifiedQueryType::Timestamp,
query_index,
reset_state,
)?;
self.validate_query(SimplifiedQueryType::Timestamp, query_index, reset_state)?;
unsafe {
// If we don't have a reset state tracker which can defer resets, we must reset now.
if needs_reset {
raw_encoder.reset_queries(self.raw(), query_index..(query_index + 1));
}
raw_encoder.write_timestamp(query_set, query_index);
}
Ok(())
}
pub(super) fn validate_and_begin_occlusion_query(
&self,
raw_encoder: &mut A::CommandEncoder,
query_set_id: id::QuerySetId,
query_index: u32,
reset_state: Option<&mut QueryResetMap<A>>,
active_query: &mut Option<(id::QuerySetId, u32)>,
) -> Result<(), QueryUseError> {
let needs_reset = reset_state.is_none();
let query_set = self.validate_query(
query_set_id,
SimplifiedQueryType::Occlusion,
query_index,
reset_state,
)?;
if let Some((_old_id, old_idx)) = active_query.replace((query_set_id, query_index)) {
return Err(QueryUseError::AlreadyStarted {
active_query_index: old_idx,
new_query_index: query_index,
});
}
unsafe {
// If we don't have a reset state tracker which can defer resets, we must reset now.
if needs_reset {
raw_encoder
.reset_queries(self.raw.as_ref().unwrap(), query_index..(query_index + 1));
}
raw_encoder.begin_query(query_set, query_index);
}
Ok(())
}
pub(super) fn validate_and_begin_pipeline_statistics_query(
&self,
raw_encoder: &mut A::CommandEncoder,
query_set_id: id::QuerySetId,
query_index: u32,
reset_state: Option<&mut QueryResetMap<A>>,
active_query: &mut Option<(id::QuerySetId, u32)>,
) -> Result<(), QueryUseError> {
let needs_reset = reset_state.is_none();
let query_set = self.validate_query(
query_set_id,
SimplifiedQueryType::PipelineStatistics,
query_index,
reset_state,
)?;
if let Some((_old_id, old_idx)) = active_query.replace((query_set_id, query_index)) {
return Err(QueryUseError::AlreadyStarted {
active_query_index: old_idx,
new_query_index: query_index,
});
}
unsafe {
// If we don't have a reset state tracker which can defer resets, we must reset now.
if needs_reset {
raw_encoder.reset_queries(self.raw(), query_index..(query_index + 1));
}
raw_encoder.begin_query(query_set, query_index);
raw_encoder.write_timestamp(self.raw(), query_index);
}
Ok(())
}
}
pub(super) fn validate_and_begin_occlusion_query<A: HalApi>(
query_set: Arc<QuerySet<A>>,
raw_encoder: &mut A::CommandEncoder,
query_index: u32,
reset_state: Option<&mut QueryResetMap<A>>,
active_query: &mut Option<(Arc<QuerySet<A>>, u32)>,
) -> Result<(), QueryUseError> {
let needs_reset = reset_state.is_none();
query_set.validate_query(SimplifiedQueryType::Occlusion, query_index, reset_state)?;
if let Some((_old, old_idx)) = active_query.take() {
return Err(QueryUseError::AlreadyStarted {
active_query_index: old_idx,
new_query_index: query_index,
});
}
let (query_set, _) = &active_query.insert((query_set, query_index));
unsafe {
// If we don't have a reset state tracker which can defer resets, we must reset now.
if needs_reset {
raw_encoder.reset_queries(query_set.raw(), query_index..(query_index + 1));
}
raw_encoder.begin_query(query_set.raw(), query_index);
}
Ok(())
}
pub(super) fn end_occlusion_query<A: HalApi>(
raw_encoder: &mut A::CommandEncoder,
storage: &Storage<QuerySet<A>>,
active_query: &mut Option<(id::QuerySetId, u32)>,
active_query: &mut Option<(Arc<QuerySet<A>>, u32)>,
) -> Result<(), QueryUseError> {
if let Some((query_set_id, query_index)) = active_query.take() {
// We can unwrap here as the validity was validated when the active query was set
let query_set = storage.get(query_set_id).unwrap();
if let Some((query_set, query_index)) = active_query.take() {
unsafe { raw_encoder.end_query(query_set.raw.as_ref().unwrap(), query_index) };
Ok(())
} else {
Err(QueryUseError::AlreadyStopped)
}
}
pub(super) fn validate_and_begin_pipeline_statistics_query<A: HalApi>(
query_set: Arc<QuerySet<A>>,
raw_encoder: &mut A::CommandEncoder,
query_index: u32,
reset_state: Option<&mut QueryResetMap<A>>,
active_query: &mut Option<(Arc<QuerySet<A>>, u32)>,
) -> Result<(), QueryUseError> {
let needs_reset = reset_state.is_none();
query_set.validate_query(
SimplifiedQueryType::PipelineStatistics,
query_index,
reset_state,
)?;
if let Some((_old, old_idx)) = active_query.take() {
return Err(QueryUseError::AlreadyStarted {
active_query_index: old_idx,
new_query_index: query_index,
});
}
let (query_set, _) = &active_query.insert((query_set, query_index));
unsafe {
// If we don't have a reset state tracker which can defer resets, we must reset now.
if needs_reset {
raw_encoder.reset_queries(query_set.raw(), query_index..(query_index + 1));
}
raw_encoder.begin_query(query_set.raw(), query_index);
}
Ok(())
}
pub(super) fn end_pipeline_statistics_query<A: HalApi>(
raw_encoder: &mut A::CommandEncoder,
storage: &Storage<QuerySet<A>>,
active_query: &mut Option<(id::QuerySetId, u32)>,
active_query: &mut Option<(Arc<QuerySet<A>>, u32)>,
) -> Result<(), QueryUseError> {
if let Some((query_set_id, query_index)) = active_query.take() {
// We can unwrap here as the validity was validated when the active query was set
let query_set = storage.get(query_set_id).unwrap();
if let Some((query_set, query_index)) = active_query.take() {
unsafe { raw_encoder.end_query(query_set.raw(), query_index) };
Ok(())
} else {
Err(QueryUseError::AlreadyStopped)
@ -384,7 +360,7 @@ impl Global {
.add_single(&*query_set_guard, query_set_id)
.ok_or(QueryError::InvalidQuerySet(query_set_id))?;
query_set.validate_and_write_timestamp(raw_encoder, query_set_id, query_index, None)?;
query_set.validate_and_write_timestamp(raw_encoder, query_index, None)?;
Ok(())
}

View File

@ -1,3 +1,6 @@
use crate::command::{
validate_and_begin_occlusion_query, validate_and_begin_pipeline_statistics_query,
};
use crate::resource::Resource;
use crate::snatch::SnatchGuard;
use crate::{
@ -2258,7 +2261,6 @@ impl Global {
query_set
.validate_and_write_timestamp(
raw,
query_set_id,
query_index,
Some(&mut cmd_buf_data.pending_query_resets),
)
@ -2278,22 +2280,20 @@ impl Global {
.ok_or(RenderCommandError::InvalidQuerySet(query_set_id))
.map_pass_err(scope)?;
query_set
.validate_and_begin_occlusion_query(
raw,
query_set_id,
query_index,
Some(&mut cmd_buf_data.pending_query_resets),
&mut active_query,
)
.map_pass_err(scope)?;
validate_and_begin_occlusion_query(
query_set.clone(),
raw,
query_index,
Some(&mut cmd_buf_data.pending_query_resets),
&mut active_query,
)
.map_pass_err(scope)?;
}
RenderCommand::EndOcclusionQuery => {
api_log!("RenderPass::end_occlusion_query");
let scope = PassErrorScope::EndOcclusionQuery;
end_occlusion_query(raw, &*query_set_guard, &mut active_query)
.map_pass_err(scope)?;
end_occlusion_query(raw, &mut active_query).map_pass_err(scope)?;
}
RenderCommand::BeginPipelineStatisticsQuery {
query_set_id,
@ -2308,21 +2308,20 @@ impl Global {
.ok_or(RenderCommandError::InvalidQuerySet(query_set_id))
.map_pass_err(scope)?;
query_set
.validate_and_begin_pipeline_statistics_query(
raw,
query_set_id,
query_index,
Some(&mut cmd_buf_data.pending_query_resets),
&mut active_query,
)
.map_pass_err(scope)?;
validate_and_begin_pipeline_statistics_query(
query_set.clone(),
raw,
query_index,
Some(&mut cmd_buf_data.pending_query_resets),
&mut active_query,
)
.map_pass_err(scope)?;
}
RenderCommand::EndPipelineStatisticsQuery => {
api_log!("RenderPass::end_pipeline_statistics_query");
let scope = PassErrorScope::EndPipelineStatisticsQuery;
end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
end_pipeline_statistics_query(raw, &mut active_query)
.map_pass_err(scope)?;
}
RenderCommand::ExecuteBundle(bundle_id) => {