mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-22 14:55:05 +00:00
move same device validation in compute_pass_end_impl
This commit is contained in:
parent
53f8477b15
commit
adfb183dc0
@ -53,11 +53,6 @@ pub struct ComputePass<A: HalApi> {
|
|||||||
// Resource binding dedupe state.
|
// Resource binding dedupe state.
|
||||||
current_bind_groups: BindGroupStateChange,
|
current_bind_groups: BindGroupStateChange,
|
||||||
current_pipeline: StateChange<id::ComputePipelineId>,
|
current_pipeline: StateChange<id::ComputePipelineId>,
|
||||||
|
|
||||||
/// The device that this pass is associated with.
|
|
||||||
///
|
|
||||||
/// Used for quick validation during recording.
|
|
||||||
device_id: id::DeviceId,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A: HalApi> ComputePass<A> {
|
impl<A: HalApi> ComputePass<A> {
|
||||||
@ -68,10 +63,6 @@ impl<A: HalApi> ComputePass<A> {
|
|||||||
timestamp_writes,
|
timestamp_writes,
|
||||||
} = desc;
|
} = desc;
|
||||||
|
|
||||||
let device_id = parent
|
|
||||||
.as_ref()
|
|
||||||
.map_or(id::DeviceId::dummy(0), |p| p.device.as_info().id());
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
base: Some(BasePass::new(label)),
|
base: Some(BasePass::new(label)),
|
||||||
parent,
|
parent,
|
||||||
@ -79,8 +70,6 @@ impl<A: HalApi> ComputePass<A> {
|
|||||||
|
|
||||||
current_bind_groups: BindGroupStateChange::new(),
|
current_bind_groups: BindGroupStateChange::new(),
|
||||||
current_pipeline: StateChange::new(),
|
current_pipeline: StateChange::new(),
|
||||||
|
|
||||||
device_id,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -593,6 +582,8 @@ impl Global {
|
|||||||
} => {
|
} => {
|
||||||
let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id());
|
let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id());
|
||||||
|
|
||||||
|
bind_group.device.same_device(device).map_pass_err(scope)?;
|
||||||
|
|
||||||
let max_bind_groups = cmd_buf.limits.max_bind_groups;
|
let max_bind_groups = cmd_buf.limits.max_bind_groups;
|
||||||
if index >= max_bind_groups {
|
if index >= max_bind_groups {
|
||||||
return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
|
return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
|
||||||
@ -658,6 +649,8 @@ impl Global {
|
|||||||
let pipeline_id = pipeline.as_info().id();
|
let pipeline_id = pipeline.as_info().id();
|
||||||
let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
|
let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
|
||||||
|
|
||||||
|
pipeline.device.same_device(device).map_pass_err(scope)?;
|
||||||
|
|
||||||
state.pipeline = Some(pipeline_id);
|
state.pipeline = Some(pipeline_id);
|
||||||
|
|
||||||
let pipeline = tracker.compute_pipelines.insert_single(pipeline);
|
let pipeline = tracker.compute_pipelines.insert_single(pipeline);
|
||||||
@ -797,6 +790,8 @@ impl Global {
|
|||||||
pipeline: state.pipeline,
|
pipeline: state.pipeline,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
buffer.device.same_device(device).map_pass_err(scope)?;
|
||||||
|
|
||||||
state.is_ready().map_pass_err(scope)?;
|
state.is_ready().map_pass_err(scope)?;
|
||||||
|
|
||||||
device
|
device
|
||||||
@ -890,6 +885,8 @@ impl Global {
|
|||||||
} => {
|
} => {
|
||||||
let scope = PassErrorScope::WriteTimestamp;
|
let scope = PassErrorScope::WriteTimestamp;
|
||||||
|
|
||||||
|
query_set.device.same_device(device).map_pass_err(scope)?;
|
||||||
|
|
||||||
device
|
device
|
||||||
.require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
|
.require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
|
||||||
.map_pass_err(scope)?;
|
.map_pass_err(scope)?;
|
||||||
@ -906,6 +903,8 @@ impl Global {
|
|||||||
} => {
|
} => {
|
||||||
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
|
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
|
||||||
|
|
||||||
|
query_set.device.same_device(device).map_pass_err(scope)?;
|
||||||
|
|
||||||
let query_set = tracker.query_sets.insert_single(query_set);
|
let query_set = tracker.query_sets.insert_single(query_set);
|
||||||
|
|
||||||
validate_and_begin_pipeline_statistics_query(
|
validate_and_begin_pipeline_statistics_query(
|
||||||
@ -994,10 +993,6 @@ impl Global {
|
|||||||
.map_err(|_| ComputePassErrorInner::InvalidBindGroup(index))
|
.map_err(|_| ComputePassErrorInner::InvalidBindGroup(index))
|
||||||
.map_pass_err(scope)?;
|
.map_pass_err(scope)?;
|
||||||
|
|
||||||
if bind_group.device.as_info().id() != pass.device_id {
|
|
||||||
return Err(DeviceError::WrongDevice).map_pass_err(scope);
|
|
||||||
}
|
|
||||||
|
|
||||||
base.commands.push(ArcComputeCommand::SetBindGroup {
|
base.commands.push(ArcComputeCommand::SetBindGroup {
|
||||||
index,
|
index,
|
||||||
num_dynamic_offsets: offsets.len(),
|
num_dynamic_offsets: offsets.len(),
|
||||||
@ -1016,7 +1011,6 @@ impl Global {
|
|||||||
|
|
||||||
let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
|
let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
|
||||||
|
|
||||||
let device_id = pass.device_id;
|
|
||||||
let base = pass.base_mut(scope)?;
|
let base = pass.base_mut(scope)?;
|
||||||
if redundant {
|
if redundant {
|
||||||
// Do redundant early-out **after** checking whether the pass is ended or not.
|
// Do redundant early-out **after** checking whether the pass is ended or not.
|
||||||
@ -1031,10 +1025,6 @@ impl Global {
|
|||||||
.map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id))
|
.map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id))
|
||||||
.map_pass_err(scope)?;
|
.map_pass_err(scope)?;
|
||||||
|
|
||||||
if pipeline.device.as_info().id() != device_id {
|
|
||||||
return Err(DeviceError::WrongDevice).map_pass_err(scope);
|
|
||||||
}
|
|
||||||
|
|
||||||
base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
|
base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1108,7 +1098,6 @@ impl Global {
|
|||||||
indirect: true,
|
indirect: true,
|
||||||
pipeline: pass.current_pipeline.last_state,
|
pipeline: pass.current_pipeline.last_state,
|
||||||
};
|
};
|
||||||
let device_id = pass.device_id;
|
|
||||||
let base = pass.base_mut(scope)?;
|
let base = pass.base_mut(scope)?;
|
||||||
|
|
||||||
let buffer = hub
|
let buffer = hub
|
||||||
@ -1118,10 +1107,6 @@ impl Global {
|
|||||||
.map_err(|_| ComputePassErrorInner::InvalidBuffer(buffer_id))
|
.map_err(|_| ComputePassErrorInner::InvalidBuffer(buffer_id))
|
||||||
.map_pass_err(scope)?;
|
.map_pass_err(scope)?;
|
||||||
|
|
||||||
if buffer.device.as_info().id() != device_id {
|
|
||||||
return Err(DeviceError::WrongDevice).map_pass_err(scope);
|
|
||||||
}
|
|
||||||
|
|
||||||
base.commands
|
base.commands
|
||||||
.push(ArcComputeCommand::<A>::DispatchIndirect { buffer, offset });
|
.push(ArcComputeCommand::<A>::DispatchIndirect { buffer, offset });
|
||||||
|
|
||||||
@ -1185,7 +1170,6 @@ impl Global {
|
|||||||
query_index: u32,
|
query_index: u32,
|
||||||
) -> Result<(), ComputePassError> {
|
) -> Result<(), ComputePassError> {
|
||||||
let scope = PassErrorScope::WriteTimestamp;
|
let scope = PassErrorScope::WriteTimestamp;
|
||||||
let device_id = pass.device_id;
|
|
||||||
let base = pass.base_mut(scope)?;
|
let base = pass.base_mut(scope)?;
|
||||||
|
|
||||||
let hub = A::hub(self);
|
let hub = A::hub(self);
|
||||||
@ -1196,10 +1180,6 @@ impl Global {
|
|||||||
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
|
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
|
||||||
.map_pass_err(scope)?;
|
.map_pass_err(scope)?;
|
||||||
|
|
||||||
if query_set.device.as_info().id() != device_id {
|
|
||||||
return Err(DeviceError::WrongDevice).map_pass_err(scope);
|
|
||||||
}
|
|
||||||
|
|
||||||
base.commands.push(ArcComputeCommand::WriteTimestamp {
|
base.commands.push(ArcComputeCommand::WriteTimestamp {
|
||||||
query_set,
|
query_set,
|
||||||
query_index,
|
query_index,
|
||||||
@ -1215,7 +1195,6 @@ impl Global {
|
|||||||
query_index: u32,
|
query_index: u32,
|
||||||
) -> Result<(), ComputePassError> {
|
) -> Result<(), ComputePassError> {
|
||||||
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
|
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
|
||||||
let device_id = pass.device_id;
|
|
||||||
let base = pass.base_mut(scope)?;
|
let base = pass.base_mut(scope)?;
|
||||||
|
|
||||||
let hub = A::hub(self);
|
let hub = A::hub(self);
|
||||||
@ -1226,10 +1205,6 @@ impl Global {
|
|||||||
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
|
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
|
||||||
.map_pass_err(scope)?;
|
.map_pass_err(scope)?;
|
||||||
|
|
||||||
if query_set.device.as_info().id() != device_id {
|
|
||||||
return Err(DeviceError::WrongDevice).map_pass_err(scope);
|
|
||||||
}
|
|
||||||
|
|
||||||
base.commands
|
base.commands
|
||||||
.push(ArcComputeCommand::BeginPipelineStatisticsQuery {
|
.push(ArcComputeCommand::BeginPipelineStatisticsQuery {
|
||||||
query_set,
|
query_set,
|
||||||
|
Loading…
Reference in New Issue
Block a user