validate for same device via Arc::ptr_eq rather than IDs

This commit is contained in:
teoxoy 2024-06-18 11:59:16 +02:00 committed by Nicolas Silva
parent a2fcd72606
commit 53f8477b15
8 changed files with 52 additions and 141 deletions

View File

@ -104,9 +104,7 @@ impl Global {
.get(dst)
.map_err(|_| ClearError::InvalidBuffer(dst))?;
if dst_buffer.device.as_info().id() != cmd_buf.device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
dst_buffer.device.same_device(&cmd_buf.device)?;
cmd_buf_data
.trackers
@ -203,9 +201,7 @@ impl Global {
.get(dst)
.map_err(|_| ClearError::InvalidTexture(dst))?;
if dst_texture.device.as_info().id() != cmd_buf.device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
dst_texture.device.same_device(&cmd_buf.device)?;
// Check if subresource aspects are valid.
let clear_aspects =

View File

@ -361,7 +361,7 @@ impl Global {
);
};
if query_set.device.as_info().id() != cmd_buf.device.as_info().id() {
if query_set.device.same_device(&cmd_buf.device).is_err() {
return (
ComputePass::new(None, arc_desc),
Some(CommandEncoderError::WrongDeviceForTimestampWritesQuerySet),

View File

@ -9,7 +9,7 @@ use crate::{
hal_api::HalApi,
id::{self, Id},
init_tracker::MemoryInitKind,
resource::{QuerySet, Resource},
resource::QuerySet,
storage::Storage,
Epoch, FastHashMap, Index,
};
@ -405,9 +405,7 @@ impl Global {
.add_single(&*query_set_guard, query_set_id)
.ok_or(QueryError::InvalidQuerySet(query_set_id))?;
if query_set.device.as_info().id() != cmd_buf.device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
query_set.device.same_device(&cmd_buf.device)?;
let (dst_buffer, dst_pending) = {
let buffer_guard = hub.buffers.read();
@ -415,9 +413,7 @@ impl Global {
.get(destination)
.map_err(|_| QueryError::InvalidBuffer(destination))?;
if dst_buffer.device.as_info().id() != cmd_buf.device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
dst_buffer.device.same_device(&cmd_buf.device)?;
tracker
.buffers

View File

@ -1476,9 +1476,7 @@ impl Global {
.ok_or(RenderCommandError::InvalidBindGroup(bind_group_id))
.map_pass_err(scope)?;
if bind_group.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}
bind_group.device.same_device(device).map_pass_err(scope)?;
bind_group
.validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits)
@ -1544,9 +1542,7 @@ impl Global {
.ok_or(RenderCommandError::InvalidPipeline(pipeline_id))
.map_pass_err(scope)?;
if pipeline.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}
pipeline.device.same_device(device).map_pass_err(scope)?;
info.context
.check_compatible(
@ -1673,9 +1669,7 @@ impl Global {
.merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDEX)
.map_pass_err(scope)?;
if buffer.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}
buffer.device.same_device(device).map_pass_err(scope)?;
check_buffer_usage(buffer_id, buffer.usage, BufferUsages::INDEX)
.map_pass_err(scope)?;
@ -1726,9 +1720,7 @@ impl Global {
.merge_single(&*buffer_guard, buffer_id, hal::BufferUses::VERTEX)
.map_pass_err(scope)?;
if buffer.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}
buffer.device.same_device(device).map_pass_err(scope)?;
let max_vertex_buffers = device.limits.max_vertex_buffers;
if slot >= max_vertex_buffers {
@ -2333,9 +2325,7 @@ impl Global {
.ok_or(RenderCommandError::InvalidRenderBundle(bundle_id))
.map_pass_err(scope)?;
if bundle.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}
bundle.device.same_device(device).map_pass_err(scope)?;
info.context
.check_compatible(

View File

@ -602,9 +602,7 @@ impl Global {
.get(source)
.map_err(|_| TransferError::InvalidBuffer(source))?;
if src_buffer.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
src_buffer.device.same_device(device)?;
cmd_buf_data
.trackers
@ -628,9 +626,7 @@ impl Global {
.get(destination)
.map_err(|_| TransferError::InvalidBuffer(destination))?;
if dst_buffer.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
dst_buffer.device.same_device(device)?;
cmd_buf_data
.trackers
@ -781,9 +777,7 @@ impl Global {
.get(destination.texture)
.map_err(|_| TransferError::InvalidTexture(destination.texture))?;
if dst_texture.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
dst_texture.device.same_device(device)?;
let (hal_copy_size, array_layer_count) = validate_texture_copy_range(
destination,
@ -816,9 +810,7 @@ impl Global {
.get(source.buffer)
.map_err(|_| TransferError::InvalidBuffer(source.buffer))?;
if src_buffer.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
src_buffer.device.same_device(device)?;
tracker
.buffers
@ -951,9 +943,7 @@ impl Global {
.get(source.texture)
.map_err(|_| TransferError::InvalidTexture(source.texture))?;
if src_texture.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
src_texture.device.same_device(device)?;
let (hal_copy_size, array_layer_count) =
validate_texture_copy_range(source, &src_texture.desc, CopySide::Source, copy_size)?;
@ -1007,9 +997,7 @@ impl Global {
.get(destination.buffer)
.map_err(|_| TransferError::InvalidBuffer(destination.buffer))?;
if dst_buffer.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
dst_buffer.device.same_device(device)?;
tracker
.buffers
@ -1139,12 +1127,8 @@ impl Global {
.get(destination.texture)
.map_err(|_| TransferError::InvalidTexture(source.texture))?;
if src_texture.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
if dst_texture.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
src_texture.device.same_device(device)?;
dst_texture.device.same_device(device)?;
// src and dst texture format must be copy-compatible
// https://gpuweb.github.io/gpuweb/#copy-compatible

View File

@ -15,7 +15,6 @@ use crate::{
pipeline, present,
resource::{
self, BufferAccessError, BufferAccessResult, BufferMapOperation, CreateBufferError,
Resource,
},
validation::check_buffer_usage,
Label, LabelHelpers as _,
@ -1125,8 +1124,8 @@ impl Global {
Err(..) => break 'error binding_model::CreateBindGroupError::InvalidLayout,
};
if bind_group_layout.device.as_info().id() != device.as_info().id() {
break 'error DeviceError::WrongDevice.into();
if let Err(e) = bind_group_layout.device.same_device(&device) {
break 'error e.into();
}
let bind_group = match device.create_bind_group(&bind_group_layout, desc, hub) {

View File

@ -12,7 +12,7 @@ use crate::{
global::Global,
hal_api::HalApi,
hal_label,
id::{self, DeviceId, QueueId},
id::{self, QueueId},
init_tracker::{has_copy_partial_init_tracker_coverage, TextureInitRange},
lock::{rank, Mutex, RwLockWriteGuard},
resource::{
@ -352,15 +352,6 @@ pub struct InvalidQueue;
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum QueueWriteError {
#[error(
"Device of queue ({:?}) does not match device of write recipient ({:?})",
queue_device_id,
target_device_id
)]
DeviceMismatch {
queue_device_id: DeviceId,
target_device_id: DeviceId,
},
#[error(transparent)]
Queue(#[from] DeviceError),
#[error(transparent)]
@ -405,13 +396,10 @@ impl Global {
let hub = A::hub(self);
let buffer_device_id = hub
let buffer = hub
.buffers
.get(buffer_id)
.map_err(|_| TransferError::InvalidBuffer(buffer_id))?
.device
.as_info()
.id();
.map_err(|_| TransferError::InvalidBuffer(buffer_id))?;
let queue = hub
.queues
@ -420,15 +408,7 @@ impl Global {
let device = queue.device.as_ref().unwrap();
{
let queue_device_id = device.as_info().id();
if buffer_device_id != queue_device_id {
return Err(QueueWriteError::DeviceMismatch {
queue_device_id,
target_device_id: buffer_device_id,
});
}
}
buffer.device.same_device(device)?;
let data_size = data.len() as wgt::BufferAddress;
@ -607,7 +587,7 @@ impl Global {
fn queue_write_staging_buffer_impl<A: HalApi>(
&self,
device: &Device<A>,
device: &Arc<Device<A>>,
pending_writes: &mut PendingWrites<A>,
staging_buffer: &StagingBuffer<A>,
buffer_id: id::BufferId,
@ -632,9 +612,7 @@ impl Global {
.get(&snatch_guard)
.ok_or(TransferError::InvalidBuffer(buffer_id))?;
if dst.device.as_info().id() != device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
dst.device.same_device(device)?;
let src_buffer_size = staging_buffer.size;
self.queue_validate_write_buffer_impl(&dst, buffer_id, buffer_offset, src_buffer_size)?;
@ -717,9 +695,7 @@ impl Global {
.get(destination.texture)
.map_err(|_| TransferError::InvalidTexture(destination.texture))?;
if dst.device.as_info().id().into_queue_id() != queue_id {
return Err(DeviceError::WrongDevice.into());
}
dst.device.same_device(device)?;
if !dst.desc.usage.contains(wgt::TextureUsages::COPY_DST) {
return Err(
@ -1200,9 +1176,7 @@ impl Global {
Err(_) => continue,
};
if cmdbuf.device.as_info().id().into_queue_id() != queue_id {
return Err(DeviceError::WrongDevice.into());
}
cmdbuf.device.same_device(device)?;
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {

View File

@ -313,6 +313,12 @@ impl<A: HalApi> Device<A> {
self.valid.load(Ordering::Acquire)
}
pub fn same_device(self: &Arc<Self>, other: &Arc<Self>) -> Result<(), DeviceError> {
Arc::ptr_eq(self, other)
.then_some(())
.ok_or(DeviceError::WrongDevice)
}
pub(crate) fn release_queue(&self, queue: A::Queue) {
assert!(self.queue_to_drop.set(queue).is_ok());
}
@ -1837,6 +1843,7 @@ impl<A: HalApi> Device<A> {
}
pub(crate) fn create_buffer_binding<'a>(
self: &Arc<Self>,
bb: &binding_model::BufferBinding,
binding: u32,
decl: &wgt::BindGroupLayoutEntry,
@ -1846,7 +1853,6 @@ impl<A: HalApi> Device<A> {
used: &mut BindGroupStates<A>,
storage: &'a Storage<Buffer<A>>,
limits: &wgt::Limits,
device_id: id::Id<id::markers::Device>,
snatch_guard: &'a SnatchGuard<'a>,
) -> Result<hal::BufferBinding<'a, A>, binding_model::CreateBindGroupError> {
use crate::binding_model::CreateBindGroupError as Error;
@ -1898,9 +1904,7 @@ impl<A: HalApi> Device<A> {
.add_single(storage, bb.buffer_id, internal_use)
.ok_or(Error::InvalidBuffer(bb.buffer_id))?;
if buffer.device.as_info().id() != device_id {
return Err(DeviceError::WrongDevice.into());
}
buffer.device.same_device(self)?;
check_buffer_usage(bb.buffer_id, buffer.usage, pub_usage)?;
let raw_buffer = buffer
@ -1981,10 +1985,10 @@ impl<A: HalApi> Device<A> {
}
fn create_sampler_binding<'a>(
self: &Arc<Self>,
used: &BindGroupStates<A>,
storage: &'a Storage<Sampler<A>>,
id: id::Id<id::markers::Sampler>,
device_id: id::Id<id::markers::Device>,
) -> Result<&'a Sampler<A>, binding_model::CreateBindGroupError> {
use crate::binding_model::CreateBindGroupError as Error;
@ -1993,9 +1997,7 @@ impl<A: HalApi> Device<A> {
.add_single(storage, id)
.ok_or(Error::InvalidSampler(id))?;
if sampler.device.as_info().id() != device_id {
return Err(DeviceError::WrongDevice.into());
}
sampler.device.same_device(self)?;
Ok(sampler)
}
@ -2017,9 +2019,7 @@ impl<A: HalApi> Device<A> {
.add_single(storage, id)
.ok_or(Error::InvalidTextureView(id))?;
if view.device.as_info().id() != self.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
view.device.same_device(self)?;
let (pub_usage, internal_use) = self.texture_use_parameters(
binding,
@ -2038,9 +2038,7 @@ impl<A: HalApi> Device<A> {
texture_id,
))?;
if texture.device.as_info().id() != view.device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
texture.device.same_device(&view.device)?;
check_texture_usage(texture.desc.usage, pub_usage)?;
@ -2113,7 +2111,7 @@ impl<A: HalApi> Device<A> {
.ok_or(Error::MissingBindingDeclaration(binding))?;
let (res_index, count) = match entry.resource {
Br::Buffer(ref bb) => {
let bb = Self::create_buffer_binding(
let bb = self.create_buffer_binding(
bb,
binding,
decl,
@ -2123,7 +2121,6 @@ impl<A: HalApi> Device<A> {
&mut used,
&*buffer_guard,
&self.limits,
self.as_info().id(),
&snatch_guard,
)?;
@ -2137,7 +2134,7 @@ impl<A: HalApi> Device<A> {
let res_index = hal_buffers.len();
for bb in bindings_array.iter() {
let bb = Self::create_buffer_binding(
let bb = self.create_buffer_binding(
bb,
binding,
decl,
@ -2147,7 +2144,6 @@ impl<A: HalApi> Device<A> {
&mut used,
&*buffer_guard,
&self.limits,
self.as_info().id(),
&snatch_guard,
)?;
hal_buffers.push(bb);
@ -2156,12 +2152,7 @@ impl<A: HalApi> Device<A> {
}
Br::Sampler(id) => match decl.ty {
wgt::BindingType::Sampler(ty) => {
let sampler = Self::create_sampler_binding(
&used,
&sampler_guard,
id,
self.as_info().id(),
)?;
let sampler = self.create_sampler_binding(&used, &sampler_guard, id)?;
let (allowed_filtering, allowed_comparison) = match ty {
wgt::SamplerBindingType::Filtering => (None, false),
@ -2203,12 +2194,7 @@ impl<A: HalApi> Device<A> {
let res_index = hal_samplers.len();
for &id in bindings_array.iter() {
let sampler = Self::create_sampler_binding(
&used,
&sampler_guard,
id,
self.as_info().id(),
)?;
let sampler = self.create_sampler_binding(&used, &sampler_guard, id)?;
hal_samplers.push(sampler.raw());
}
@ -2537,9 +2523,7 @@ impl<A: HalApi> Device<A> {
// Validate total resource counts and check for a matching device
for bgl in &bind_group_layouts {
if bgl.device.as_info().id() != self.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
bgl.device.same_device(self)?;
count_validator.merge(&bgl.binding_count_validator);
}
@ -2647,9 +2631,7 @@ impl<A: HalApi> Device<A> {
.get(desc.stage.module)
.map_err(|_| validation::StageError::InvalidModule)?;
if shader_module.device.as_info().id() != self.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
shader_module.device.same_device(self)?;
// Get the pipeline layout from the desc if it is provided.
let pipeline_layout = match desc.layout {
@ -2659,9 +2641,7 @@ impl<A: HalApi> Device<A> {
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?;
if pipeline_layout.device.as_info().id() != self.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
pipeline_layout.device.same_device(self)?;
Some(pipeline_layout)
}
@ -2723,9 +2703,7 @@ impl<A: HalApi> Device<A> {
break 'cache None;
};
if cache.device.as_info().id() != self.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
cache.device.same_device(self)?;
Some(cache)
};
@ -3103,9 +3081,7 @@ impl<A: HalApi> Device<A> {
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?;
if pipeline_layout.device.as_info().id() != self.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
pipeline_layout.device.same_device(self)?;
Some(pipeline_layout)
}
@ -3140,9 +3116,7 @@ impl<A: HalApi> Device<A> {
error: validation::StageError::InvalidModule,
}
})?;
if vertex_shader_module.device.as_info().id() != self.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
vertex_shader_module.device.same_device(self)?;
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error };
@ -3334,9 +3308,7 @@ impl<A: HalApi> Device<A> {
break 'cache None;
};
if cache.device.as_info().id() != self.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
cache.device.same_device(self)?;
Some(cache)
};