move same device validation in compute_pass_end_impl

This commit is contained in:
teoxoy 2024-06-18 12:12:31 +02:00 committed by Nicolas Silva
parent 53f8477b15
commit adfb183dc0

View File

@ -53,11 +53,6 @@ pub struct ComputePass<A: HalApi> {
// Resource binding dedupe state. // Resource binding dedupe state.
current_bind_groups: BindGroupStateChange, current_bind_groups: BindGroupStateChange,
current_pipeline: StateChange<id::ComputePipelineId>, current_pipeline: StateChange<id::ComputePipelineId>,
/// The device that this pass is associated with.
///
/// Used for quick validation during recording.
device_id: id::DeviceId,
} }
impl<A: HalApi> ComputePass<A> { impl<A: HalApi> ComputePass<A> {
@ -68,10 +63,6 @@ impl<A: HalApi> ComputePass<A> {
timestamp_writes, timestamp_writes,
} = desc; } = desc;
let device_id = parent
.as_ref()
.map_or(id::DeviceId::dummy(0), |p| p.device.as_info().id());
Self { Self {
base: Some(BasePass::new(label)), base: Some(BasePass::new(label)),
parent, parent,
@ -79,8 +70,6 @@ impl<A: HalApi> ComputePass<A> {
current_bind_groups: BindGroupStateChange::new(), current_bind_groups: BindGroupStateChange::new(),
current_pipeline: StateChange::new(), current_pipeline: StateChange::new(),
device_id,
} }
} }
@ -593,6 +582,8 @@ impl Global {
} => { } => {
let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id()); let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id());
bind_group.device.same_device(device).map_pass_err(scope)?;
let max_bind_groups = cmd_buf.limits.max_bind_groups; let max_bind_groups = cmd_buf.limits.max_bind_groups;
if index >= max_bind_groups { if index >= max_bind_groups {
return Err(ComputePassErrorInner::BindGroupIndexOutOfRange { return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
@ -658,6 +649,8 @@ impl Global {
let pipeline_id = pipeline.as_info().id(); let pipeline_id = pipeline.as_info().id();
let scope = PassErrorScope::SetPipelineCompute(pipeline_id); let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
pipeline.device.same_device(device).map_pass_err(scope)?;
state.pipeline = Some(pipeline_id); state.pipeline = Some(pipeline_id);
let pipeline = tracker.compute_pipelines.insert_single(pipeline); let pipeline = tracker.compute_pipelines.insert_single(pipeline);
@ -797,6 +790,8 @@ impl Global {
pipeline: state.pipeline, pipeline: state.pipeline,
}; };
buffer.device.same_device(device).map_pass_err(scope)?;
state.is_ready().map_pass_err(scope)?; state.is_ready().map_pass_err(scope)?;
device device
@ -890,6 +885,8 @@ impl Global {
} => { } => {
let scope = PassErrorScope::WriteTimestamp; let scope = PassErrorScope::WriteTimestamp;
query_set.device.same_device(device).map_pass_err(scope)?;
device device
.require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES) .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
.map_pass_err(scope)?; .map_pass_err(scope)?;
@ -906,6 +903,8 @@ impl Global {
} => { } => {
let scope = PassErrorScope::BeginPipelineStatisticsQuery; let scope = PassErrorScope::BeginPipelineStatisticsQuery;
query_set.device.same_device(device).map_pass_err(scope)?;
let query_set = tracker.query_sets.insert_single(query_set); let query_set = tracker.query_sets.insert_single(query_set);
validate_and_begin_pipeline_statistics_query( validate_and_begin_pipeline_statistics_query(
@ -994,10 +993,6 @@ impl Global {
.map_err(|_| ComputePassErrorInner::InvalidBindGroup(index)) .map_err(|_| ComputePassErrorInner::InvalidBindGroup(index))
.map_pass_err(scope)?; .map_pass_err(scope)?;
if bind_group.device.as_info().id() != pass.device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}
base.commands.push(ArcComputeCommand::SetBindGroup { base.commands.push(ArcComputeCommand::SetBindGroup {
index, index,
num_dynamic_offsets: offsets.len(), num_dynamic_offsets: offsets.len(),
@ -1016,7 +1011,6 @@ impl Global {
let scope = PassErrorScope::SetPipelineCompute(pipeline_id); let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
let device_id = pass.device_id;
let base = pass.base_mut(scope)?; let base = pass.base_mut(scope)?;
if redundant { if redundant {
// Do redundant early-out **after** checking whether the pass is ended or not. // Do redundant early-out **after** checking whether the pass is ended or not.
@ -1031,10 +1025,6 @@ impl Global {
.map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id)) .map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id))
.map_pass_err(scope)?; .map_pass_err(scope)?;
if pipeline.device.as_info().id() != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}
base.commands.push(ArcComputeCommand::SetPipeline(pipeline)); base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
Ok(()) Ok(())
@ -1108,7 +1098,6 @@ impl Global {
indirect: true, indirect: true,
pipeline: pass.current_pipeline.last_state, pipeline: pass.current_pipeline.last_state,
}; };
let device_id = pass.device_id;
let base = pass.base_mut(scope)?; let base = pass.base_mut(scope)?;
let buffer = hub let buffer = hub
@ -1118,10 +1107,6 @@ impl Global {
.map_err(|_| ComputePassErrorInner::InvalidBuffer(buffer_id)) .map_err(|_| ComputePassErrorInner::InvalidBuffer(buffer_id))
.map_pass_err(scope)?; .map_pass_err(scope)?;
if buffer.device.as_info().id() != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}
base.commands base.commands
.push(ArcComputeCommand::<A>::DispatchIndirect { buffer, offset }); .push(ArcComputeCommand::<A>::DispatchIndirect { buffer, offset });
@ -1185,7 +1170,6 @@ impl Global {
query_index: u32, query_index: u32,
) -> Result<(), ComputePassError> { ) -> Result<(), ComputePassError> {
let scope = PassErrorScope::WriteTimestamp; let scope = PassErrorScope::WriteTimestamp;
let device_id = pass.device_id;
let base = pass.base_mut(scope)?; let base = pass.base_mut(scope)?;
let hub = A::hub(self); let hub = A::hub(self);
@ -1196,10 +1180,6 @@ impl Global {
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id)) .map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
.map_pass_err(scope)?; .map_pass_err(scope)?;
if query_set.device.as_info().id() != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}
base.commands.push(ArcComputeCommand::WriteTimestamp { base.commands.push(ArcComputeCommand::WriteTimestamp {
query_set, query_set,
query_index, query_index,
@ -1215,7 +1195,6 @@ impl Global {
query_index: u32, query_index: u32,
) -> Result<(), ComputePassError> { ) -> Result<(), ComputePassError> {
let scope = PassErrorScope::BeginPipelineStatisticsQuery; let scope = PassErrorScope::BeginPipelineStatisticsQuery;
let device_id = pass.device_id;
let base = pass.base_mut(scope)?; let base = pass.base_mut(scope)?;
let hub = A::hub(self); let hub = A::hub(self);
@ -1226,10 +1205,6 @@ impl Global {
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id)) .map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
.map_pass_err(scope)?; .map_pass_err(scope)?;
if query_set.device.as_info().id() != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}
base.commands base.commands
.push(ArcComputeCommand::BeginPipelineStatisticsQuery { .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
query_set, query_set,