Add method to create shader modules without runtime checks (#1978)

* Add method to create shader modules without runtime checks

* Use opaque struct to represent shader bound checks
This commit is contained in:
João Capucho 2021-09-27 20:51:44 +01:00 committed by GitHub
parent a7aa72ba1c
commit 9bc5908492
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 123 additions and 9 deletions

View File

@ -963,6 +963,7 @@ impl<A: HalApi> Device<A> {
let hal_desc = hal::ShaderModuleDescriptor {
label: desc.label.borrow_option(),
runtime_checks: desc.shader_bound_checks.runtime_checks(),
};
let raw = match unsafe { self.raw.create_shader_module(&hal_desc, hal_shader) } {
Ok(raw) => raw,
@ -1001,6 +1002,7 @@ impl<A: HalApi> Device<A> {
self.require_features(wgt::Features::SPIRV_SHADER_PASSTHROUGH)?;
let hal_desc = hal::ShaderModuleDescriptor {
label: desc.label.borrow_option(),
runtime_checks: desc.shader_bound_checks.runtime_checks(),
};
let hal_shader = hal::ShaderInput::SpirV(source);
let raw = match unsafe { self.raw.create_shader_module(&hal_desc, hal_shader) } {

View File

@ -18,6 +18,7 @@ pub enum ShaderModuleSource<'a> {
#[cfg_attr(feature = "replay", derive(serde::Deserialize))]
pub struct ShaderModuleDescriptor<'a> {
pub label: Label<'a>,
pub shader_bound_checks: wgt::ShaderBoundChecks,
}
#[derive(Debug)]

View File

@ -137,7 +137,10 @@ impl<A: hal::Api> Example<A> {
.unwrap();
hal::NagaShader { module, info }
};
let shader_desc = hal::ShaderModuleDescriptor { label: None };
let shader_desc = hal::ShaderModuleDescriptor {
label: None,
runtime_checks: false,
};
let shader = unsafe {
device
.create_shader_module(&shader_desc, hal::ShaderInput::Naga(naga_shader))

View File

@ -898,6 +898,7 @@ pub enum ShaderInput<'a> {
pub struct ShaderModuleDescriptor<'a> {
pub label: Label<'a>,
pub runtime_checks: bool,
}
/// Describes a programmable pipeline stage.

View File

@ -582,15 +582,32 @@ impl super::Device {
let stage_flags = crate::auxil::map_naga_stage(naga_stage);
let vk_module = match *stage.module {
super::ShaderModule::Raw(raw) => raw,
super::ShaderModule::Intermediate(ref naga_shader) => {
super::ShaderModule::Intermediate {
ref naga_shader,
runtime_checks,
} => {
let pipeline_options = naga::back::spv::PipelineOptions {
entry_point: stage.entry_point.to_string(),
shader_stage: naga_stage,
};
let temp_options;
let options = if !runtime_checks {
temp_options = naga::back::spv::Options {
bounds_check_policies: naga::back::BoundsCheckPolicies {
index: naga::back::BoundsCheckPolicy::Unchecked,
buffer: naga::back::BoundsCheckPolicy::Unchecked,
image: naga::back::BoundsCheckPolicy::Unchecked,
},
..self.naga_options.clone()
};
&temp_options
} else {
&self.naga_options
};
let spv = naga::back::spv::write_vec(
&naga_shader.module,
&naga_shader.info,
&self.naga_options,
options,
Some(&pipeline_options),
)
.map_err(|e| crate::PipelineError::Linkage(stage_flags, format!("{}", e)))?;
@ -610,7 +627,7 @@ impl super::Device {
_entry_point: entry_point,
temp_raw_module: match *stage.module {
super::ShaderModule::Raw(_) => None,
super::ShaderModule::Intermediate(_) => Some(vk_module),
super::ShaderModule::Intermediate { .. } => Some(vk_module),
},
})
}
@ -1179,13 +1196,24 @@ impl crate::Device<super::Api> for super::Device {
.workarounds
.contains(super::Workarounds::SEPARATE_ENTRY_POINTS)
{
return Ok(super::ShaderModule::Intermediate(naga_shader));
return Ok(super::ShaderModule::Intermediate {
naga_shader,
runtime_checks: desc.runtime_checks,
});
}
let mut naga_options = self.naga_options.clone();
if !desc.runtime_checks {
naga_options.bounds_check_policies = naga::back::BoundsCheckPolicies {
index: naga::back::BoundsCheckPolicy::Unchecked,
buffer: naga::back::BoundsCheckPolicy::Unchecked,
image: naga::back::BoundsCheckPolicy::Unchecked,
};
}
Cow::Owned(
naga::back::spv::write_vec(
&naga_shader.module,
&naga_shader.info,
&self.naga_options,
&naga_options,
None,
)
.map_err(|e| crate::ShaderError::Compilation(format!("{}", e)))?,
@ -1208,7 +1236,7 @@ impl crate::Device<super::Api> for super::Device {
super::ShaderModule::Raw(raw) => {
let _ = self.shared.raw.destroy_shader_module(raw, None);
}
super::ShaderModule::Intermediate(_) => {}
super::ShaderModule::Intermediate { .. } => {}
}
}

View File

@ -372,7 +372,10 @@ pub struct CommandBuffer {
#[derive(Debug)]
pub enum ShaderModule {
Raw(vk::ShaderModule),
Intermediate(crate::NagaShader),
Intermediate {
naga_shader: crate::NagaShader,
runtime_checks: bool,
},
}
#[derive(Debug)]

View File

@ -3440,3 +3440,42 @@ pub struct DispatchIndirectArgs {
/// Z dimension of the grid of workgroups to dispatch.
pub group_size_z: u32,
}
/// Describes how shader bound checks should be performed.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "trace", derive(serde::Serialize))]
#[cfg_attr(feature = "replay", derive(serde::Deserialize))]
pub struct ShaderBoundChecks {
runtime_checks: bool,
}
impl ShaderBoundChecks {
/// Creates a new configuration where the shader is bound checked.
pub fn new() -> Self {
ShaderBoundChecks {
runtime_checks: true,
}
}
/// Creates a new configuration where the shader isn't bound checked.
///
/// # Safety
/// The caller MUST ensure that all shaders built with this configuration don't perform any
/// out of bounds reads or writes.
pub unsafe fn unchecked() -> Self {
ShaderBoundChecks {
runtime_checks: false,
}
}
/// Query whether runtime bound checks are enabled in this configuration
pub fn runtime_checks(&self) -> bool {
self.runtime_checks
}
}
impl Default for ShaderBoundChecks {
fn default() -> Self {
Self::new()
}
}

View File

@ -971,10 +971,12 @@ impl crate::Context for Context {
&self,
device: &Self::DeviceId,
desc: &ShaderModuleDescriptor,
shader_bound_checks: wgt::ShaderBoundChecks,
) -> Self::ShaderModuleId {
let global = &self.0;
let descriptor = wgc::pipeline::ShaderModuleDescriptor {
label: desc.label.map(Borrowed),
shader_bound_checks,
};
let source = match desc.source {
#[cfg(feature = "spirv")]
@ -1014,6 +1016,9 @@ impl crate::Context for Context {
let global = &self.0;
let descriptor = wgc::pipeline::ShaderModuleDescriptor {
label: desc.label.map(Borrowed),
// Doesn't matter the value since spirv shaders aren't mutated to include
// runtime checks
shader_bound_checks: wgt::ShaderBoundChecks::unchecked(),
};
let (id, error) = wgc::gfx_select!(
device.id => global.device_create_shader_module_spirv(device.id, &descriptor, Borrowed(&desc.source), PhantomData)

View File

@ -1182,6 +1182,7 @@ impl crate::Context for Context {
&self,
device: &Self::DeviceId,
desc: &crate::ShaderModuleDescriptor,
_shader_bound_checks: wgt::ShaderBoundChecks,
) -> Self::ShaderModuleId {
let mut descriptor = match desc.source {
#[cfg(feature = "spirv-web")]

View File

@ -242,6 +242,7 @@ trait Context: Debug + Send + Sized + Sync {
&self,
device: &Self::DeviceId,
desc: &ShaderModuleDescriptor,
shader_bound_checks: wgt::ShaderBoundChecks,
) -> Self::ShaderModuleId;
unsafe fn device_create_shader_module_spirv(
&self,
@ -1664,7 +1665,37 @@ impl Device {
pub fn create_shader_module(&self, desc: &ShaderModuleDescriptor) -> ShaderModule {
ShaderModule {
context: Arc::clone(&self.context),
id: Context::device_create_shader_module(&*self.context, &self.id, desc),
id: Context::device_create_shader_module(
&*self.context,
&self.id,
desc,
wgt::ShaderBoundChecks::new(),
),
}
}
/// Creates a shader module from either SPIR-V or WGSL source code without runtime checks.
///
/// # Safety
/// In contrast with [`create_shader_module`](Self::create_shader_module) this function
/// creates a shader module without runtime checks which allows shaders to perform
/// operations which can lead to undefined behavior like indexing out of bounds, thus it's
/// the caller responsibility to pass a shader which doesn't perform any of this
/// operations.
///
/// This has no effect on web.
pub unsafe fn create_shader_module_unchecked(
&self,
desc: &ShaderModuleDescriptor,
) -> ShaderModule {
ShaderModule {
context: Arc::clone(&self.context),
id: Context::device_create_shader_module(
&*self.context,
&self.id,
desc,
wgt::ShaderBoundChecks::unchecked(),
),
}
}