move most device validity checks inside the device's methods

This commit is contained in:
teoxoy 2024-06-19 11:02:21 +02:00 committed by Teodor Tanasoaia
parent edc2cd9615
commit 4b5666ceff
3 changed files with 31 additions and 40 deletions

View File

@ -349,6 +349,10 @@ impl RenderBundleEncoder {
device: &Arc<Device<A>>,
hub: &Hub<A>,
) -> Result<RenderBundle<A>, RenderBundleError> {
let scope = PassErrorScope::Bundle;
device.check_is_valid().map_pass_err(scope)?;
let bind_group_guard = hub.bind_groups.read();
let pipeline_guard = hub.render_pipelines.read();
let buffer_guard = hub.buffers.read();

View File

@ -159,9 +159,6 @@ impl Global {
break 'error DeviceError::InvalidDeviceId.into();
}
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
if desc.usage.is_empty() {
// Per spec, `usage` must not be zero.
@ -556,9 +553,6 @@ impl Global {
Ok(device) => device,
Err(_) => break 'error DeviceError::InvalidDeviceId.into(),
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
trace.add(trace::Action::CreateTexture(fid.id(), desc.clone()));
@ -879,9 +873,6 @@ impl Global {
Ok(device) => device,
Err(_) => break 'error DeviceError::InvalidDeviceId.into(),
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
@ -943,15 +934,17 @@ impl Global {
Ok(device) => device,
Err(_) => break 'error DeviceError::InvalidDeviceId.into(),
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
trace.add(trace::Action::CreateBindGroupLayout(fid.id(), desc.clone()));
}
// this check can't go in the body of `create_bind_group_layout` since the closure might not get called
if let Err(e) = device.check_is_valid() {
break 'error e.into();
}
let entry_map = match bgl::EntryMap::from_entries(&device.limits, &desc.entries) {
Ok(map) => map,
Err(e) => break 'error e,
@ -1042,9 +1035,6 @@ impl Global {
Ok(device) => device,
Err(_) => break 'error DeviceError::InvalidDeviceId.into(),
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
@ -1100,9 +1090,6 @@ impl Global {
Ok(device) => device,
Err(_) => break 'error DeviceError::InvalidDeviceId.into(),
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
@ -1193,9 +1180,6 @@ impl Global {
Ok(device) => device,
Err(_) => break 'error DeviceError::InvalidDeviceId.into(),
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
@ -1271,9 +1255,6 @@ impl Global {
Ok(device) => device,
Err(_) => break 'error DeviceError::InvalidDeviceId.into(),
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
@ -1427,9 +1408,6 @@ impl Global {
);
}
};
if !device.is_valid() {
break 'error command::RenderBundleError::from_device_error(DeviceError::Lost);
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
@ -1496,10 +1474,6 @@ impl Global {
Ok(device) => device,
Err(_) => break 'error DeviceError::InvalidDeviceId.into(),
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
trace.add(trace::Action::CreateQuerySet {
@ -1573,9 +1547,6 @@ impl Global {
Ok(device) => device,
Err(_) => break 'error DeviceError::InvalidDeviceId.into(),
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
trace.add(trace::Action::CreateRenderPipeline {
@ -1709,9 +1680,6 @@ impl Global {
Ok(device) => device,
Err(_) => break 'error DeviceError::InvalidDeviceId.into(),
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
@ -1841,9 +1809,6 @@ impl Global {
// TODO: Handle error properly
Err(crate::storage::InvalidId) => break 'error DeviceError::InvalidDeviceId.into(),
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
trace.add(trace::Action::CreatePipelineCache {

View File

@ -570,6 +570,8 @@ impl<A: HalApi> Device<A> {
) -> Result<Buffer<A>, resource::CreateBufferError> {
debug_assert_eq!(self.as_info().id().backend(), A::VARIANT);
self.check_is_valid()?;
if desc.size > self.limits.max_buffer_size {
return Err(resource::CreateBufferError::MaxBufferSize {
requested: desc.size,
@ -735,6 +737,8 @@ impl<A: HalApi> Device<A> {
) -> Result<Texture<A>, resource::CreateTextureError> {
use resource::{CreateTextureError, TextureDimensionError};
self.check_is_valid()?;
if desc.usage.is_empty() || desc.usage.contains_invalid_bits() {
return Err(CreateTextureError::InvalidUsage(desc.usage));
}
@ -1310,6 +1314,8 @@ impl<A: HalApi> Device<A> {
self: &Arc<Self>,
desc: &resource::SamplerDescriptor,
) -> Result<Sampler<A>, resource::CreateSamplerError> {
self.check_is_valid()?;
if desc
.address_modes
.iter()
@ -1421,6 +1427,8 @@ impl<A: HalApi> Device<A> {
desc: &pipeline::ShaderModuleDescriptor<'a>,
source: pipeline::ShaderModuleSource<'a>,
) -> Result<pipeline::ShaderModule<A>, pipeline::CreateShaderModuleError> {
self.check_is_valid()?;
let (module, source) = match source {
#[cfg(feature = "wgsl")]
pipeline::ShaderModuleSource::Wgsl(code) => {
@ -1551,6 +1559,8 @@ impl<A: HalApi> Device<A> {
desc: &pipeline::ShaderModuleDescriptor<'a>,
source: &'a [u32],
) -> Result<pipeline::ShaderModule<A>, pipeline::CreateShaderModuleError> {
self.check_is_valid()?;
self.require_features(wgt::Features::SPIRV_SHADER_PASSTHROUGH)?;
let hal_desc = hal::ShaderModuleDescriptor {
label: desc.label.to_hal(self.instance_flags),
@ -2072,6 +2082,7 @@ impl<A: HalApi> Device<A> {
) -> Result<BindGroup<A>, binding_model::CreateBindGroupError> {
use crate::binding_model::{BindingResource as Br, CreateBindGroupError as Error};
self.check_is_valid()?;
layout.same_device(self)?;
{
@ -2465,6 +2476,8 @@ impl<A: HalApi> Device<A> {
) -> Result<binding_model::PipelineLayout<A>, binding_model::CreatePipelineLayoutError> {
use crate::binding_model::CreatePipelineLayoutError as Error;
self.check_is_valid()?;
let bind_group_layouts_count = desc.bind_group_layouts.len();
let device_max_bind_groups = self.limits.max_bind_groups as usize;
if bind_group_layouts_count > device_max_bind_groups {
@ -2616,6 +2629,8 @@ impl<A: HalApi> Device<A> {
implicit_context: Option<ImplicitPipelineContext>,
hub: &Hub<A>,
) -> Result<pipeline::ComputePipeline<A>, pipeline::CreateComputePipelineError> {
self.check_is_valid()?;
// This has to be done first, or otherwise the IDs may be pointing to entries
// that are not even in the storage.
if let Some(ref ids) = implicit_context {
@ -2764,6 +2779,8 @@ impl<A: HalApi> Device<A> {
) -> Result<pipeline::RenderPipeline<A>, pipeline::CreateRenderPipelineError> {
use wgt::TextureFormatFeatureFlags as Tfff;
self.check_is_valid()?;
// This has to be done first, or otherwise the IDs may be pointing to entries
// that are not even in the storage.
if let Some(ref ids) = implicit_context {
@ -3414,6 +3431,9 @@ impl<A: HalApi> Device<A> {
desc: &pipeline::PipelineCacheDescriptor,
) -> Result<pipeline::PipelineCache<A>, pipeline::CreatePipelineCacheError> {
use crate::pipeline_cache;
self.check_is_valid()?;
self.require_features(wgt::Features::PIPELINE_CACHE)?;
let data = if let Some((data, validation_key)) = desc
.data
@ -3536,6 +3556,8 @@ impl<A: HalApi> Device<A> {
) -> Result<QuerySet<A>, resource::CreateQuerySetError> {
use resource::CreateQuerySetError as Error;
self.check_is_valid()?;
match desc.ty {
wgt::QueryType::Occlusion => {}
wgt::QueryType::Timestamp => {