diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index 7f800d529..3bbeb1b3a 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -963,6 +963,7 @@ impl Device { let hal_desc = hal::ShaderModuleDescriptor { label: desc.label.borrow_option(), + runtime_checks: desc.shader_bound_checks.runtime_checks(), }; let raw = match unsafe { self.raw.create_shader_module(&hal_desc, hal_shader) } { Ok(raw) => raw, @@ -1001,6 +1002,7 @@ impl Device { self.require_features(wgt::Features::SPIRV_SHADER_PASSTHROUGH)?; let hal_desc = hal::ShaderModuleDescriptor { label: desc.label.borrow_option(), + runtime_checks: desc.shader_bound_checks.runtime_checks(), }; let hal_shader = hal::ShaderInput::SpirV(source); let raw = match unsafe { self.raw.create_shader_module(&hal_desc, hal_shader) } { diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index be4b234f5..12195d85e 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -18,6 +18,7 @@ pub enum ShaderModuleSource<'a> { #[cfg_attr(feature = "replay", derive(serde::Deserialize))] pub struct ShaderModuleDescriptor<'a> { pub label: Label<'a>, + pub shader_bound_checks: wgt::ShaderBoundChecks, } #[derive(Debug)] diff --git a/wgpu-hal/examples/halmark/main.rs b/wgpu-hal/examples/halmark/main.rs index 694ac118f..4f9dbbdfc 100644 --- a/wgpu-hal/examples/halmark/main.rs +++ b/wgpu-hal/examples/halmark/main.rs @@ -137,7 +137,10 @@ impl Example { .unwrap(); hal::NagaShader { module, info } }; - let shader_desc = hal::ShaderModuleDescriptor { label: None }; + let shader_desc = hal::ShaderModuleDescriptor { + label: None, + runtime_checks: false, + }; let shader = unsafe { device .create_shader_module(&shader_desc, hal::ShaderInput::Naga(naga_shader)) diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index 0ba2ba9ea..c38f76077 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -898,6 +898,7 @@ pub enum ShaderInput<'a> { pub struct ShaderModuleDescriptor<'a> { pub label: Label<'a>, + pub runtime_checks: bool, } /// Describes a programmable pipeline stage. diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index cb2cc1f9c..8c4edad51 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -582,15 +582,32 @@ impl super::Device { let stage_flags = crate::auxil::map_naga_stage(naga_stage); let vk_module = match *stage.module { super::ShaderModule::Raw(raw) => raw, - super::ShaderModule::Intermediate(ref naga_shader) => { + super::ShaderModule::Intermediate { + ref naga_shader, + runtime_checks, + } => { let pipeline_options = naga::back::spv::PipelineOptions { entry_point: stage.entry_point.to_string(), shader_stage: naga_stage, }; + let temp_options; + let options = if !runtime_checks { + temp_options = naga::back::spv::Options { + bounds_check_policies: naga::back::BoundsCheckPolicies { + index: naga::back::BoundsCheckPolicy::Unchecked, + buffer: naga::back::BoundsCheckPolicy::Unchecked, + image: naga::back::BoundsCheckPolicy::Unchecked, + }, + ..self.naga_options.clone() + }; + &temp_options + } else { + &self.naga_options + }; let spv = naga::back::spv::write_vec( &naga_shader.module, &naga_shader.info, - &self.naga_options, + options, Some(&pipeline_options), ) .map_err(|e| crate::PipelineError::Linkage(stage_flags, format!("{}", e)))?; @@ -610,7 +627,7 @@ impl super::Device { _entry_point: entry_point, temp_raw_module: match *stage.module { super::ShaderModule::Raw(_) => None, - super::ShaderModule::Intermediate(_) => Some(vk_module), + super::ShaderModule::Intermediate { .. } => Some(vk_module), }, }) } @@ -1179,13 +1196,24 @@ impl crate::Device for super::Device { .workarounds .contains(super::Workarounds::SEPARATE_ENTRY_POINTS) { - return Ok(super::ShaderModule::Intermediate(naga_shader)); + return Ok(super::ShaderModule::Intermediate { + naga_shader, + runtime_checks: desc.runtime_checks, + }); + } + let mut naga_options = self.naga_options.clone(); + if !desc.runtime_checks { + naga_options.bounds_check_policies = naga::back::BoundsCheckPolicies { + index: naga::back::BoundsCheckPolicy::Unchecked, + buffer: naga::back::BoundsCheckPolicy::Unchecked, + image: naga::back::BoundsCheckPolicy::Unchecked, + }; } Cow::Owned( naga::back::spv::write_vec( &naga_shader.module, &naga_shader.info, - &self.naga_options, + &naga_options, None, ) .map_err(|e| crate::ShaderError::Compilation(format!("{}", e)))?, @@ -1208,7 +1236,7 @@ impl crate::Device for super::Device { super::ShaderModule::Raw(raw) => { let _ = self.shared.raw.destroy_shader_module(raw, None); } - super::ShaderModule::Intermediate(_) => {} + super::ShaderModule::Intermediate { .. } => {} } } diff --git a/wgpu-hal/src/vulkan/mod.rs b/wgpu-hal/src/vulkan/mod.rs index 1c727a6e2..be9d120e0 100644 --- a/wgpu-hal/src/vulkan/mod.rs +++ b/wgpu-hal/src/vulkan/mod.rs @@ -372,7 +372,10 @@ pub struct CommandBuffer { #[derive(Debug)] pub enum ShaderModule { Raw(vk::ShaderModule), - Intermediate(crate::NagaShader), + Intermediate { + naga_shader: crate::NagaShader, + runtime_checks: bool, + }, } #[derive(Debug)] diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 64ebf1814..729664d9c 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -3440,3 +3440,42 @@ pub struct DispatchIndirectArgs { /// Z dimension of the grid of workgroups to dispatch. pub group_size_z: u32, } + +/// Describes how shader bound checks should be performed. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "trace", derive(serde::Serialize))] +#[cfg_attr(feature = "replay", derive(serde::Deserialize))] +pub struct ShaderBoundChecks { + runtime_checks: bool, +} + +impl ShaderBoundChecks { + /// Creates a new configuration where the shader is bound checked. + pub fn new() -> Self { + ShaderBoundChecks { + runtime_checks: true, + } + } + + /// Creates a new configuration where the shader isn't bound checked. + /// + /// # Safety + /// The caller MUST ensure that all shaders built with this configuration don't perform any + /// out of bounds reads or writes. + pub unsafe fn unchecked() -> Self { + ShaderBoundChecks { + runtime_checks: false, + } + } + + /// Query whether runtime bound checks are enabled in this configuration + pub fn runtime_checks(&self) -> bool { + self.runtime_checks + } +} + +impl Default for ShaderBoundChecks { + fn default() -> Self { + Self::new() + } +} diff --git a/wgpu/src/backend/direct.rs b/wgpu/src/backend/direct.rs index d530e3b1c..836ddd020 100644 --- a/wgpu/src/backend/direct.rs +++ b/wgpu/src/backend/direct.rs @@ -971,10 +971,12 @@ impl crate::Context for Context { &self, device: &Self::DeviceId, desc: &ShaderModuleDescriptor, + shader_bound_checks: wgt::ShaderBoundChecks, ) -> Self::ShaderModuleId { let global = &self.0; let descriptor = wgc::pipeline::ShaderModuleDescriptor { label: desc.label.map(Borrowed), + shader_bound_checks, }; let source = match desc.source { #[cfg(feature = "spirv")] @@ -1014,6 +1016,9 @@ impl crate::Context for Context { let global = &self.0; let descriptor = wgc::pipeline::ShaderModuleDescriptor { label: desc.label.map(Borrowed), + // Doesn't matter the value since spirv shaders aren't mutated to include + // runtime checks + shader_bound_checks: wgt::ShaderBoundChecks::unchecked(), }; let (id, error) = wgc::gfx_select!( device.id => global.device_create_shader_module_spirv(device.id, &descriptor, Borrowed(&desc.source), PhantomData) diff --git a/wgpu/src/backend/web.rs b/wgpu/src/backend/web.rs index e5c4b60c7..21bb863d9 100644 --- a/wgpu/src/backend/web.rs +++ b/wgpu/src/backend/web.rs @@ -1182,6 +1182,7 @@ impl crate::Context for Context { &self, device: &Self::DeviceId, desc: &crate::ShaderModuleDescriptor, + _shader_bound_checks: wgt::ShaderBoundChecks, ) -> Self::ShaderModuleId { let mut descriptor = match desc.source { #[cfg(feature = "spirv-web")] diff --git a/wgpu/src/lib.rs b/wgpu/src/lib.rs index a0325541a..b89574a24 100644 --- a/wgpu/src/lib.rs +++ b/wgpu/src/lib.rs @@ -242,6 +242,7 @@ trait Context: Debug + Send + Sized + Sync { &self, device: &Self::DeviceId, desc: &ShaderModuleDescriptor, + shader_bound_checks: wgt::ShaderBoundChecks, ) -> Self::ShaderModuleId; unsafe fn device_create_shader_module_spirv( &self, @@ -1664,7 +1665,37 @@ impl Device { pub fn create_shader_module(&self, desc: &ShaderModuleDescriptor) -> ShaderModule { ShaderModule { context: Arc::clone(&self.context), - id: Context::device_create_shader_module(&*self.context, &self.id, desc), + id: Context::device_create_shader_module( + &*self.context, + &self.id, + desc, + wgt::ShaderBoundChecks::new(), + ), + } + } + + /// Creates a shader module from either SPIR-V or WGSL source code without runtime checks. + /// + /// # Safety + /// In contrast with [`create_shader_module`](Self::create_shader_module) this function + /// creates a shader module without runtime checks which allows shaders to perform + /// operations which can lead to undefined behavior like indexing out of bounds, thus it's + /// the caller responsibility to pass a shader which doesn't perform any of this + /// operations. + /// + /// This has no effect on web. + pub unsafe fn create_shader_module_unchecked( + &self, + desc: &ShaderModuleDescriptor, + ) -> ShaderModule { + ShaderModule { + context: Arc::clone(&self.context), + id: Context::device_create_shader_module( + &*self.context, + &self.id, + desc, + wgt::ShaderBoundChecks::unchecked(), + ), } }