From 4b5666ceff3b5c42b8a376ff45b6bac4a7fd3386 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 19 Jun 2024 11:02:21 +0200 Subject: [PATCH] move most device validity checks inside the device's methods --- wgpu-core/src/command/bundle.rs | 4 +++ wgpu-core/src/device/global.rs | 45 ++++---------------------------- wgpu-core/src/device/resource.rs | 22 ++++++++++++++++ 3 files changed, 31 insertions(+), 40 deletions(-) diff --git a/wgpu-core/src/command/bundle.rs b/wgpu-core/src/command/bundle.rs index 0df93f7b8..2c971deeb 100644 --- a/wgpu-core/src/command/bundle.rs +++ b/wgpu-core/src/command/bundle.rs @@ -349,6 +349,10 @@ impl RenderBundleEncoder { device: &Arc>, hub: &Hub, ) -> Result, RenderBundleError> { + let scope = PassErrorScope::Bundle; + + device.check_is_valid().map_pass_err(scope)?; + let bind_group_guard = hub.bind_groups.read(); let pipeline_guard = hub.render_pipelines.read(); let buffer_guard = hub.buffers.read(); diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index af7fd69eb..d3d1c5be5 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -159,9 +159,6 @@ impl Global { break 'error DeviceError::InvalidDeviceId.into(); } }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); - } if desc.usage.is_empty() { // Per spec, `usage` must not be zero. @@ -556,9 +553,6 @@ impl Global { Ok(device) => device, Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); - } #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { trace.add(trace::Action::CreateTexture(fid.id(), desc.clone())); @@ -879,9 +873,6 @@ impl Global { Ok(device) => device, Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); - } #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { @@ -943,15 +934,17 @@ impl Global { Ok(device) => device, Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); - } #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { trace.add(trace::Action::CreateBindGroupLayout(fid.id(), desc.clone())); } + // this check can't go in the body of `create_bind_group_layout` since the closure might not get called + if let Err(e) = device.check_is_valid() { + break 'error e.into(); + } + let entry_map = match bgl::EntryMap::from_entries(&device.limits, &desc.entries) { Ok(map) => map, Err(e) => break 'error e, @@ -1042,9 +1035,6 @@ impl Global { Ok(device) => device, Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); - } #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { @@ -1100,9 +1090,6 @@ impl Global { Ok(device) => device, Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); - } #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { @@ -1193,9 +1180,6 @@ impl Global { Ok(device) => device, Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); - } #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { @@ -1271,9 +1255,6 @@ impl Global { Ok(device) => device, Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); - } #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { @@ -1427,9 +1408,6 @@ impl Global { ); } }; - if !device.is_valid() { - break 'error command::RenderBundleError::from_device_error(DeviceError::Lost); - } #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { @@ -1496,10 +1474,6 @@ impl Global { Ok(device) => device, Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); - } - #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { trace.add(trace::Action::CreateQuerySet { @@ -1573,9 +1547,6 @@ impl Global { Ok(device) => device, Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); - } #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { trace.add(trace::Action::CreateRenderPipeline { @@ -1709,9 +1680,6 @@ impl Global { Ok(device) => device, Err(_) => break 'error DeviceError::InvalidDeviceId.into(), }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); - } #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { @@ -1841,9 +1809,6 @@ impl Global { // TODO: Handle error properly Err(crate::storage::InvalidId) => break 'error DeviceError::InvalidDeviceId.into(), }; - if !device.is_valid() { - break 'error DeviceError::Lost.into(); - } #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { trace.add(trace::Action::CreatePipelineCache { diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 2e7efa0d4..83fe695d0 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -570,6 +570,8 @@ impl Device { ) -> Result, resource::CreateBufferError> { debug_assert_eq!(self.as_info().id().backend(), A::VARIANT); + self.check_is_valid()?; + if desc.size > self.limits.max_buffer_size { return Err(resource::CreateBufferError::MaxBufferSize { requested: desc.size, @@ -735,6 +737,8 @@ impl Device { ) -> Result, resource::CreateTextureError> { use resource::{CreateTextureError, TextureDimensionError}; + self.check_is_valid()?; + if desc.usage.is_empty() || desc.usage.contains_invalid_bits() { return Err(CreateTextureError::InvalidUsage(desc.usage)); } @@ -1310,6 +1314,8 @@ impl Device { self: &Arc, desc: &resource::SamplerDescriptor, ) -> Result, resource::CreateSamplerError> { + self.check_is_valid()?; + if desc .address_modes .iter() @@ -1421,6 +1427,8 @@ impl Device { desc: &pipeline::ShaderModuleDescriptor<'a>, source: pipeline::ShaderModuleSource<'a>, ) -> Result, pipeline::CreateShaderModuleError> { + self.check_is_valid()?; + let (module, source) = match source { #[cfg(feature = "wgsl")] pipeline::ShaderModuleSource::Wgsl(code) => { @@ -1551,6 +1559,8 @@ impl Device { desc: &pipeline::ShaderModuleDescriptor<'a>, source: &'a [u32], ) -> Result, pipeline::CreateShaderModuleError> { + self.check_is_valid()?; + self.require_features(wgt::Features::SPIRV_SHADER_PASSTHROUGH)?; let hal_desc = hal::ShaderModuleDescriptor { label: desc.label.to_hal(self.instance_flags), @@ -2072,6 +2082,7 @@ impl Device { ) -> Result, binding_model::CreateBindGroupError> { use crate::binding_model::{BindingResource as Br, CreateBindGroupError as Error}; + self.check_is_valid()?; layout.same_device(self)?; { @@ -2465,6 +2476,8 @@ impl Device { ) -> Result, binding_model::CreatePipelineLayoutError> { use crate::binding_model::CreatePipelineLayoutError as Error; + self.check_is_valid()?; + let bind_group_layouts_count = desc.bind_group_layouts.len(); let device_max_bind_groups = self.limits.max_bind_groups as usize; if bind_group_layouts_count > device_max_bind_groups { @@ -2616,6 +2629,8 @@ impl Device { implicit_context: Option, hub: &Hub, ) -> Result, pipeline::CreateComputePipelineError> { + self.check_is_valid()?; + // This has to be done first, or otherwise the IDs may be pointing to entries // that are not even in the storage. if let Some(ref ids) = implicit_context { @@ -2764,6 +2779,8 @@ impl Device { ) -> Result, pipeline::CreateRenderPipelineError> { use wgt::TextureFormatFeatureFlags as Tfff; + self.check_is_valid()?; + // This has to be done first, or otherwise the IDs may be pointing to entries // that are not even in the storage. if let Some(ref ids) = implicit_context { @@ -3414,6 +3431,9 @@ impl Device { desc: &pipeline::PipelineCacheDescriptor, ) -> Result, pipeline::CreatePipelineCacheError> { use crate::pipeline_cache; + + self.check_is_valid()?; + self.require_features(wgt::Features::PIPELINE_CACHE)?; let data = if let Some((data, validation_key)) = desc .data @@ -3536,6 +3556,8 @@ impl Device { ) -> Result, resource::CreateQuerySetError> { use resource::CreateQuerySetError as Error; + self.check_is_valid()?; + match desc.ty { wgt::QueryType::Occlusion => {} wgt::QueryType::Timestamp => {