refactor(core): extract Global::validate_pass_timestamp_writes

This commit is contained in:
Erich Gubler 2024-11-21 22:50:26 -05:00
parent 29e7fe3fe2
commit 19d80fe229
3 changed files with 59 additions and 105 deletions

View File

@ -28,9 +28,7 @@ use crate::{
use thiserror::Error;
use wgt::{BufferAddress, DynamicOffset};
use super::{
bind::BinderError, memory_init::CommandBufferTextureMemoryActions, SimplifiedQueryType,
};
use super::{bind::BinderError, memory_init::CommandBufferTextureMemoryActions};
use crate::ray_tracing::TlasAction;
use std::sync::Arc;
use std::{fmt, mem::size_of, str};
@ -310,64 +308,15 @@ impl Global {
Err(e) => return make_err(e, arc_desc),
};
arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
let &PassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
} = tw;
match cmd_buf
.device
.require_features(wgt::Features::TIMESTAMP_QUERY)
{
Ok(()) => (),
Err(e) => return make_err(e.into(), arc_desc),
}
let query_set = match hub.query_sets.get(query_set).get() {
Ok(query_set) => query_set,
Err(e) => return make_err(e.into(), arc_desc),
};
match query_set.same_device(&cmd_buf.device) {
Ok(()) => (),
Err(e) => return make_err(e.into(), arc_desc),
}
for idx in [beginning_of_pass_write_index, end_of_pass_write_index]
.into_iter()
.flatten()
{
match query_set.validate_query(SimplifiedQueryType::Timestamp, idx, None) {
Ok(()) => (),
Err(e) => return make_err(e.into(), arc_desc),
}
}
if let Some((begin, end)) = beginning_of_pass_write_index.zip(end_of_pass_write_index) {
if begin == end {
return make_err(
CommandEncoderError::TimestampWriteIndicesEqual { idx: begin },
arc_desc,
);
}
}
if beginning_of_pass_write_index
.or(end_of_pass_write_index)
.is_none()
{
return make_err(CommandEncoderError::TimestampWriteIndicesMissing, arc_desc);
}
Some(ArcPassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
arc_desc.timestamp_writes = match desc
.timestamp_writes
.map(|tw| {
Self::validate_pass_timestamp_writes(&cmd_buf.device, &hub.query_sets.read(), tw)
})
} else {
None
.transpose()
{
Ok(ok) => ok,
Err(e) => return make_err(e, arc_desc),
};
(ComputePass::new(Some(cmd_buf), arc_desc), None)

View File

@ -33,7 +33,8 @@ use crate::snatch::SnatchGuard;
use crate::init_tracker::BufferInitTrackerAction;
use crate::ray_tracing::{BlasAction, TlasAction};
use crate::resource::{InvalidResourceError, Labeled};
use crate::resource::{Fallible, InvalidResourceError, Labeled, ParentDevice as _, QuerySet};
use crate::storage::Storage;
use crate::track::{DeviceTracker, Tracker, UsageScope};
use crate::LabelHelpers;
use crate::{api_log, global::Global, id, resource_log, Label};
@ -782,6 +783,50 @@ impl Global {
}
Ok(())
}
fn validate_pass_timestamp_writes(
device: &Device,
query_sets: &Storage<Fallible<QuerySet>>,
timestamp_writes: &PassTimestampWrites,
) -> Result<ArcPassTimestampWrites, CommandEncoderError> {
let &PassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
} = timestamp_writes;
device.require_features(wgt::Features::TIMESTAMP_QUERY)?;
let query_set = query_sets.get(query_set).get()?;
query_set.same_device(device)?;
for idx in [beginning_of_pass_write_index, end_of_pass_write_index]
.into_iter()
.flatten()
{
query_set.validate_query(SimplifiedQueryType::Timestamp, idx, None)?;
}
if let Some((begin, end)) = beginning_of_pass_write_index.zip(end_of_pass_write_index) {
if begin == end {
return Err(CommandEncoderError::TimestampWriteIndicesEqual { idx: begin });
}
}
if beginning_of_pass_write_index
.or(end_of_pass_write_index)
.is_none()
{
return Err(CommandEncoderError::TimestampWriteIndicesMissing);
}
Ok(ArcPassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
})
}
}
fn push_constant_clear<PushFn>(offset: u32, size_bytes: u32, mut push_fn: PushFn)

View File

@ -1,7 +1,6 @@
use crate::binding_model::BindGroup;
use crate::command::{
validate_and_begin_occlusion_query, validate_and_begin_pipeline_statistics_query,
SimplifiedQueryType,
};
use crate::init_tracker::BufferInitTrackerAction;
use crate::pipeline::RenderPipeline;
@ -1392,49 +1391,10 @@ impl Global {
None
};
arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
let &PassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
} = tw;
let query_set = query_sets.get(query_set).get()?;
device.require_features(wgt::Features::TIMESTAMP_QUERY)?;
query_set.same_device(device)?;
for idx in [beginning_of_pass_write_index, end_of_pass_write_index]
.into_iter()
.flatten()
{
query_set.validate_query(SimplifiedQueryType::Timestamp, idx, None)?;
}
if let Some((begin, end)) =
beginning_of_pass_write_index.zip(end_of_pass_write_index)
{
if begin == end {
return Err(CommandEncoderError::TimestampWriteIndicesEqual { idx: begin });
}
}
if beginning_of_pass_write_index
.or(end_of_pass_write_index)
.is_none()
{
return Err(CommandEncoderError::TimestampWriteIndicesMissing);
}
Some(ArcPassTimestampWrites {
query_set,
beginning_of_pass_write_index,
end_of_pass_write_index,
})
} else {
None
};
arc_desc.timestamp_writes = desc
.timestamp_writes
.map(|tw| Global::validate_pass_timestamp_writes(device, &query_sets, tw))
.transpose()?;
arc_desc.occlusion_query_set =
if let Some(occlusion_query_set) = desc.occlusion_query_set {