From edc2cd9615a679528ec780785aa36b5c7566ea46 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 19 Jun 2024 10:46:28 +0200 Subject: [PATCH] introduce `Device.check_is_valid` --- wgpu-core/src/command/bundle.rs | 10 +-- wgpu-core/src/command/clear.rs | 8 +-- wgpu-core/src/command/compute.rs | 9 +-- wgpu-core/src/command/render.rs | 4 +- wgpu-core/src/command/transfer.rs | 22 ++----- wgpu-core/src/device/global.rs | 101 ++++++++++++++++-------------- wgpu-core/src/device/mod.rs | 10 ++- wgpu-core/src/device/resource.rs | 6 ++ wgpu-core/src/present.rs | 12 +--- 9 files changed, 84 insertions(+), 98 deletions(-) diff --git a/wgpu-core/src/command/bundle.rs b/wgpu-core/src/command/bundle.rs index 1a7733e65..0df93f7b8 100644 --- a/wgpu-core/src/command/bundle.rs +++ b/wgpu-core/src/command/bundle.rs @@ -1510,10 +1510,12 @@ pub struct RenderBundleError { } impl RenderBundleError { - pub(crate) const INVALID_DEVICE: Self = RenderBundleError { - scope: PassErrorScope::Bundle, - inner: RenderBundleErrorInner::Device(DeviceError::Invalid), - }; + pub fn from_device_error(e: DeviceError) -> Self { + Self { + scope: PassErrorScope::Bundle, + inner: e.into(), + } + } } impl PrettyError for RenderBundleError { fn fmt_pretty(&self, fmt: &mut ErrorFormatter) { diff --git a/wgpu-core/src/command/clear.rs b/wgpu-core/src/command/clear.rs index 80167d2c2..e75ca798f 100644 --- a/wgpu-core/src/command/clear.rs +++ b/wgpu-core/src/command/clear.rs @@ -9,7 +9,7 @@ use crate::{ get_lowest_common_denom, global::Global, hal_api::HalApi, - id::{BufferId, CommandEncoderId, DeviceId, TextureId}, + id::{BufferId, CommandEncoderId, TextureId}, init_tracker::{MemoryInitKind, TextureInitRange}, resource::{ParentDevice, Resource, Texture, TextureClearMode}, snatch::SnatchGuard, @@ -26,8 +26,6 @@ use wgt::{math::align_to, BufferAddress, BufferUsages, ImageSubresourceRange, Te pub enum ClearError { #[error("To use clear_texture the CLEAR_TEXTURE feature needs to be enabled")] MissingClearTextureFeature, - #[error("Device {0:?} is invalid")] - InvalidDevice(DeviceId), #[error("Buffer {0:?} is invalid or destroyed")] InvalidBuffer(BufferId), #[error("Texture {0:?} is invalid or destroyed")] @@ -238,9 +236,7 @@ impl Global { } let device = &cmd_buf.device; - if !device.is_valid() { - return Err(ClearError::InvalidDevice(cmd_buf.device.as_info().id())); - } + device.check_is_valid()?; let (encoder, tracker) = cmd_buf_data.open_encoder_and_tracker()?; let snatch_guard = device.snatchable_lock.read(); diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 3e19caf51..795403a32 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -161,8 +161,6 @@ pub enum ComputePassErrorInner { InvalidParentEncoder, #[error("Bind group at index {0:?} is invalid")] InvalidBindGroup(u32), - #[error("Device {0:?} is invalid")] - InvalidDevice(id::DeviceId), #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")] BindGroupIndexOutOfRange { index: u32, max: u32 }, #[error("Compute pipeline {0:?} is invalid")] @@ -452,12 +450,7 @@ impl Global { let pass_scope = PassErrorScope::Pass(Some(cmd_buf.as_info().id())); let device = &cmd_buf.device; - if !device.is_valid() { - return Err(ComputePassErrorInner::InvalidDevice( - cmd_buf.device.as_info().id(), - )) - .map_pass_err(pass_scope); - } + device.check_is_valid().map_pass_err(pass_scope)?; let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); diff --git a/wgpu-core/src/command/render.rs b/wgpu-core/src/command/render.rs index 5f0dde1db..cb8141f73 100644 --- a/wgpu-core/src/command/render.rs +++ b/wgpu-core/src/command/render.rs @@ -1368,9 +1368,7 @@ impl Global { }); } - if !device.is_valid() { - return Err(DeviceError::Lost).map_pass_err(pass_scope); - } + device.check_is_valid().map_pass_err(pass_scope)?; let encoder = &mut cmd_buf_data.encoder; let status = &mut cmd_buf_data.status; diff --git a/wgpu-core/src/command/transfer.rs b/wgpu-core/src/command/transfer.rs index d04510f83..997d8ecd9 100644 --- a/wgpu-core/src/command/transfer.rs +++ b/wgpu-core/src/command/transfer.rs @@ -8,12 +8,12 @@ use crate::{ error::{ErrorFormatter, PrettyError}, global::Global, hal_api::HalApi, - id::{BufferId, CommandEncoderId, DeviceId, TextureId}, + id::{BufferId, CommandEncoderId, TextureId}, init_tracker::{ has_copy_partial_init_tracker_coverage, MemoryInitKind, TextureInitRange, TextureInitTrackerAction, }, - resource::{ParentDevice, Resource, Texture, TextureErrorDimension}, + resource::{ParentDevice, Texture, TextureErrorDimension}, snatch::SnatchGuard, track::{TextureSelector, Tracker}, }; @@ -41,8 +41,6 @@ pub enum CopySide { #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum TransferError { - #[error("Device {0:?} is invalid")] - InvalidDevice(DeviceId), #[error("Buffer {0:?} is invalid or destroyed")] InvalidBuffer(BufferId), #[error("Texture {0:?} is invalid or destroyed")] @@ -579,9 +577,7 @@ impl Global { let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); let device = &cmd_buf.device; - if !device.is_valid() { - return Err(TransferError::InvalidDevice(cmd_buf.device.as_info().id()).into()); - } + device.check_is_valid()?; #[cfg(feature = "trace")] if let Some(ref mut list) = cmd_buf_data.commands { @@ -746,9 +742,7 @@ impl Global { let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?; let device = &cmd_buf.device; - if !device.is_valid() { - return Err(TransferError::InvalidDevice(cmd_buf.device.as_info().id()).into()); - } + device.check_is_valid()?; let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); @@ -913,9 +907,7 @@ impl Global { let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?; let device = &cmd_buf.device; - if !device.is_valid() { - return Err(TransferError::InvalidDevice(cmd_buf.device.as_info().id()).into()); - } + device.check_is_valid()?; let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); @@ -1092,9 +1084,7 @@ impl Global { let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?; let device = &cmd_buf.device; - if !device.is_valid() { - return Err(TransferError::InvalidDevice(cmd_buf.device.as_info().id()).into()); - } + device.check_is_valid()?; let snatch_guard = device.snatchable_lock.read(); diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index db1cd98b9..af7fd69eb 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -31,7 +31,7 @@ use std::{ sync::{atomic::Ordering, Arc}, }; -use super::{ImplicitPipelineIds, InvalidDevice, UserClosures}; +use super::{ImplicitPipelineIds, UserClosures}; impl Global { pub fn adapter_is_surface_supported( @@ -101,10 +101,13 @@ impl Global { pub fn device_features( &self, device_id: DeviceId, - ) -> Result { + ) -> Result { let hub = A::hub(self); - let device = hub.devices.get(device_id).map_err(|_| InvalidDevice)?; + let device = hub + .devices + .get(device_id) + .map_err(|_| DeviceError::InvalidDeviceId)?; Ok(device.features) } @@ -112,10 +115,13 @@ impl Global { pub fn device_limits( &self, device_id: DeviceId, - ) -> Result { + ) -> Result { let hub = A::hub(self); - let device = hub.devices.get(device_id).map_err(|_| InvalidDevice)?; + let device = hub + .devices + .get(device_id) + .map_err(|_| DeviceError::InvalidDeviceId)?; Ok(device.limits.clone()) } @@ -123,10 +129,13 @@ impl Global { pub fn device_downlevel_properties( &self, device_id: DeviceId, - ) -> Result { + ) -> Result { let hub = A::hub(self); - let device = hub.devices.get(device_id).map_err(|_| InvalidDevice)?; + let device = hub + .devices + .get(device_id) + .map_err(|_| DeviceError::InvalidDeviceId)?; Ok(device.downlevel.clone()) } @@ -147,7 +156,7 @@ impl Global { let device = match hub.devices.get(device_id) { Ok(device) => device, Err(_) => { - break 'error DeviceError::Invalid.into(); + break 'error DeviceError::InvalidDeviceId.into(); } }; if !device.is_valid() { @@ -348,7 +357,7 @@ impl Global { hub.devices .get(device_id) - .map_err(|_| DeviceError::Invalid)? + .map_err(|_| DeviceError::InvalidDeviceId)? .wait_for_submit(last_submission) } @@ -367,11 +376,9 @@ impl Global { let device = hub .devices .get(device_id) - .map_err(|_| DeviceError::Invalid)?; + .map_err(|_| DeviceError::InvalidDeviceId)?; let snatch_guard = device.snatchable_lock.read(); - if !device.is_valid() { - return Err(DeviceError::Lost.into()); - } + device.check_is_valid()?; let buffer = hub .buffers @@ -429,10 +436,8 @@ impl Global { let device = hub .devices .get(device_id) - .map_err(|_| DeviceError::Invalid)?; - if !device.is_valid() { - return Err(DeviceError::Lost.into()); - } + .map_err(|_| DeviceError::InvalidDeviceId)?; + device.check_is_valid()?; let snatch_guard = device.snatchable_lock.read(); @@ -549,7 +554,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; if !device.is_valid() { break 'error DeviceError::Lost.into(); @@ -603,7 +608,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; // NB: Any change done through the raw texture handle will not be @@ -675,7 +680,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; // NB: Any change done through the raw buffer handle will not be @@ -872,7 +877,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; if !device.is_valid() { break 'error DeviceError::Lost.into(); @@ -936,7 +941,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; if !device.is_valid() { break 'error DeviceError::Lost.into(); @@ -1035,7 +1040,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; if !device.is_valid() { break 'error DeviceError::Lost.into(); @@ -1093,7 +1098,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; if !device.is_valid() { break 'error DeviceError::Lost.into(); @@ -1186,7 +1191,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; if !device.is_valid() { break 'error DeviceError::Lost.into(); @@ -1264,7 +1269,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; if !device.is_valid() { break 'error DeviceError::Lost.into(); @@ -1325,7 +1330,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid, + Err(_) => break 'error DeviceError::InvalidDeviceId, }; if !device.is_valid() { break 'error DeviceError::Lost; @@ -1416,10 +1421,14 @@ impl Global { let error = 'error: { let device = match hub.devices.get(bundle_encoder.parent()) { Ok(device) => device, - Err(_) => break 'error command::RenderBundleError::INVALID_DEVICE, + Err(_) => { + break 'error command::RenderBundleError::from_device_error( + DeviceError::InvalidDeviceId, + ); + } }; if !device.is_valid() { - break 'error command::RenderBundleError::INVALID_DEVICE; + break 'error command::RenderBundleError::from_device_error(DeviceError::Lost); } #[cfg(feature = "trace")] @@ -1485,7 +1494,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; if !device.is_valid() { break 'error DeviceError::Lost.into(); @@ -1562,7 +1571,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; if !device.is_valid() { break 'error DeviceError::Lost.into(); @@ -1698,7 +1707,7 @@ impl Global { let error = 'error: { let device = match hub.devices.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; if !device.is_valid() { break 'error DeviceError::Lost.into(); @@ -1830,7 +1839,7 @@ impl Global { let device = match hub.devices.get(device_id) { Ok(device) => device, // TODO: Handle error properly - Err(crate::storage::InvalidId) => break 'error DeviceError::Invalid.into(), + Err(crate::storage::InvalidId) => break 'error DeviceError::InvalidDeviceId.into(), }; if !device.is_valid() { break 'error DeviceError::Lost.into(); @@ -2000,10 +2009,10 @@ impl Global { let device = match device_guard.get(device_id) { Ok(device) => device, - Err(_) => break 'error DeviceError::Invalid.into(), + Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); + if let Err(e) = device.check_is_valid() { + break 'error e.into(); } #[cfg(feature = "trace")] @@ -2139,13 +2148,16 @@ impl Global { #[cfg(feature = "replay")] /// Only triage suspected resource IDs. This helps us to avoid ID collisions /// upon creating new resources when re-playing a trace. - pub fn device_maintain_ids(&self, device_id: DeviceId) -> Result<(), InvalidDevice> { + pub fn device_maintain_ids(&self, device_id: DeviceId) -> Result<(), DeviceError> { let hub = A::hub(self); - let device = hub.devices.get(device_id).map_err(|_| InvalidDevice)?; - if !device.is_valid() { - return Err(InvalidDevice); - } + let device = hub + .devices + .get(device_id) + .map_err(|_| DeviceError::InvalidDeviceId)?; + + device.check_is_valid()?; + device.lock_life().triage_suspected(&device.trackers); Ok(()) } @@ -2164,7 +2176,7 @@ impl Global { let device = hub .devices .get(device_id) - .map_err(|_| DeviceError::Invalid)?; + .map_err(|_| DeviceError::InvalidDeviceId)?; if let wgt::Maintain::WaitForSubmissionIndex(submission_index) = maintain { if submission_index.queue_id != device_id.into_queue_id() { @@ -2687,10 +2699,7 @@ impl Global { } drop(snatch_guard); - if !buffer.device.is_valid() { - return Err(DeviceError::Lost.into()); - } - + buffer.device.check_is_valid()?; buffer.unmap() } } diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index f44764f94..b8c499218 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -380,10 +380,6 @@ fn map_buffer( Ok(mapping.ptr) } -#[derive(Clone, Debug, Error)] -#[error("Device is invalid")] -pub struct InvalidDevice; - #[derive(Clone, Debug)] pub struct DeviceMismatch { pub(super) res: ResourceErrorIdent, @@ -411,14 +407,16 @@ impl std::error::Error for DeviceMismatch {} #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum DeviceError { - #[error("Parent device is invalid.")] - Invalid, + #[error("{0} is invalid.")] + Invalid(ResourceErrorIdent), #[error("Parent device is lost")] Lost, #[error("Not enough memory left.")] OutOfMemory, #[error("Creation of a resource failed for a reason other than running out of memory.")] ResourceCreationFailed, + #[error("DeviceId is invalid")] + InvalidDeviceId, #[error("QueueId is invalid")] InvalidQueueId, #[error(transparent)] diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index d44d98b96..2e7efa0d4 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -313,6 +313,12 @@ impl Device { self.valid.load(Ordering::Acquire) } + pub fn check_is_valid(&self) -> Result<(), DeviceError> { + self.is_valid() + .then_some(()) + .ok_or_else(|| DeviceError::Invalid(self.error_ident())) + } + pub(crate) fn release_queue(&self, queue: A::Queue) { assert!(self.queue_to_drop.set(queue).is_ok()); } diff --git a/wgpu-core/src/present.rs b/wgpu-core/src/present.rs index 7f5939feb..95840b133 100644 --- a/wgpu-core/src/present.rs +++ b/wgpu-core/src/present.rs @@ -136,9 +136,7 @@ impl Global { let (device, config) = if let Some(ref present) = *surface.presentation.lock() { match present.device.downcast_clone::() { Some(device) => { - if !device.is_valid() { - return Err(DeviceError::Lost.into()); - } + device.check_is_valid()?; (device, present.config.clone()) } None => return Err(SurfaceError::NotConfigured), @@ -303,9 +301,7 @@ impl Global { }; let device = present.device.downcast_ref::().unwrap(); - if !device.is_valid() { - return Err(DeviceError::Lost.into()); - } + device.check_is_valid()?; let queue = device.get_queue().unwrap(); #[cfg(feature = "trace")] @@ -397,9 +393,7 @@ impl Global { }; let device = present.device.downcast_ref::().unwrap(); - if !device.is_valid() { - return Err(DeviceError::Lost.into()); - } + device.check_is_valid()?; #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() {