move pipeline ident to DispatchError::IncompatibleBindGroup

This commit is contained in:
teoxoy 2024-06-26 11:11:49 +02:00 committed by Teodor Tanasoaia
parent 1e9844af29
commit d0e63c5c05
3 changed files with 31 additions and 50 deletions

View File

@ -17,8 +17,10 @@ use crate::{
hal_api::HalApi,
hal_label, id,
init_tracker::{BufferInitTrackerAction, MemoryInitKind},
pipeline::ComputePipeline,
resource::{
self, Buffer, DestroyedResourceError, MissingBufferUsageError, ParentDevice, Resource,
ResourceErrorIdent,
},
snatch::SnatchGuard,
track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope},
@ -136,13 +138,17 @@ struct ArcComputePassDescriptor<'a, A: HalApi> {
pub timestamp_writes: Option<ArcComputePassTimestampWrites<A>>,
}
#[derive(Clone, Debug, Error, Eq, PartialEq)]
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum DispatchError {
#[error("Compute pipeline must be set")]
MissingPipeline,
#[error("Incompatible bind group at index {index} in the current compute pipeline")]
IncompatibleBindGroup { index: u32, diff: Vec<String> },
#[error("Bind group at index {index} is incompatible with the current set {pipeline}")]
IncompatibleBindGroup {
index: u32,
pipeline: ResourceErrorIdent,
diff: Vec<String>,
},
#[error(
"Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
)]
@ -254,7 +260,7 @@ where
struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder, A: HalApi> {
binder: Binder<A>,
pipeline: Option<id::ComputePipelineId>,
pipeline: Option<Arc<ComputePipeline<A>>>,
scope: UsageScope<'scope, A>,
debug_scope_depth: u32,
@ -284,22 +290,20 @@ impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder, A: HalApi>
State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder, A>
{
fn is_ready(&self) -> Result<(), DispatchError> {
let bind_mask = self.binder.invalid_mask();
if bind_mask != 0 {
//let (expected, provided) = self.binder.entries[index as usize].info();
let index = bind_mask.trailing_zeros();
return Err(DispatchError::IncompatibleBindGroup {
index,
diff: self.binder.bgl_diff(),
});
if let Some(pipeline) = self.pipeline.as_ref() {
let bind_mask = self.binder.invalid_mask();
if bind_mask != 0 {
return Err(DispatchError::IncompatibleBindGroup {
index: bind_mask.trailing_zeros(),
pipeline: pipeline.error_ident(),
diff: self.binder.bgl_diff(),
});
}
self.binder.check_late_buffer_bindings()?;
Ok(())
} else {
Err(DispatchError::MissingPipeline)
}
if self.pipeline.is_none() {
return Err(DispatchError::MissingPipeline);
}
self.binder.check_late_buffer_bindings()?;
Ok(())
}
// `extra_buffer` is there to represent the indirect buffer that is also
@ -630,17 +634,11 @@ impl Global {
.map_pass_err(scope)?;
}
ArcComputeCommand::Dispatch(groups) => {
let scope = PassErrorScope::Dispatch {
indirect: false,
pipeline: state.pipeline,
};
let scope = PassErrorScope::Dispatch { indirect: false };
dispatch(&mut state, groups).map_pass_err(scope)?;
}
ArcComputeCommand::DispatchIndirect { buffer, offset } => {
let scope = PassErrorScope::Dispatch {
indirect: true,
pipeline: state.pipeline,
};
let scope = PassErrorScope::Dispatch { indirect: true };
dispatch_indirect(&mut state, cmd_buf, buffer, offset).map_pass_err(scope)?;
}
ArcComputeCommand::PushDebugGroup { color: _, len } => {
@ -798,11 +796,11 @@ fn set_bind_group<A: HalApi>(
fn set_pipeline<A: HalApi>(
state: &mut State<A>,
cmd_buf: &CommandBuffer<A>,
pipeline: Arc<crate::pipeline::ComputePipeline<A>>,
pipeline: Arc<ComputePipeline<A>>,
) -> Result<(), ComputePassErrorInner> {
pipeline.same_device_as(cmd_buf)?;
state.pipeline = Some(pipeline.as_info().id());
state.pipeline = Some(pipeline.clone());
let pipeline = state.tracker.compute_pipelines.insert_single(pipeline);
@ -1150,10 +1148,7 @@ impl Global {
groups_y: u32,
groups_z: u32,
) -> Result<(), ComputePassError> {
let scope = PassErrorScope::Dispatch {
indirect: false,
pipeline: pass.current_pipeline.last_state,
};
let scope = PassErrorScope::Dispatch { indirect: false };
let base = pass.base_mut(scope)?;
base.commands.push(ArcComputeCommand::<A>::Dispatch([
@ -1170,10 +1165,7 @@ impl Global {
offset: BufferAddress,
) -> Result<(), ComputePassError> {
let hub = A::hub(self);
let scope = PassErrorScope::Dispatch {
indirect: true,
pipeline: pass.current_pipeline.last_state,
};
let scope = PassErrorScope::Dispatch { indirect: true };
let base = pass.base_mut(scope)?;
let buffer = hub

View File

@ -128,10 +128,7 @@ impl ComputeCommand {
ArcComputeCommand::DispatchIndirect {
buffer: buffers_guard.get_owned(buffer_id).map_err(|_| {
ComputePassError {
scope: PassErrorScope::Dispatch {
indirect: true,
pipeline: None, // TODO: not used right now, but once we do the resolve during recording we can use this again.
},
scope: PassErrorScope::Dispatch { indirect: true },
inner: ComputePassErrorInner::InvalidBufferId(buffer_id),
}
})?,

View File

@ -907,10 +907,7 @@ pub enum PassErrorScope {
#[error("In a execute_bundle command")]
ExecuteBundle,
#[error("In a dispatch command, indirect:{indirect}")]
Dispatch {
indirect: bool,
pipeline: Option<id::ComputePipelineId>,
},
Dispatch { indirect: bool },
#[error("In a push_debug_group command")]
PushDebugGroup,
#[error("In a pop_debug_group command")]
@ -949,11 +946,6 @@ impl PrettyError for PassErrorScope {
} => {
fmt.render_pipeline_label(&id);
}
Self::Dispatch {
pipeline: Some(id), ..
} => {
fmt.compute_pipeline_label(&id);
}
_ => {}
}
}