diff --git a/wgpu-core/src/command/clear.rs b/wgpu-core/src/command/clear.rs index 9ef0f24d4..71f8c0441 100644 --- a/wgpu-core/src/command/clear.rs +++ b/wgpu-core/src/command/clear.rs @@ -104,9 +104,7 @@ impl Global { .get(dst) .map_err(|_| ClearError::InvalidBuffer(dst))?; - if dst_buffer.device.as_info().id() != cmd_buf.device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst_buffer.device.same_device(&cmd_buf.device)?; cmd_buf_data .trackers @@ -203,9 +201,7 @@ impl Global { .get(dst) .map_err(|_| ClearError::InvalidTexture(dst))?; - if dst_texture.device.as_info().id() != cmd_buf.device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst_texture.device.same_device(&cmd_buf.device)?; // Check if subresource aspects are valid. let clear_aspects = diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index acbff0a03..206894247 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -361,7 +361,7 @@ impl Global { ); }; - if query_set.device.as_info().id() != cmd_buf.device.as_info().id() { + if query_set.device.same_device(&cmd_buf.device).is_err() { return ( ComputePass::new(None, arc_desc), Some(CommandEncoderError::WrongDeviceForTimestampWritesQuerySet), diff --git a/wgpu-core/src/command/query.rs b/wgpu-core/src/command/query.rs index bd4f9e991..dce8bfe10 100644 --- a/wgpu-core/src/command/query.rs +++ b/wgpu-core/src/command/query.rs @@ -9,7 +9,7 @@ use crate::{ hal_api::HalApi, id::{self, Id}, init_tracker::MemoryInitKind, - resource::{QuerySet, Resource}, + resource::QuerySet, storage::Storage, Epoch, FastHashMap, Index, }; @@ -405,9 +405,7 @@ impl Global { .add_single(&*query_set_guard, query_set_id) .ok_or(QueryError::InvalidQuerySet(query_set_id))?; - if query_set.device.as_info().id() != cmd_buf.device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + query_set.device.same_device(&cmd_buf.device)?; let (dst_buffer, dst_pending) = { let buffer_guard = hub.buffers.read(); @@ -415,9 +413,7 @@ impl Global { .get(destination) .map_err(|_| QueryError::InvalidBuffer(destination))?; - if dst_buffer.device.as_info().id() != cmd_buf.device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst_buffer.device.same_device(&cmd_buf.device)?; tracker .buffers diff --git a/wgpu-core/src/command/render.rs b/wgpu-core/src/command/render.rs index defd6a608..e08dd9f02 100644 --- a/wgpu-core/src/command/render.rs +++ b/wgpu-core/src/command/render.rs @@ -1476,9 +1476,7 @@ impl Global { .ok_or(RenderCommandError::InvalidBindGroup(bind_group_id)) .map_pass_err(scope)?; - if bind_group.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } + bind_group.device.same_device(device).map_pass_err(scope)?; bind_group .validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits) @@ -1544,9 +1542,7 @@ impl Global { .ok_or(RenderCommandError::InvalidPipeline(pipeline_id)) .map_pass_err(scope)?; - if pipeline.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } + pipeline.device.same_device(device).map_pass_err(scope)?; info.context .check_compatible( @@ -1673,9 +1669,7 @@ impl Global { .merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDEX) .map_pass_err(scope)?; - if buffer.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } + buffer.device.same_device(device).map_pass_err(scope)?; check_buffer_usage(buffer_id, buffer.usage, BufferUsages::INDEX) .map_pass_err(scope)?; @@ -1726,9 +1720,7 @@ impl Global { .merge_single(&*buffer_guard, buffer_id, hal::BufferUses::VERTEX) .map_pass_err(scope)?; - if buffer.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } + buffer.device.same_device(device).map_pass_err(scope)?; let max_vertex_buffers = device.limits.max_vertex_buffers; if slot >= max_vertex_buffers { @@ -2333,9 +2325,7 @@ impl Global { .ok_or(RenderCommandError::InvalidRenderBundle(bundle_id)) .map_pass_err(scope)?; - if bundle.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } + bundle.device.same_device(device).map_pass_err(scope)?; info.context .check_compatible( diff --git a/wgpu-core/src/command/transfer.rs b/wgpu-core/src/command/transfer.rs index 6c7073900..bf5ce67e1 100644 --- a/wgpu-core/src/command/transfer.rs +++ b/wgpu-core/src/command/transfer.rs @@ -602,9 +602,7 @@ impl Global { .get(source) .map_err(|_| TransferError::InvalidBuffer(source))?; - if src_buffer.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + src_buffer.device.same_device(device)?; cmd_buf_data .trackers @@ -628,9 +626,7 @@ impl Global { .get(destination) .map_err(|_| TransferError::InvalidBuffer(destination))?; - if dst_buffer.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst_buffer.device.same_device(device)?; cmd_buf_data .trackers @@ -781,9 +777,7 @@ impl Global { .get(destination.texture) .map_err(|_| TransferError::InvalidTexture(destination.texture))?; - if dst_texture.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst_texture.device.same_device(device)?; let (hal_copy_size, array_layer_count) = validate_texture_copy_range( destination, @@ -816,9 +810,7 @@ impl Global { .get(source.buffer) .map_err(|_| TransferError::InvalidBuffer(source.buffer))?; - if src_buffer.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + src_buffer.device.same_device(device)?; tracker .buffers @@ -951,9 +943,7 @@ impl Global { .get(source.texture) .map_err(|_| TransferError::InvalidTexture(source.texture))?; - if src_texture.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + src_texture.device.same_device(device)?; let (hal_copy_size, array_layer_count) = validate_texture_copy_range(source, &src_texture.desc, CopySide::Source, copy_size)?; @@ -1007,9 +997,7 @@ impl Global { .get(destination.buffer) .map_err(|_| TransferError::InvalidBuffer(destination.buffer))?; - if dst_buffer.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst_buffer.device.same_device(device)?; tracker .buffers @@ -1139,12 +1127,8 @@ impl Global { .get(destination.texture) .map_err(|_| TransferError::InvalidTexture(source.texture))?; - if src_texture.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } - if dst_texture.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + src_texture.device.same_device(device)?; + dst_texture.device.same_device(device)?; // src and dst texture format must be copy-compatible // https://gpuweb.github.io/gpuweb/#copy-compatible diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index d6133f438..e9863376b 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -15,7 +15,6 @@ use crate::{ pipeline, present, resource::{ self, BufferAccessError, BufferAccessResult, BufferMapOperation, CreateBufferError, - Resource, }, validation::check_buffer_usage, Label, LabelHelpers as _, @@ -1125,8 +1124,8 @@ impl Global { Err(..) => break 'error binding_model::CreateBindGroupError::InvalidLayout, }; - if bind_group_layout.device.as_info().id() != device.as_info().id() { - break 'error DeviceError::WrongDevice.into(); + if let Err(e) = bind_group_layout.device.same_device(&device) { + break 'error e.into(); } let bind_group = match device.create_bind_group(&bind_group_layout, desc, hub) { diff --git a/wgpu-core/src/device/queue.rs b/wgpu-core/src/device/queue.rs index 33af48388..eab96eed6 100644 --- a/wgpu-core/src/device/queue.rs +++ b/wgpu-core/src/device/queue.rs @@ -12,7 +12,7 @@ use crate::{ global::Global, hal_api::HalApi, hal_label, - id::{self, DeviceId, QueueId}, + id::{self, QueueId}, init_tracker::{has_copy_partial_init_tracker_coverage, TextureInitRange}, lock::{rank, Mutex, RwLockWriteGuard}, resource::{ @@ -352,15 +352,6 @@ pub struct InvalidQueue; #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum QueueWriteError { - #[error( - "Device of queue ({:?}) does not match device of write recipient ({:?})", - queue_device_id, - target_device_id - )] - DeviceMismatch { - queue_device_id: DeviceId, - target_device_id: DeviceId, - }, #[error(transparent)] Queue(#[from] DeviceError), #[error(transparent)] @@ -405,13 +396,10 @@ impl Global { let hub = A::hub(self); - let buffer_device_id = hub + let buffer = hub .buffers .get(buffer_id) - .map_err(|_| TransferError::InvalidBuffer(buffer_id))? - .device - .as_info() - .id(); + .map_err(|_| TransferError::InvalidBuffer(buffer_id))?; let queue = hub .queues @@ -420,15 +408,7 @@ impl Global { let device = queue.device.as_ref().unwrap(); - { - let queue_device_id = device.as_info().id(); - if buffer_device_id != queue_device_id { - return Err(QueueWriteError::DeviceMismatch { - queue_device_id, - target_device_id: buffer_device_id, - }); - } - } + buffer.device.same_device(device)?; let data_size = data.len() as wgt::BufferAddress; @@ -607,7 +587,7 @@ impl Global { fn queue_write_staging_buffer_impl( &self, - device: &Device, + device: &Arc>, pending_writes: &mut PendingWrites, staging_buffer: &StagingBuffer, buffer_id: id::BufferId, @@ -632,9 +612,7 @@ impl Global { .get(&snatch_guard) .ok_or(TransferError::InvalidBuffer(buffer_id))?; - if dst.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst.device.same_device(device)?; let src_buffer_size = staging_buffer.size; self.queue_validate_write_buffer_impl(&dst, buffer_id, buffer_offset, src_buffer_size)?; @@ -717,9 +695,7 @@ impl Global { .get(destination.texture) .map_err(|_| TransferError::InvalidTexture(destination.texture))?; - if dst.device.as_info().id().into_queue_id() != queue_id { - return Err(DeviceError::WrongDevice.into()); - } + dst.device.same_device(device)?; if !dst.desc.usage.contains(wgt::TextureUsages::COPY_DST) { return Err( @@ -1200,9 +1176,7 @@ impl Global { Err(_) => continue, }; - if cmdbuf.device.as_info().id().into_queue_id() != queue_id { - return Err(DeviceError::WrongDevice.into()); - } + cmdbuf.device.same_device(device)?; #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index f4702bc91..6cb6653d0 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -313,6 +313,12 @@ impl Device { self.valid.load(Ordering::Acquire) } + pub fn same_device(self: &Arc, other: &Arc) -> Result<(), DeviceError> { + Arc::ptr_eq(self, other) + .then_some(()) + .ok_or(DeviceError::WrongDevice) + } + pub(crate) fn release_queue(&self, queue: A::Queue) { assert!(self.queue_to_drop.set(queue).is_ok()); } @@ -1837,6 +1843,7 @@ impl Device { } pub(crate) fn create_buffer_binding<'a>( + self: &Arc, bb: &binding_model::BufferBinding, binding: u32, decl: &wgt::BindGroupLayoutEntry, @@ -1846,7 +1853,6 @@ impl Device { used: &mut BindGroupStates, storage: &'a Storage>, limits: &wgt::Limits, - device_id: id::Id, snatch_guard: &'a SnatchGuard<'a>, ) -> Result, binding_model::CreateBindGroupError> { use crate::binding_model::CreateBindGroupError as Error; @@ -1898,9 +1904,7 @@ impl Device { .add_single(storage, bb.buffer_id, internal_use) .ok_or(Error::InvalidBuffer(bb.buffer_id))?; - if buffer.device.as_info().id() != device_id { - return Err(DeviceError::WrongDevice.into()); - } + buffer.device.same_device(self)?; check_buffer_usage(bb.buffer_id, buffer.usage, pub_usage)?; let raw_buffer = buffer @@ -1981,10 +1985,10 @@ impl Device { } fn create_sampler_binding<'a>( + self: &Arc, used: &BindGroupStates, storage: &'a Storage>, id: id::Id, - device_id: id::Id, ) -> Result<&'a Sampler, binding_model::CreateBindGroupError> { use crate::binding_model::CreateBindGroupError as Error; @@ -1993,9 +1997,7 @@ impl Device { .add_single(storage, id) .ok_or(Error::InvalidSampler(id))?; - if sampler.device.as_info().id() != device_id { - return Err(DeviceError::WrongDevice.into()); - } + sampler.device.same_device(self)?; Ok(sampler) } @@ -2017,9 +2019,7 @@ impl Device { .add_single(storage, id) .ok_or(Error::InvalidTextureView(id))?; - if view.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + view.device.same_device(self)?; let (pub_usage, internal_use) = self.texture_use_parameters( binding, @@ -2038,9 +2038,7 @@ impl Device { texture_id, ))?; - if texture.device.as_info().id() != view.device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + texture.device.same_device(&view.device)?; check_texture_usage(texture.desc.usage, pub_usage)?; @@ -2113,7 +2111,7 @@ impl Device { .ok_or(Error::MissingBindingDeclaration(binding))?; let (res_index, count) = match entry.resource { Br::Buffer(ref bb) => { - let bb = Self::create_buffer_binding( + let bb = self.create_buffer_binding( bb, binding, decl, @@ -2123,7 +2121,6 @@ impl Device { &mut used, &*buffer_guard, &self.limits, - self.as_info().id(), &snatch_guard, )?; @@ -2137,7 +2134,7 @@ impl Device { let res_index = hal_buffers.len(); for bb in bindings_array.iter() { - let bb = Self::create_buffer_binding( + let bb = self.create_buffer_binding( bb, binding, decl, @@ -2147,7 +2144,6 @@ impl Device { &mut used, &*buffer_guard, &self.limits, - self.as_info().id(), &snatch_guard, )?; hal_buffers.push(bb); @@ -2156,12 +2152,7 @@ impl Device { } Br::Sampler(id) => match decl.ty { wgt::BindingType::Sampler(ty) => { - let sampler = Self::create_sampler_binding( - &used, - &sampler_guard, - id, - self.as_info().id(), - )?; + let sampler = self.create_sampler_binding(&used, &sampler_guard, id)?; let (allowed_filtering, allowed_comparison) = match ty { wgt::SamplerBindingType::Filtering => (None, false), @@ -2203,12 +2194,7 @@ impl Device { let res_index = hal_samplers.len(); for &id in bindings_array.iter() { - let sampler = Self::create_sampler_binding( - &used, - &sampler_guard, - id, - self.as_info().id(), - )?; + let sampler = self.create_sampler_binding(&used, &sampler_guard, id)?; hal_samplers.push(sampler.raw()); } @@ -2537,9 +2523,7 @@ impl Device { // Validate total resource counts and check for a matching device for bgl in &bind_group_layouts { - if bgl.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + bgl.device.same_device(self)?; count_validator.merge(&bgl.binding_count_validator); } @@ -2647,9 +2631,7 @@ impl Device { .get(desc.stage.module) .map_err(|_| validation::StageError::InvalidModule)?; - if shader_module.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + shader_module.device.same_device(self)?; // Get the pipeline layout from the desc if it is provided. let pipeline_layout = match desc.layout { @@ -2659,9 +2641,7 @@ impl Device { .get(pipeline_layout_id) .map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?; - if pipeline_layout.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + pipeline_layout.device.same_device(self)?; Some(pipeline_layout) } @@ -2723,9 +2703,7 @@ impl Device { break 'cache None; }; - if cache.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + cache.device.same_device(self)?; Some(cache) }; @@ -3103,9 +3081,7 @@ impl Device { .get(pipeline_layout_id) .map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?; - if pipeline_layout.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + pipeline_layout.device.same_device(self)?; Some(pipeline_layout) } @@ -3140,9 +3116,7 @@ impl Device { error: validation::StageError::InvalidModule, } })?; - if vertex_shader_module.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + vertex_shader_module.device.same_device(self)?; let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; @@ -3334,9 +3308,7 @@ impl Device { break 'cache None; }; - if cache.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + cache.device.same_device(self)?; Some(cache) };